# export
from fastai.datasets import URLs, untar_data
from pathlib import Path
import pandas as pd, re, PIL, os, mimetypes, csv, itertools
import matplotlib.pyplot as plt
from collections import OrderedDict
from enum import Enum
from warnings import warn
from functools import partial,reduce
from PIL import Image
This is the base class to define a transform when we want something more complex than a function. The _order helps sort the transforms before applying them. setup is a preparation step to get the state ready on the data (which is a DataSource). __call__ is the main function that applies the transform to o and decode does the reverse operation.
NB: You should only implement decode if your transform needs to be reversed for display purposes. For instance we want to reverse the operation class to index, but we don't want to reverse the operation open this image.
# export
class Transform():
_order,filt = 0,None
def __init__(self, encodes=None, decodes=None, filt=None, order=None):
self.filt = filt
if encodes is not None: self.encodes = encodes
if decodes is not None: self.decodes = decodes
if order is not None: self._order=order
@classmethod
def create(cls, f, filt=None): return f if hasattr(f,'setup') or isinstance(f,Transform) else cls(f)
def __call__(self, o, filt=None, **kwargs):
if self.filt is not None and self.filt!=filt: return o
return self.encodes(o, **kwargs)
def decode(self, o, filt=None, **kwargs):
if self.filt is not None and self.filt!=filt: return o
return self.decodes(o, **kwargs)
def __repr__(self): return str(self.encodes) if self.__class__==Transform else str(self.__class__)
def decodes(self, o, *args, **kwargs): return o
def order_sorted(funcs, order='_order'):
"Listify `funcs` and sort with `order`."
key = lambda o: getattr(o, order, 0)
return sorted(listify(funcs), key=key)
# export
def opt_call(f, fname, *args, **kwargs): return getattr(f,fname,noop)(*args, **kwargs)
class Transforms():
def __init__(self, tfms, order='_order'):
self.order,self.tfms = order,[]
self._tfms = [Transform.create(t) for t in listify(tfms)]
def __call__(self, x, **kwargs): return self._apply(x, **kwargs)
def decode(self, x, **kwargs): return self._apply(x, rev=True, fname='decode', **kwargs)
def _apply(self, x, rev=False, fname='__call__', **kwargs):
tfms = reversed(self.tfms) if rev else self.tfms
for f in tfms: x = opt_call(f, fname, x, **kwargs)
return x
def __repr__(self): return str(self.tfms)
def delete(self, idx): del(self.tfms[idx])
def remove(self, tfm): self.tfms.remove(tfm)
def setup(self, items=None): self.add(self._tfms, items)
def add(self, tfm, items):
# We only add one at a time so that each setup has access to correct tfm subset
for t in order_sorted(tfm):
self.tfms.append(t)
opt_call(t, 'setup', items)
def __getattr__(self, k):
for t in reversed(self.tfms):
a = getattr(t, k, None)
if a is not None: return a
raise AttributeError(k)
DataSource is the base class of the data blok API and is defined from items, tfms and filts. It can represent multiple datasets (train, valid, or more) that are contained in the items: each element of filts is a boolean mask or a collection of ints that says which items are in which dataset.
When accessing an element, tfms are applied to it with optional tfm_kwargs passed along. Those kwargs are filtered so that each tfms only gets the one it accepts. At its base a tfm is just a simple function (open an image, resize it, one-hot encode a category, etc.) but it can be more complex (see Transform class below). Some transforms need a setup (for instance the transform that changes a cateogry to its index needs to compute all the classes) and some can be reversible for display purposes (if you change a category to an index, you still want to display the category name later on, or if you normalize your image, you need to undo that to display it). DataSource calls the potential setup function of its Transform at initialization and it has a decode method that will reverse the transforms that can be reversed.
# export
def coll_repr(c, max=1000):
return f'(#{len(c)}) [' + ','.join(itertools.islice(map(str,c), 10)) + ('...'
if len(c)>10 else '') + ']\n'
# export
class DataSource():
def __init__(self, items, tfms=noop, tfm=None, filts=None):
if filts is None: filts = [range_of(items)]
ft = mask2idxs if isinstance(filts[0][0], bool) else listify
self.filts = listify(ft(filt) for filt in filts)
self.items,self.tfm = listify(items),ifnone(tfm, Transforms(tfms))
self.tfm.setup(self)
def __len__(self): return len(self.filts)
def len(self, filt=0): return len(self.filts[filt])
def __getitem__(self, i): return FilteredList(self, i)
def decode(self, o, filt=0, **kwargs): return self.tfm.decode(o, filt=filt, **kwargs)
def decoded(self, idx, filt=0): return self.decode(self.get(idx,filt), filt)
def __iter__(self): return (self[i] for i in range_of(self))
def get(self, idx, filt=0):
if hasattr(idx,'__len__') and getattr(idx,'ndim',1):
# rank>0 collection
if isinstance(idx[0],bool): idx = mask2idxs(idx)
return [self.get(i,filt) for i in idx] # index list
it = self.items[self.filts[filt][idx]]
return self.tfm(it, filt=filt)
def __eq__(self,b):
if not isinstance(b,DataSource): b = DataSource(b)
return len(b) == len(self) and all(o==p for o,p in zip(self,b))
def __repr__(self):
res = f'{self.__class__.__name__}\n'
for i,o in enumerate(self): res += f'{i}: {coll_repr(o)}'
return res
@property
def train(self): return self[0]
@property
def valid(self): return self[1]
#export
class FilteredList:
def __init__(self, dsrc, filt): self.dsrc,self.filt = dsrc,filt
def __getitem__(self,i): return self.dsrc.get(i,self.filt)
def decode(self, o): return self.dsrc.decode(o, self.filt)
def __len__(self): return self.dsrc.len(self.filt)
def __eq__(self,b): return len(b) == len(self) and all(o==p for o,p in zip(self,b))
def __iter__(self): return (self[i] for i in range_of(self))
def __repr__(self): return coll_repr(self)
# test
#Indexing
dsrc = DataSource(range(5))
test_eq(dsrc,[0,1,2,3,4])
test_eq(list(dsrc[0]),[0,1,2,3,4])
test_ne(dsrc,[0,1,2,3,5])
test_eq(dsrc.get(2),2)
test_eq(dsrc.get([1,2]),[1,2])
test_eq(dsrc.get([True,False,False,True,False]),[0,3])
# test
#filts can be indices or boolean masks
dsrc = DataSource(range(5), filts=[[0,2], [1,3,4]])
test_eq(list(dsrc[0]),[0,2])
test_eq(list(dsrc[1]),[1,3,4])
#Subsets don't have to be disjoints
dsrc = DataSource(range(5), filts=[[False,True,True,False,True], [True,False,False,True,True]])
test_eq(list(dsrc[0]),[1,2,4])
test_eq(list(dsrc[1]),[0,3,4])
dsrc
# test
#Base transform
dsrc = DataSource(range(5), lambda x:x*2)
test_eq(dsrc,[0,2,4,6,8])
test_eq(list(dsrc[0]),[0,2,4,6,8])
test_ne(dsrc,[1,2,4,6,8])
test_eq(dsrc.get(2), 4)
test_eq(dsrc.get([1,2]), [2,4])
test_eq(dsrc.get([True,False,False,True,False]), [0,6])
# test
#Different transforms for the two subsets
dsrc = DataSource(range(5), Transform(lambda x: x*2, filt=1), filts=[[1,2],[0,3,4]])
# test_eq(list(dsrc[0]),[1,2])
test_eq(list(dsrc[1]),[0,6,8])
test_eq(dsrc.get(2,1), 8)
test_eq(dsrc.get([1,2], 1), [6,8])
test_eq(dsrc.get([False,True], 0), [2])
# test
def add(x, a=1): return x+a
def multiply(x, a=2): return x*a
def square(x): return x**2
def add_undo(x, a=1): return x-a
def multiply_undo(x, a=2): return x/a
addt = Transform(add, add_undo, order=2)
multt = Transform(multiply, multiply_undo, order=1)
sqrt = Transform(square, order=0)
#Test _order
tfms = [addt,multt,sqrt]
dsrc = DataSource([0,1,2,3], tfms, filts=[range(4)])
test_eq(dsrc.get(2), ((2**2) * 2) + 1)
#Test decode
dsrc = DataSource([0,1,2,3], tfms, filts=[[0,1,2,3]])
test_eq(dsrc.decode(9), (9-1)/2)
# export
def _get_files(p, fs, extensions=None):
p = Path(p)
res = [p/f for f in fs if not f.startswith('.')
and ((not extensions) or f'.{f.split(".")[-1].lower()}' in extensions)]
return res
def get_files(path, extensions=None, recurse=False, include=None):
"Get all the files in `path` with optional `extensions`."
path = Path(path)
extensions = setify(extensions)
extensions = {e.lower() for e in extensions}
if recurse:
res = []
for i,(p,d,f) in enumerate(os.walk(path)): # returns (dirpath, dirnames, filenames)
if include is not None and i==0: d[:] = [o for o in d if o in include]
else: d[:] = [o for o in d if not o.startswith('.')]
res += _get_files(p, f, extensions)
else:
f = [o.name for o in os.scandir(path) if o.is_file()]
res = _get_files(path, f, extensions)
return res
# test
path = untar_data(URLs.MNIST_TINY)
test_eq(len(get_files(path/'train'/'3')),346)
test_eq(len(get_files(path/'train'/'3', extensions='.png')),346)
test_eq(len(get_files(path/'train'/'3', extensions='.jpg')),0)
test_eq(len(get_files(path/'train', extensions='.png')),0)
test_eq(len(get_files(path/'train', extensions='.png', recurse=True)),709)
test_eq(len(get_files(path, extensions='.png', recurse=True, include=['train'])),709)
test_eq(len(get_files(path, extensions='.png', recurse=True, include=['train', 'test'])),729)
#export
image_extensions = set(k for k,v in mimetypes.types_map.items() if v.startswith('image/'))
def get_image_files(path, include=None, **kwargs):
"Get image files in `path` recursively."
return get_files(path, extensions=image_extensions, recurse=True, include=include)
def image_getter(suf='', **kwargs):
def _inner(o, **kw): return get_image_files(o/suf, **{**kwargs,**kw})
return _inner
# test
path = untar_data(URLs.MNIST_TINY)
test_eq(len(get_image_files(path)),1428)
test_eq(len(get_image_files(path/'train')),709)
test_eq(len(get_image_files(path, include='train')),709)
test_eq(len(get_image_files(path, include=['train','valid'])),1408)
# test
path = untar_data(URLs.MNIST_TINY)
test_eq(len(image_getter()(path)),1428)
test_eq(len(image_getter('train')(path)),709)
#export
def show_image(im, ax=None, figsize=None, title=None, **kwargs):
"Show a PIL image on `ax`."
if ax is None: _,ax = plt.subplots(figsize=figsize)
if isinstance(im,Tensor) and im.shape[0]<5: im=im.permute(1,2,0)
ax.imshow(im, **kwargs)
if title is not None: ax.set_title(title)
ax.axis('off')
return ax
def show_title(o, ax=None):
if ax is None: print(o)
else: ax.set_title(o)
Convention: a function that has the name of a verb and ends with er returns a function (to get transforms directly or for use in the high level API belox).
# export
def random_splitter(valid_pct=0.2, seed=None, **kwargs):
"Split `items` between train/val with `valid_pct` randomly."
def _inner(o, **kwargs):
if seed is not None: torch.manual_seed(seed)
rand_idx = torch.randperm(len(o))
cut = int(valid_pct * len(o))
return rand_idx[cut:],rand_idx[:cut]
return _inner
#test
trn,val = random_splitter(seed=42)([0,1,2,3,4,5])
test_equal(trn, tensor([3, 2, 4, 1, 5]))
test_equal(val, tensor([0]))
# export
def _grandparent_mask(items, name):
return [(o.parent.parent.name if isinstance(o, Path) else o.split(os.path.sep)[-2]) == name for o in items]
def grandparent_splitter(train_name='train', valid_name='valid', **kwargs):
"Split `items` from the grand parent folder names (`train_name` and `valid_name`)."
def _inner(o, **kwargs):
return _grandparent_mask(o, train_name),_grandparent_mask(o, valid_name)
return _inner
path = untar_data(URLs.MNIST_TINY)
#test
#With string filenames
path = untar_data(URLs.MNIST_TINY)
items = [path/'train'/'3'/'9932.png', path/'valid'/'7'/'7189.png',
path/'valid'/'7'/'7320.png', path/'train'/'7'/'9833.png',
path/'train'/'3'/'7666.png', path/'valid'/'3'/'925.png',
path/'train'/'7'/'724.png', path/'valid'/'3'/'93055.png']
trn,val = grandparent_splitter()(items)
test_eq(trn,[True,False,False,True,True,False,True,False])
test_eq(val,[False,True,True,False,False,True,False,True])
# export
def parent_label(o, **kwargs):
"Label `item` with the parent folder name."
return o.parent.name if isinstance(o, Path) else o.split(os.path.sep)[-1]
def re_labeller(pat):
"Label `item` with regex `pat`."
pat = re.compile(pat)
def _inner(o, **kwargs):
res = pat.search(str(o))
assert res,f'Failed to find "{pat}" in "{o}"'
return res.group(1)
return _inner
Let's grab the Pets dataset first. Our DataSource will contain all the image files as items, and we'll randomly select two filts with 80% and 20% of the data. To get our xs, we need to apply a Transform that opens the image in the filenames. We'll call it Imagify.
source = untar_data(URLs.PETS)/"images"
class PetTfm(Transform):
def __init__(self, source):
super().__init__()
self.source,self.vocab = source,None
self.labeller = re_labeller(pat = r'/([^/]+)_\d+.jpg$')
def setups(self, dsrc):
vals = map(self.labeller, dsrc.train)
self.vocab,self.o2i = uniqueify(vals, sort=True, bidir=True)
def encodes(self, o):
if self.vocab is None: return o
return Image.open(o), self.o2i[self.labeller(o)]
def decodes(self, o): return o[0],self.vocab[o[1]]
def show(self, o, ax=None): show_image(o[0], ax, title=o[1])
tfm = PetTfm(source)
items = get_image_files(source)
split_idx = random_splitter()(items)
pets = DataSource(items, tfm, filts=split_idx)
To access an element we need to specify index/filter (the latter defaults to 0)
xy = pets.get(0,1); xy
We can decode an element for display purposes.
xyd = pets.decode(xy, 1); xyd
tfm.show(xyd)
# export
def _dsrc_show(self, o, filt=0, **kwargs): self.tfm.show(self.decode(o, filt), **kwargs)
def _fl_show(self, o, **kwargs): self.dsrc.show(o, self.filt, **kwargs)
def _fl_show_at(self, i, **kwargs): self.show(self[i], **kwargs)
DataSource.show = _dsrc_show
FilteredList.show = _fl_show
FilteredList.show_at = _fl_show_at
pets.show(pets.get(0,1))
pets.valid.show_at(0)
Before we can batch our images, we'll need to apply some basic image transformations: converting to RGB, making them all the same size and also converting them to tensors. We have to get prepared for different kind of targets: sometimes the target won't be applied the transform, but sometimes it will and in different ways. We support images, segmentation masks, points or bounding boxes.
# export
TfmY = Enum('TfmY', 'Mask Image Point Bbox No')
The ImageTransform class seems ugly but it just dispatches the apply or decode function properly between x and y, and allow different implementations of each function. This will be very important for data augmentation in the next notebook.
# export
class ImageTransform(Transform):
"Basic class for image transforms."
_order,_tfm_y = 10,TfmY.No
def randomize(self): pass
def encodes(self, o, **kwargs):
self.x,*y = o
self.randomize() # Ensure we have the same state for x and y
return ( self.apply(self.x), *( self.apply_y(y_, **kwargs) for y_ in y))
def decodes(self, o, **kwargs):
self.x,*y = o
return (self.unapply(self.x), *(self.unapply_y(y_, **kwargs) for y_ in y))
def _tfm_name(self, is_decode=False):
return f"{'un' if is_decode else ''}apply_{self._tfm_y.name.lower()}"
def apply_no (self, y): return y
def apply_image(self, y): return self.apply(y)
def apply_mask (self, y): return self.apply_image(y)
def apply_point(self, y): return y
def apply_bbox (self, y): return self.apply_point(y)
def unapply_no (self, y): return y
def unapply_image(self, y): return self.unapply(y)
def unapply_mask (self, y): return self.unapply_image(y)
def unapply_point(self, y): return y
def unapply_bbox (self, y): return self.unapply_point(y)
def apply (self, x): return x
def unapply (self, x): return x
def apply_y(self, y): return getattr(self, self._tfm_name(False))(y)
def unapply_y(self, y): return getattr(self, self._tfm_name(True ))(y)
# test
import random
class FakeTransform(ImageTransform):
def randomize(self): self.a = random.randint(1,1000)
def apply(self, x): return x + self.a
def apply_mask(self, x): return x + 5
def apply_point(self, x): return x + 2
tfm = FakeTransform()
xy = x,y = 5,10
#Basic behavior: x has changed, not y
t1 = tfm(xy)
assert t1[0]!=x and t1[1]==y, t1
#Check the same random integer was used for x and y when transforming y
tfm._tfm_y=TfmY.Image; t1 = tfm(xy)
test_eq(t1[0] - 5,t1[1] - 10)
#Check mask, point,bbox implementations
tfm._tfm_y=TfmY.Mask ; test_eq(tfm(xy)[1],15)
tfm._tfm_y=TfmY.Point; test_eq(tfm(xy)[1],12)
tfm._tfm_y=TfmY.Bbox ; test_eq(tfm(xy)[1],12)
Our first transform decodes an image to 'RGB'. We can specify different modes for the xs and ys, the default is 'RGB' for x, then mode_x for y if our ys are images, 'L' for y if our ys are segmentation masks.
#export
class DecodeImg(ImageTransform):
"Convert regular image to RGB, masks to L mode."
def __init__(self, mode_x='RGB', mode_y=None): self.mode_x,self.mode_y = mode_x,mode_y
def apply(self, x): return x.convert(self.mode_x)
def apply_image(self, y): return y.convert(ifnone(self.mode_y,self.mode_x))
def apply_mask(self, y): return y.convert(ifnone(self.mode_y,'L'))
Our second transform resizes an image, using a given mode. It defaults to bilinear for images and nearest for segmentation masks.
# export
class ResizeFixed(ImageTransform):
"Resize image to `size` using `mode_x` (and `mode_y` on targets)."
_order=15
def __init__(self, size, mode_x=Image.BILINEAR, mode_y=None):
if isinstance(size,int): size=(size,size)
self.size = (size[1],size[0]) #PIL takes size in the other way round
self.mode_x,self.mode_y = mode_x,mode_y
def apply(self, x): return x.resize(self.size, self.mode_x)
def apply_image(self, y): return y.resize(self.size, ifnone(self.mode_y,self.mode_x))
def apply_mask(self, y): return y.resize(self.size, ifnone(self.mode_y,Image.NEAREST))
The transformation to tensors is done in two steps just in case one wants to apply transforms to byte tensors. The permutation of axes needs to be reversed for display, so we have an unapply function (which is what is called by decode in an ImageTransform).
# export
class ToByteTensor(ImageTransform):
"Transform our items to byte tensors."
_order=20
def apply(self, x):
res = torch.ByteTensor(torch.ByteStorage.from_buffer(x.tobytes()))
w,h = x.size
return res.view(h,w,-1).permute(2,0,1)
def unapply(self, x):
return x[0] if x.shape[0] == 1 else x.permute(1,2,0)
Lastly we convert our tensors to floats (or ints for segmentation masks) and divides by 255 (can specify a different value and a div_y)
# export
class ToFloatTensor(ImageTransform):
"Transform our items to float tensors (int in the case of mask)."
_order=5 #Need to run after CUDA on the GPU
def __init__(self, div_x=255., div_y=None): self.div_x,self.div_y = div_x,div_y
def apply(self, x): return x.float().div_(self.div_x)
def apply_mask(self, x):
return x.long() if self.div_y is None else x.long().div_(self.div_y)
def unapply(self, x): return torch.clamp(x, 0., 1.)
def unapply_mask(self, x): return x
Let's test it's all work properly.
tfms = [PetTfm(source), DecodeImg(), ResizeFixed(128), ToByteTensor()]
pets = DataSource(items, tfms, filts=split_idx)
xy = pets.get(0,1); xy[0].type()
pets.show(xy)
With transforms to make our images tensors of the same size, we're ready to create batches and dataloaders. We wrap a PyTorch dataloader to add batch transforms. Additional kwargs will be passed along. Those transforms can be decoded like before, for display purposes (like normalization).
def apply_all(o, fs, fname=None, **kwargs):
for f in fs:
if fname is not None: f = getattr(f,fname,noop)
o = f(o, **kwargs)
return o
# export
class TfmDataLoader():
def __init__(self, dl, tfms=None, **tfm_kwargs):
self.dl,self.tfms,self.tfm_kwargs = dl,order_sorted(tfms),tfm_kwargs
def __len__(self): return len(self.dl)
def __iter__(self):
for b in self.dl: yield apply_all(b, self.tfms)
def one_batch(self): return next(iter(self))
def decode_batch(self): return self.decode(self.one_batch())
def decode(self, o): return apply_all(o, reversed(self.tfms), fname='decode')
def __getattr__(self, k):
try: return getattr(self.dataset, k)
except AttributeError: raise AttributeError(k) from None
@property
def dataset(self): return self.dl.dataset
Then we add a basic function to create dataloaders from a DataSource.
# export
from torch.utils.data.dataloader import DataLoader
def get_dl(dset, bs=64, tfms=None, tfm_kwargs=None, **kwargs):
dl = DataLoader(dset, bs, **kwargs)
return TfmDataLoader(dl, tfms=tfms, **(ifnone(tfm_kwargs,{})))
def get_dls(dsrc, bs=64, tfms=None, tfm_kwargs=None, **kwargs):
return [get_dl(dsrc[i], bs, shuffle=i==0, tfms=tfms, tfm_kwargs=tfm_kwargs, **kwargs)
for i in range_of(dsrc)]
dls = get_dls(pets, tfms=ToFloatTensor())
This is a convenience function to grab the k-th item in a batch, even if the batch is constituted of lists of tensors.
#export
def grab_item(b,k):
if isinstance(b, (list,tuple)): return [grab_item(o,k) for o in b]
return b[k]
def show_batch(b, show, items=9, cols=3, figsize=None, show_func=None, **kwargs):
rows = (items+cols-1) // cols
if figsize is None: figsize = (cols*3, rows*3)
fig,axs = plt.subplots(rows, cols, figsize=figsize)
for k,ax in enumerate(axs.flatten()):
show(grab_item(b,k), ax=ax, show_func=show_func, **kwargs)
# export
class DataBunch():
"Basic wrapper around several `DataLoader`s."
def __init__(self, *dls): self.dls = dls
def __getitem__(self, i): return self.dls[i]
@property
def train_dl(self): return self.dls[0]
@property
def valid_dl(self): return self.dls[1]
@property
def train_ds(self): return self.train_dl.dataset
@property
def valid_ds(self): return self.valid_dl.dataset
def show_batch(self, i=0, items=9, cols=3, figsize=None, show_func=None, **kwargs):
b = self.dls[i].decode_batch()
show = self[i].dataset.show
show_batch(b, show, items, cols, figsize=figsize, show_func=show_func, **kwargs)
data = DataBunch(*dls)
x,y = data[0].one_batch()
x.shape,x.type(),y.shape,y.type()
data.show_batch()
Finally let's monkey-patch a databunch function in DataSource to quickly create a databunch.
# export
def _dsrc_databunch(self, bs=64, tfms=None, **kwargs):
res = DataBunch(*get_dls(self, bs=bs, tfms=tfms, **kwargs))
res.dsrc = self
return res
DataSource.databunch = _dsrc_databunch
device = torch.device('cuda',0)
# export
from fastai.torch_core import to_device, to_cpu
import torch.nn.functional as F
# export
class Cuda(Transform):
_order = 0
def __init__(self,device): self.device=device
def encodes(self, b, tfm_y=TfmY.No): return to_device(b, self.device)
def decodes(self, b): return to_cpu(b)
We'll see other batch transforms in the next chapter but one that is pretty common is normalization.
# export
class Normalize(Transform):
_order=99
def __init__(self, mean, std, do_x=True, do_y=False):
self.mean,self.std,self.do_x,self.do_y = mean,std,do_x,do_y
def encodes(self, b):
x,y = b
if self.do_x: x = self.normalize(x)
if self.do_y: y = self.normalize(y)
return x,y
def decodes(self, b):
x,y = b
if self.do_x: x = self.denorm(x)
if self.do_y: y = self.denorm(y)
return x,y
def normalize(self, x): return (x - self.mean) / self.std
def denorm(self, x): return x * self.std + self.mean
mean,std = tensor([0.5,0.5,0.5]).view(1,-1,1,1).cuda(),tensor([0.5,0.5,0.5]).view(1,-1,1,1).cuda()
data = pets.databunch(tfms = [Cuda(device), ToFloatTensor(), Normalize(mean,std)])
data.show_batch()
x ,y = data[0].one_batch()
xd,yd = data[0].decode_batch()
x.type(), xd.type(), x.mean(), x.std(), xd.mean(), xd.std()
ds_tfms = [DecodeImg(), ResizeFixed(128), ToByteTensor()]
dl_tfms = [Cuda(device), ToFloatTensor()]
#export
class Imagify(Transform):
def __init__(self, f=Image.open, cmap=None, alpha=1.): self.f,self.cmap,self.alpha = f,cmap,alpha
def encodes(self, fn): return Image.open(fn)
def show(self, im, ax=None, figsize=None, cmap=None, alpha=None):
return show_image(im, ax, figsize=figsize,
cmap=ifnone(cmap,self.cmap),
alpha=ifnone(alpha,self.alpha))
To get our ys, we'll need to apply the re pattern (function re_labeller from before) and a transform that creates the list of classes, then maps categories to their indices. We call that transform Categorize. It needs a setup to create the classes, and it's reversible so we implement decode. This is also a base a transform so we implement its show method.
Since it needs to run after the labelling transform (here our re pattern labeller) we give it an _order of 1.
# export
class Categorize(Transform):
_order=1
def __init__(self): self.o2i = None
def encodes(self,o): return self.o2i[o] if self.o2i else o
def decodes(self, o): return self.vocab[o]
def show(self, o, ax=None): show_title(o, ax)
def setups(self, dsrc): self.vocab,self.o2i = uniqueify(dsrc.train, sort=True, bidir=True)
labeller = re_labeller(pat = r'/([^/]+)_\d+.jpg$')
A transform that is applied to define a base object can have a show method: for instance the transform that opens an Image has a show method. When trying to display objects, the API will decode it and grab the first transform that provides a show method (this can be overriden by passing a custom show but we'll see that later).
The show_xs function is there to combine the show methods of the base transforms to display x and y together. We can either pass a transform that has a show method or a custom list of show methods.
#export
def show_xs(xs, shows, ax=None, **kwargs):
for x,show in zip(xs,shows):
# can pass func or obj with a `show` method
show = getattr(show, 'show', show)
ax = show(x, ax=ax, **kwargs)
#export
class DsrcTfm():
def __init__(self, ttfms, tfm=noop):
self.ttfms = [Transforms(tfm) for tfm in listify(ttfms)]
self.tfm,self.activ,self.done_setup = Transforms(tfm),None,False
def __call__(self, o, **kwargs):
if self.activ: return self.activ(o, **kwargs)
o = [t(o, **kwargs) for t in self.ttfms]
return self.tfm(o, **kwargs)
def decode(self, o, **kwargs):
o = self.tfm.decode(o, **kwargs)
return [t.decode(p, **kwargs) for p,t in zip(o,self.ttfms)]
def setup(self, dsrc):
if self.done_setup: return
for tfm in self.ttfms:
self.activ = tfm
tfm.setup(dsrc)
self.activ=None
self.tfm.setup(dsrc)
self.done_setup = True
def show(self, o, **kwargs): return show_xs(o, self.ttfms, **kwargs)
def __repr__(self): return f'DsrcTfm({self.ttfms}\n{self.tfm})\n'
@property
def xt(self): return self.ttfms[0]
@property
def yt(self): return self.ttfms[1]
tfm = DsrcTfm([Imagify(), [labeller,Categorize()]], ds_tfms)
pets = DataSource(items, tfm=tfm, filts=split_idx)
def tfm_dsrc(items, filts, xt, yt, labeller, ds_tfms=None):
tfm = DsrcTfm([xt, [labeller,yt]], ds_tfms)
return DataSource(items, tfm=tfm, filts=filts)
pets = tfm_dsrc(items, split_idx, Imagify(), Categorize(), labeller, ds_tfms=ds_tfms)
xy = pets.decoded(0)
show_xs(xy, tfm.ttfms)
pets.train.show_at(0)
data = pets.databunch(bs=16, tfms=dl_tfms)
data.show_batch()
#export
class DataBlock():
@staticmethod
def get_items(source): raise NotImplementedError
@staticmethod
def split(items): raise NotImplementedError
@staticmethod
def label_func(item): raise NotImplementedError
def __init__(self, source):
self.source = source
xt,yt = self.types()
self.tfm = DsrcTfm([xt, [self.__class__.label_func,yt]])
def datasource(self, tfms=None):
items = self.__class__.get_items(self.source, self=self)
split_idx = self.__class__.split(items, self=self)
return DataSource(items, [self.tfm]+listify(tfms), filts=split_idx)
def databunch(self, ds_tfms=None, dl_tfms=None, bs=64, **kwargs):
return self.datasource(tfms=ds_tfms).databunch(bs, tfms=dl_tfms, **kwargs)
class PetsData(DataBlock):
def types(self): return Imagify(),Categorize()
get_items = image_getter()
split = random_splitter()
label_func = re_labeller(pat = r'/([^/]+)_\d+.jpg$')
source = untar_data(URLs.PETS)/"images"
dblk = PetsData(source)
data = dblk.databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)
data.dsrc.train.show_at(0)
data.show_batch()
' '.join(dblklk.tfm.yt.vocab)
class MnistData(DataBlock):
def types(self): return Imagify(),Categorize()
get_items = get_image_files
split = grandparent_splitter(train_name='training', valid_name='testing')
label_func = parent_label
source = untar_data(URLs.MNIST)
data = MnistData(source).databunch(ds_tfms=[ToByteTensor()], dl_tfms=dl_tfms)
data.show_batch()
There are seveal ways to get the show display properly our images. First we can pass a custom show_func method and change the function that shows the x.
data.show_batch(shows = (partial(show_image, cmap='gray'), None))
Or just set the default cmap to gray in types:
class MnistDataBW(MnistData):
def types(self): return Imagify(cmap='gray'),Categorize()
data = MnistDataBW(source).databunch(ds_tfms=[ToByteTensor()], dl_tfms=dl_tfms)
data.show_batch()
path = untar_data(URLs.PLANET_SAMPLE)
# export
def onehot(x, c, a=1.):
"Return the `a`-hot encoded tensor for `x` with `c` classes."
res = torch.zeros(c)
if a<1: res += (1-a)/(c-1)
res[x] = a
return res
# test
test_equal(onehot(1,5), tensor([0.,1.,0.,0.,0.]))
test_equal(onehot([1,3],5), tensor([0.,1.,0.,1.,0.]))
test_equal(onehot(tensor([1,3]),5), tensor([0.,1.,0.,1.,0.]))
test_equal(onehot([True,False,True,True,False],5), tensor([1.,0.,1.,1.,0.]))
test_equal(onehot([],5), tensor([0.,0.,0.,0.,0.]))
test_equal(onehot(1,5,0.9), tensor([0.025,0.9,0.025,0.025,0.025]))
# export
class MultiCategorize(Transform):
_order=1
def __init__(self): self.vocab = None
def __call__(self,x): return [self.o2i[o] for o in x if o in self.o2i]
def decode(self, o): return [self.vocab[i] for i in o]
@property
def c(self): return len(self.vocab)
def show(self, o, ax=None):
(print if ax is None else ax.set_title)(';'.join(o))
def setup(self, dsrc):
if self.vocab is not None: return
vals = set()
for c in dsrc.train: vals = vals.union(set(c))
self.vocab,self.o2i = uniqueify(list(vals), sort=True, bidir=True)
class OneHotEncode(Transform):
_order=10
def setup(self, items): self.c = items.activ_tfm.c
def __call__(self, o): return onehot(o, self.c) if self.c is not None else o
def decode(self, o): return [i for i,x in enumerate(o) if x == 1]
def multi_category(): return [MultiCategorize(), OneHotEncode()]
# test
tfm = MultiCategorize()
#Even if 'c' is the first class, vocab is sorted for reproducibility
ds = DataSource([['c','a'], ['a','b'], ['b'], []], [tfm], filts=[[0,1,2,3], []])
test_eq(tfm.vocab,['a','b','c'])
test_eq(tfm(['b','a']),[1,0])
test_eq(tfm.decode([2,0]),['c','a'])
# export
def get_str_column(df, col_name, prefix='', suffix='', delim=None):
"Read `col_name` in `df`, optionnally adding `prefix` or `suffix`."
values = df[col_name].values.astype(str)
values = np.char.add(np.char.add(prefix, values), suffix)
if delim is not None:
values = np.array(list(csv.reader(values, delimiter=delim)))
return values
# test
df = pd.DataFrame({'a': ['cat', 'dog', 'car'], 'b': ['a b', 'c d', 'a e']})
test_np_eq(get_str_column(df, 'a'), np.array(['cat', 'dog', 'car']))
test_np_eq(get_str_column(df, 'a', prefix='o'), np.array(['ocat', 'odog', 'ocar']))
test_np_eq(get_str_column(df, 'a', suffix='.png'), np.array(['cat.png', 'dog.png', 'car.png']))
test_np_eq(get_str_column(df, 'b', delim=' '), np.array([['a','b'], ['c','d'], ['a','e']]))
class PlanetData(DataBlock):
def types(self): return Imagify(),multi_category()
def get_items(source, self):
df = pd.read_csv(self.source/'labels.csv')
items = get_str_column(df, 'image_name', prefix=f'{self.source}/train/', suffix='.jpg')
labels = get_str_column(df, 'tags', delim=' ')
self.item2label = {i:s for i,s in zip(items,labels)}
return items
split = random_splitter()
def label_func(item, self): return self.item2label[item]
source = untar_data(URLs.PLANET_SAMPLE)
dsrc = PlanetData(source)
data = dsrc.databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms)
data.show_batch()
# export
class SegmentMask(Item):
tfm = partial(Imagify, cmap='tab20', alpha=0.5)
tfm_kwargs = {'tfm_y': TfmY.Mask}
class CamvidData(DataBlock):
types = Image,SegmentMask
get_items = image_getter('images')
split = random_splitter()
label_func = lambda o,self: self.source/'labels'/f'{o.stem}_P{o.suffix}'
source = untar_data(URLs.CAMVID_TINY)
data = CamvidData(source).databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)
data.show_batch(cmap='tab20')
import pickle
# export
class PointScaler(Transform):
_order = 5 #Run before we apply any ImageTransform
def __init__(self, do_scale=True, y_first=False):
self.do_scale,self.y_first = do_scale,y_first
def __call__(self, o, tfm_y=TfmY.No):
x,y = o
if not isinstance(y, torch.Tensor): y = tensor(y)
y = y.view(-1, 2).float()
if not self.y_first: y = y.flip(1)
if self.do_scale: y = y * 2/tensor(list(x.size)).float() - 1
return (x,y)
def decode(self, o, tfm_y=TfmY.No):
x,y = o
y = y.flip(1)
y = (y+1) * tensor([x.shape[:2]]).float()/2
return (x,y)
class PointShow(Transform):
def show(self, x, ax=None): ax.scatter(x[:, 1], x[:, 0], s=10, marker=.', c=r')
class Points(Item):
tfm,tfm_ds,tfm_kwargs = PointShow,PointScaler,{'tfm_y': TfmY.Point}
class BiwiData(DataBlock):
types = Image,Points
def __init__(self, source, *args, **kwargs):
super().__init__(source, *args, **kwargs)
self.fn2ctr = pickle.load(open(source/'centers.pkl', 'rb'))
get_items = image_getter('images')
split = random_splitter()
label_func = lambda o,self: self.fn2ctr[o.name]
dblk = BiwiData(untar_data(URLs.BIWI_SAMPLE))
data = dblk.databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)
data.show_batch()
#export
from fastai.vision.data import get_annotations
from matplotlib import patches, patheffects
def _draw_outline(o, lw):
o.set_path_effects([patheffects.Stroke(linewidth=lw, foreground='black'), patheffects.Normal()])
def _draw_rect(ax, b, color='white', text=None, text_size=14, hw=True, rev=False):
lx,ly,w,h = b
if rev: lx,ly,w,h = ly,lx,h,w
if not hw: w,h = w-lx,h-ly
patch = ax.add_patch(patches.Rectangle((lx,ly), w, h, fill=False, edgecolor=color, lw=2))
_draw_outline(patch, 4)
if text is not None:
patch = ax.text(lx,ly, text, verticalalignment='top', color=color, fontsize=text_size, weight='bold')
_draw_outline(patch,1)
#export
class BBoxScaler(PointScaler):
def __call__(self, o, tfm_y=TfmY.Bbox):
x,y = o
return x, (super().__call__((x,y[0]))[1].view(-1,4),y[1])
def decode(self, o, tfm_y=TfmY.Bbox):
x,y = o
_,bbox = super().decode((x,y[0].view(-1,2)))
return x, (bbox.view(-1,4),y[1])
class BBoxencodes(Transform):
_order=1
def __init__(self): self.vocab = None
def __call__(self,o):
x,y = o
return (x,[self.otoi[o_] for o_ in y if o_ in self.otoi])
def decode(self, o):
x,y = o
return x, [self.vocab[i] for i in y]
def setup(self, dsrc):
if self.vocab is not None: return
vals = set()
for bb,c in dsrc.train: vals = vals.union(set(c))
self.vocab,self.otoi = uniqueify(list(vals), sort=True, bidir=True, start='#bg')
def show(self, x, ax):
bbox,label = x
for b,l in zip(bbox, label):
if l != '#bg': _draw_rect(ax, b, hw=False, rev=True, text=l)
# export
class BBox(Item): tfm,tfm_ds,tfm_kwargs = BBoxencodes,BBoxScaler,{'tfm_y': TfmY.Bbox}
# export
def bb_pad_collate(samples, pad_idx=0):
max_len = max([len(s[1][1]) for s in samples])
bboxes = torch.zeros(len(samples), max_len, 4)
labels = torch.zeros(len(samples), max_len).long() + pad_idx
imgs = []
for i,s in enumerate(samples):
imgs.append(s[0][None])
bbs, lbls = s[1]
if not (bbs.nelement() == 0):
bboxes[i,-len(lbls):] = bbs
labels[i,-len(lbls):] = tensor(lbls)
return torch.cat(imgs,0), (bboxes,labels)
class CocoData(DataBlock):
types = Image,BBox
def __init__(self, source, *args, **kwargs):
super().__init__(source, *args, **kwargs)
images, lbl_bbox = get_annotations(source/'train.json')
self.img2bbox = dict(zip(images, lbl_bbox))
get_items = image_getter('train')
split = random_splitter()
label_func = lambda o,self: self.img2bbox[o.name]
def databunch(self, ds_tfms=None, dl_tfms=None, bs=64, tfm_kwargs=None, **kwargs):
return super().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=bs, tfm_kwargs=tfm_kwargs,
collate_fn=bb_pad_collate, **kwargs)
source = untar_data(URLs.COCO_TINY)
dblk = CocoData(source)
data = dblk.databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)
data.show_batch()
You can also use DataSource directly with just one transform that does everything, without using the blocks. You will have to provide your show method if you want to use show_batch however (no need to decode if you do everything in one transform).
def size_f(x): return tensor(x.size).float()
path = untar_data(URLs.COCO_TINY)
fns, lbl_bbox = get_annotations(path/'train.json')
img2bbox = dict(zip(fns, lbl_bbox))
class CocoTransform(Transform):
def __init__(self): self.vocab = None
def setup(self, data):
if self.vocab is not None: return
vals = set()
for c in data.train: vals = vals.union(set(img2bbox[c.name][1]))
self.vocab,self.otoi = uniqueify(list(vals), sort=True, bidir=True, start='#bg')
def __call__(self, fn):
img = Image.open(fn)
bbox,lbl = img2bbox[fn.name]
#flip and rescale to -1,1
bbox = tensor(bbox).view(-1,2).flip(1) * 2/size_f(img) - 1
lbl = [self.otoi[l] for l in lbl if l in self.otoi]
return (img, [bbox.view(-1,4), lbl])
def show(self, o, ax):
img,(bbox,lbl) = o
show_image(img, ax)
lbl = [self.vocab[l] for l in lbl if l != 0] #Unpad and decode
bbox = bbox[-len(lbl):,] #Unpad
bbox = (bbox.view(-1,2) + 1) * tensor(img.shape[:2]).float() / 2
bbox = bbox.flip(1).view(-1,4)
for b,l in zip(bbox, lbl): _draw_rect(ax, b, hw=False, rev=True, text=l)
fnames = get_image_files(path)
splits = random_splitter()(fnames)
ds = DataSource(fnames, tfms=CocoTransform(), filts=splits, tfm_y=TfmY.Bbox)
ds = ds.transformed(tfms=ds_tfms)
data = DataBunch(*get_dls(ds, 16, collate_fn=bb_pad_collate, tfms=dl_tfms))
data.show_batch()
! python notebook2script.py "200_datablock_config.ipynb"