# export
from fastai.datasets import URLs, untar_data
from pathlib import Path
import pandas as pd, numpy as np, torch, re, PIL, os, mimetypes, csv, itertools
import matplotlib.pyplot as plt
from collections import OrderedDict
from typing import *
from enum import Enum
from functools import partial,reduce
from torch import as_tensor,Tensor
from IPython.core.debugger import set_trace
# export
def ifnone(a, b): return b if a is None else a
def noop ( x, *args, **kwargs): return x
def noops(self, x, *args, **kwargs): return x
def range_of(x): return list(range(len(x)))
torch.Tensor.ndim = property(lambda x: x.dim())
import operator
def test(a,b,cmp,cname=None):
if cname is None: cname=cmp.__name__
assert cmp(a,b),f"{cname}:\n{a}\n{b}"
def test_eq(a,b): test(a,b,operator.eq,'==')
def test_ne(a,b): test(a,b,operator.ne,'!=')
def test_equal(a,b): test(a,b,torch.equal,'==')
def test_np_eq(a,b): test(a,b,np.array_equal,'==')
# test
test_eq(noop(1),1)
# export
def listify(o):
"Make `o` a list."
if o is None: return []
if isinstance(o, list): return o
if isinstance(o, str): return [o]
if not isinstance(o, Iterable): return [o]
#Rank 0 tensors in PyTorch are Iterable but don't have a length.
try: a = len(o)
except: return [o]
return list(o)
def tuplify(o):
"Make `o` a tuple."
return tuple(listify(o))
# export
def compose(*funcs): return reduce(lambda f,g: lambda x: f(g(x)), reversed(funcs), noop)
def is_listy(x:Any)->bool: return isinstance(x, (tuple,list))
def tensor(x, *rest):
"Like `torch.as_tensor`, but handle lists too, and can pass multiple vector elements directly."
if len(rest): x = tuplify(x)+rest
# Pytorch bug in dataloader using num_workers>0
if is_listy(x) and len(x)==0: return tensor(0)
res = torch.tensor(x) if is_listy(x) else as_tensor(x)
if res.dtype is torch.int32:
warn('Tensor is int32: upgrading to int64; for better performance use int64 input')
return res.long()
return res
# test
test_eq(listify(None),[])
test_eq(listify([1,2,3]),[1,2,3])
test_ne(listify([1,2,3]),[1,2,])
test_eq(listify('abc'),['abc'])
test_eq(listify(range(0,3)),[0,1,2])
test_eq(listify(tensor(0)),[tensor(0)])
test_eq(listify([tensor(0),tensor(1)]),[tensor(0),tensor(1)])
test_eq(listify(tensor([0.,1.1])),[0,1.1])
test_eq(tuplify(None),())
test_eq(tuplify([1,2,3]),(1,2,3))
test_eq(tuplify(tensor([0.,1.1])),(0,1.1))
#export
from inspect import getfullargspec
def has_param(func, p):
"Check if `func` accepts `p` as argument."
return p in getfullargspec(func).args
def feed_kwargs(func, *args, **kwargs):
"Feed `args` and the `kwargs` `func` accepts to `func`."
signature = getfullargspec(func)
if signature.varkw is not None: return func(*args, **kwargs)
passed_kwargs = {k:v for k,v in kwargs.items() if k in signature.args}
return func(*args, **passed_kwargs)
#test
def test_func(a, b, x=2): return a+b+x
test_eq([has_param(test_func, p) for p in ['a', 'c', 'x']], [True,False,True])
test_eq(feed_kwargs(test_func, 1, 2, x=3), 6)
test_eq(feed_kwargs(test_func, 1, 2, y=3), 5)
def test_func(a, b, x=2, **kwargs): return a+b+x
test_eq(feed_kwargs(test_func, 1, 2, x=3), 6)
test_eq(feed_kwargs(test_func, 1, 2, y=3), 5)
# export
def order_sorted(funcs, order_key='_order'):
"Listify `funcs` and sort it with `order_key`."
key = lambda o: getattr(o, order_key, 0)
return sorted(listify(funcs), key=key)
def apply_all(x, funcs, *args, order_key='_order', filter_kwargs=False, **kwargs):
"Apply all `funcs` to `x` in order, pass along `args` and `kwargs`."
for f in order_sorted(funcs, order_key=order_key):
x = feed_kwargs(f, x, *args, **kwargs) if filter_kwargs else f(x, *args, **kwargs)
return x
# test
# basic behavior
def _test_f1(x, a=2): return x**a
def _test_f2(x, a=2): return a*x
test_eq(apply_all(2, [_test_f1, _test_f2]),8)
# order
_test_f1._order = 1
test_eq(apply_all(2, [_test_f1, _test_f2]),16)
#args
test_eq(apply_all(2, [_test_f1, _test_f2], 3),216)
#kwargs
test_eq(apply_all(2, [_test_f1, _test_f2], a=3),216)
# export
def mask2idxs(mask): return [i for i,m in enumerate(mask) if m]
# export
def uniqueify(x, sort=False, bidir=False, start=None):
"Return the unique elements in `x`, optionally `sort`-ed, optionally return the reverse correspondance."
res = list(OrderedDict.fromkeys(x).keys())
if start is not None: res = listify(start)+res
if sort: res.sort()
if bidir: return res, {v:k for k,v in enumerate(res)}
return res
# test
test_eq(set(uniqueify([1,1,0,5,0,3])),{0,1,3,5})
test_eq(uniqueify([1,1,0,5,0,3], sort=True),[0,1,3,5])
v,o = uniqueify([1,1,0,5,0,3], bidir=True)
test_eq(v,[1,0,5,3])
test_eq(o,{1:0, 0: 1, 5: 2, 3: 3})
v,o = uniqueify([1,1,0,5,0,3], sort=True, bidir=True)
test_eq(v,[0,1,3,5])
test_eq(o,{0:0, 1: 1, 3: 2, 5: 3})
# export
def setify(o): return o if isinstance(o,set) else set(listify(o))
# test
test_eq(setify(None),set())
test_eq(setify('abc'),{'abc'})
test_eq(setify([1,2,2]),{1,2})
test_eq(setify(range(0,3)),{0,1,2})
test_eq(setify({1,2}),{1,2})
# export
def onehot(x, c, a=1.):
"Return the `a`-hot encoded tensor for `x` with `c` classes."
res = torch.zeros(c)
if a<1: res += (1-a)/(c-1)
res[x] = a
return res
# test
test_equal(onehot(1,5), tensor([0.,1.,0.,0.,0.]))
test_equal(onehot([1,3],5), tensor([0.,1.,0.,1.,0.]))
test_equal(onehot(tensor([1,3]),5), tensor([0.,1.,0.,1.,0.]))
test_equal(onehot([True,False,True,True,False],5), tensor([1.,0.,1.,1.,0.]))
test_equal(onehot([],5), tensor([0.,0.,0.,0.,0.]))
test_equal(onehot(1,5,0.9), tensor([0.025,0.9,0.025,0.025,0.025]))
# export
def _get_files(p, fs, extensions=None):
p = Path(p)
res = [p/f for f in fs if not f.startswith('.')
and ((not extensions) or f'.{f.split(".")[-1].lower()}' in extensions)]
return res
def get_files(path, extensions=None, recurse=False, include=None):
"Get all the files in `path` with optional `extensions`."
path = Path(path)
extensions = setify(extensions)
extensions = {e.lower() for e in extensions}
if recurse:
res = []
for i,(p,d,f) in enumerate(os.walk(path)): # returns (dirpath, dirnames, filenames)
if include is not None and i==0: d[:] = [o for o in d if o in include]
else: d[:] = [o for o in d if not o.startswith('.')]
res += _get_files(p, f, extensions)
else:
f = [o.name for o in os.scandir(path) if o.is_file()]
res = _get_files(path, f, extensions)
return res
# test
path = untar_data(URLs.MNIST_TINY)
test_eq(len(get_files(path/'train'/'3')),346)
test_eq(len(get_files(path/'train'/'3', extensions='.png')),346)
test_eq(len(get_files(path/'train'/'3', extensions='.jpg')),0)
test_eq(len(get_files(path/'train', extensions='.png')),0)
test_eq(len(get_files(path/'train', extensions='.png', recurse=True)),709)
test_eq(len(get_files(path, extensions='.png', recurse=True, include=['train'])),709)
test_eq(len(get_files(path, extensions='.png', recurse=True, include=['train', 'test'])),729)
DataSource is the base class of the data blok API and is defined from items, tfms and filters. It can represent multiple datasets (train, valid, or more) that are contained in the items: each element of filters is a boolean mask or a collection of ints that says which items are in which dataset.
When accessing an element, tfms are applied to it with optional tfm_kwargs passed along. Those kwargs are filtered so that each tfms only gets the one it accepts. At its base a tfm is just a simple function (open an image, resize it, one-hot encode a category, etc.) but it can be more complex (see Transform class below). Some transforms need a setup (for instance the transform that changes a cateogry to its index needs to compute all the classes) and some can be reversible for display purposes (if you change a category to an index, you still want to display the category name later on, or if you normalize your image, you need to undo that to display it). DataSource calls the potential setup function of its Transform at initialization and it has a decode method that will reverse the transforms that can be reversed.
# export
class DataSource():
def __init__(self, items, tfms=None, filters=None, **tfm_kwargs):
if filters is None: filters = [range(len(items))]
if isinstance(filters[0][0], bool): filters = [mask2idxs(filt) for filt in filters]
self.items,self.filters,self.tfms = listify(items),listify(filters),[]
self.tfm_kwargs = tfm_kwargs
tfms = order_sorted(tfms)
for tfm in tfms:
getattr(tfm, 'setup', noop)(self)
self.tfms.append(tfm)
def transformed(self, tfms, **tfm_kwargs):
tfms = listify(tfms)
tfm_kwargs = {**self.tfm_kwargs, **tfm_kwargs}
return self.__class__(self.items, self.tfms + tfms, self.filters, **tfm_kwargs)
def __len__(self): return len(self.filters)
def len(self, filt=0): return len(self.filters[filt])
def __getitem__(self, i): return FilteredList(self, i)
def sublist(self, filt):
return [self.get(j,filt) for j in range(self.len(filt))]
def get(self, idx, filt=0):
if hasattr(idx,'__len__') and getattr(idx,'ndim',1):
# rank>0 collection
if isinstance(idx[0],bool):
assert len(idx)==self.len(filt) # bool mask
return [self.get(i,filt) for i,m in enumerate(idx) if m]
return [self.get(i,filt) for i in idx] # index list
if self.filters: idx = self.filters[filt][idx]
res = self.items[idx]
if self.tfms: res = apply_all(res, self.tfms, filt=filt, filter_kwargs=True, **self.tfm_kwargs)
return res
def decode(self, o, filt=0):
if self.tfms:
return apply_all(o, [getattr(f, 'decode', noop) for f in reversed(self.tfms)],
filt=filt, filter_kwargs=True, **self.tfm_kwargs)
def __iter__(self):
for i in range_of(self.filters):
yield (self.get(j,i) for j in range(self.len(i)))
def __eq__(self,b):
if not isinstance(b,DataSource): b = DataSource(b)
if len(b) != len(self): return False
for i in range_of(self.filters):
if b.len(i) != self.len(i): return False
return all(self.get(j,i)==b.get(j,i) for j in range_of(self.filters[i]))
def __repr__(self):
res = f'{self.__class__.__name__}\n'
for i,o in enumerate(self):
l = self.len(i)
res += f'{i}: ({l} items) ['
res += ','.join(itertools.islice(map(str,o), 10))
if l>10: res += '...'
res += ']\n'
return res
@property
def train(self): return self[0]
@property
def valid(self): return self[1]
A FilteredList is a convenience access to one dataset of a DataSource.
#export
class FilteredList:
def __init__(self, dsrc, filt): self.dsrc,self.filt = dsrc,filt
def __getitem__(self,i): return self.dsrc.get(i,self.filt)
def __len__(self): return self.dsrc.len(self.filt)
def __iter__(self):
return (self.dsrc.get(j,self.filt) for j in range_of(self))
def __repr__(self):
res = f'({len(self)} items) ['
res += ','.join(itertools.islice(map(str,self), 10))
if len(self)>10: res += '...'
res += ']\n'
return res
def decode(self, o): return self.dsrc.decode(o, self.filt)
# test
#Indexing
dsrc = DataSource(range(5))
test_eq(dsrc,[0,1,2,3,4])
test_eq(dsrc.sublist(0),[0,1,2,3,4])
test_ne(dsrc,[0,1,2,3,5])
test_eq(dsrc.get(2),2)
test_eq(dsrc.get([1,2]),[1,2])
test_eq(dsrc.get([True,False,False,True,False]),[0,3])
# test
#filters can be indices or boolean masks
dsrc = DataSource(range(5), filters=[[0,2], [1,3,4]])
test_eq(list(dsrc[0]),[0,2])
test_eq(list(dsrc[1]),[1,3,4])
#Subsets don't have to be disjoints
dsrc = DataSource(range(5), filters=[[False,True,True,False,True], [True,False,False,True,True]])
test_eq(list(dsrc[0]),[1,2,4])
test_eq(list(dsrc[1]),[0,3,4])
# test
#Base transform
dsrc = DataSource(range(5), lambda x:x*2)
test_eq(dsrc,[0,2,4,6,8])
test_eq(dsrc.sublist(0),[0,2,4,6,8])
test_ne(dsrc,[1,2,4,6,8])
test_eq(dsrc.get(2), 4)
test_eq(dsrc.get([1,2]), [2,4])
test_eq(dsrc.get([True,False,False,True,False]), [0,6])
# test
#Different transforms for the two subsets
dsrc = DataSource(range(5), lambda x,filt:x if filt == 0 else x*2, [[1,2],[0,3,4]])
test_eq(dsrc.sublist(0),[1,2])
test_eq(dsrc.sublist(1),[0,6,8])
test_eq(dsrc.get(2,1), 8)
test_eq(dsrc.get([1,2], 1), [6,8])
test_eq(dsrc.get([False,True], 0), [2])
# test
def add(x, a=1): return x+a
def multiply(x, a=2): return x*a
def square(x): return x**2
#Tfms are ordered by their `_order` ket when applied
#Test _order
square._order = 0
multiply._order = 1
add._order = 2
dsrc = DataSource([0,1,2,3], tfms=[add, multiply, square], filters=[[0,1,2,3]])
test_eq(dsrc.get(2), ((2**2) * 2) + 1)
#Kwargs are passed to tfms when they can be
#Test kwargs
dsrc = DataSource([0,1,2,3], tfms=[add, multiply, square], filters=[[0,1,2,3]], a=3)
test_eq(dsrc.get(2), ((2**2) * 3) + 3)
#Test decode
def add_undo(x, a=1): return x-a
def multiply_undo(x, a=2): return x/a
add.decode = add_undo
multiply.decode = multiply_undo
dsrc = DataSource([0,1,2,3], tfms=[add, multiply, square], filters=[[0,1,2,3]])
test_eq(dsrc.decode(9), (9-1)/2)
dsrc = DataSource([0,1,2,3], tfms=[add, multiply, square], filters=[[0,1,2,3]], a=3)
test_eq(dsrc.decode(9), (9-3)/3)
# test
dsrc = DataSource(range(5), lambda x,filt:x if filt == 0 else x*2, [[1,2],[0,3,4]])
fl = dsrc[1]
test_eq(list(fl),[0,6,8])
test_eq(fl[2], 8)
test_eq(fl[[1,2]], [6,8])
test_eq(fl[[False,True,True]], [6,8])
This is the base class to define a transform when we want something more complex than a function. The _order helps sort the transforms before applying them. setup is a preparation step to get the state ready on the data (which is a DataSource). __call__ is the main function that applies the transform to o and decode does the reverse operation.
NB: You should only implement decode if your transform needs to be reversed for display purposes. For instance we want to reverse the operation class to index, but we don't want to reverse the operation open this image.
# export
class Transform():
_order = 0
def setup(self, dsrc): return # 1-time setup
def __call__(self,o): return o # transform
def decode(self,o): return o # reverse transform for display
On top of this, a tranform that is applied to define a base object can have a show method: for instance the transform that opens an Image has a show method. When trying to display objects, the API will decode it and grab the first transform that provides a show method (this can be overriden by passing a custom show but we'll see that later).
The next transform is a bit more complex and is responsible for converting a single item to xs/ys.
#export
def _get_show_func(tfms):
for t in reversed(tfms):
if hasattr(t, 'show') and t.show is not None: return t.show
return None
def show_xs(xs, shows, ax=None, **kwargs):
for x,show in zip(xs,shows):
# can pass func or obj with a `show` method
show = getattr(show, 'show', show)
ax = feed_kwargs(show, x, ax=ax, **kwargs)
#export
class TupleTransform():
def __init__(self, *tfms): self.tfms = [order_sorted(tfm) for tfm in listify(tfms)]
def __call__(self, o, filt=0, **kwargs):
return [apply_all(o, tfm, filt=filt, filter_kwargs=True, **kwargs) for tfm in self.tfms]
def decode(self, o, filt=0, **kwargs):
return [apply_all(x, [getattr(f, 'decode', noop) for f in reversed(tfm)], filt=filt,
filter_kwargs=True, **kwargs)
for x,tfm in zip(o,self.tfms)]
def setup(self, dsrc):
old_tfms = getattr(dsrc, 'tfms', []).copy()
for tfm in self.tfms:
for t in tfm:
getattr(t, 'setup', noop)(dsrc)
dsrc.tfms.append(t)
dsrc.tfms = old_tfms.copy()
def show(self, o, shows=None, **kwargs):
shows = shows or [None]*len(self.tfms)
shows = [ifnone(show, _get_show_func(tfm)) for tfm,show in zip(self.tfms,shows)]
show_xs(o, shows, **kwargs)
#export
image_extensions = set(k for k,v in mimetypes.types_map.items() if v.startswith('image/'))
def get_image_files(path, include=None, **kwargs):
"Get image files in `path` recursively."
return get_files(path, extensions=image_extensions, recurse=True, include=include)
def image_getter(suf='', **kwargs):
def _inner(o, **kwargs):
return get_image_files(o/suf, **kwargs)
return _inner
# test
path = untar_data(URLs.MNIST_TINY)
test_eq(len(get_image_files(path)),1428)
test_eq(len(get_image_files(path/'train')),709)
test_eq(len(get_image_files(path, include='train')),709)
test_eq(len(get_image_files(path, include=['train','valid'])),1408)
# test
path = untar_data(URLs.MNIST_TINY)
test_eq(len(image_getter()(path)),1428)
test_eq(len(image_getter('train')(path)),709)
Convention: a function that has the name of a verb and ends with er returns a function (to get transforms directly or for use in the high level API belox).
# export
def random_splitter(valid_pct=0.2, seed=None, **kwargs):
"Split `items` between train/val with `valid_pct` randomly."
def _inner(o, **kwargs):
if seed is not None: torch.manual_seed(seed)
rand_idx = torch.randperm(len(o))
cut = int(valid_pct * len(o))
return rand_idx[cut:],rand_idx[:cut]
return _inner
#test
trn,val = random_splitter(seed=42)([0,1,2,3,4,5])
test_equal(trn, tensor([3, 2, 4, 1, 5]))
test_equal(val, tensor([0]))
# export
def _grandparent_mask(items, name):
return [(o.parent.parent.name if isinstance(o, Path) else o.split(os.path.sep)[-2]) == name for o in items]
def grandparent_splitter(train_name='train', valid_name='valid', **kwargs):
"Split `items` from the grand parent folder names (`train_name` and `valid_name`)."
def _inner(o, **kwargs):
return _grandparent_mask(o, train_name),_grandparent_mask(o, valid_name)
return _inner
path = untar_data(URLs.MNIST_TINY)
#test
#With string filenames
path = untar_data(URLs.MNIST_TINY)
items = [path/'train'/'3'/'9932.png', path/'valid'/'7'/'7189.png',
path/'valid'/'7'/'7320.png', path/'train'/'7'/'9833.png',
path/'train'/'3'/'7666.png', path/'valid'/'3'/'925.png',
path/'train'/'7'/'724.png', path/'valid'/'3'/'93055.png']
trn,val = grandparent_splitter()(items)
test_eq(trn,[True,False,False,True,True,False,True,False])
test_eq(val,[False,True,True,False,False,True,False,True])
# export
def parent_label(o, **kwargs):
"Label `item` with the parent folder name."
return o.parent.name if isinstance(o, Path) else o.split(os.path.sep)[-1]
def re_labeller(pat):
"Label `item` with regex `pat`."
pat = re.compile(pat)
def _inner(o, **kwargs):
res = pat.search(str(o))
assert res,f'Failed to find "{pat}" in "{o}"'
return res.group(1)
return _inner
Let's grab the Pets dataset first. Our DataSource will contain all the image files as items, and we'll randomly select two filters with 80% and 20% of the data. To get our xs, we need to apply a Transform that opens the image in the filenames. We'll call it Imagify.
#export
def show_image(im, ax=None, figsize=None, **kwargs):
"Show a PIL image on `ax`."
if ax is None: _,ax = plt.subplots(figsize=figsize)
if isinstance(im,Tensor) and im.shape[0]<5: im=im.permute(1,2,0)
ax.imshow(im, **kwargs)
ax.axis('off')
return ax
#export
class Imagify(Transform):
def __init__(self, f=PIL.Image.open, cmap=None, alpha=1.): self.f,self.cmap,self.alpha = f,cmap,alpha
def __call__(self, fn): return PIL.Image.open(fn)
def show(self, im, ax=None, figsize=None, cmap=None, alpha=None):
cmap = ifnone(cmap,self.cmap)
alpha = ifnone(alpha,self.alpha)
return show_image(im, ax, figsize=figsize, cmap=cmap, alpha=alpha)
To get our ys, we'll need to apply the re pattern (function re_labeller from before) and a transform that creates the list of classes, then maps categories to their indices. We call that transform Categorize. It needs a setup to create the classes, and it's reversible so we implement decode. This is also a base a transform so we implement its show method.
Since it needs to run after the labelling transform (here our re pattern labeller) we give it an _order of 1.
# export
class Categorize(Transform):
_order=1
def __init__(self): self.vocab = None
def __call__(self,o): return self.o2i[o]
def decode(self, o): return self.vocab[o]
def show(self, o, ax=None):
if ax is None: print(o)
else: ax.set_title(o)
def setup(self, dsrc):
if self.vocab is not None: return
vals = [o for o in dsrc.train]
self.vocab,self.o2i = uniqueify(vals, sort=True, bidir=True)
Now we can create a DataSource that contains our dataset. We grab all the image files, split them randomly and build a TupleTransform from open an image / labelling + categorizing.
source = untar_data(URLs.PETS)/"images"
items = get_image_files(source)
split_idx = random_splitter()(items)
xt = Imagify()
yt = Categorize()
labeller = re_labeller(pat = r'/([^/]+)_\d+.jpg$')
tfm = TupleTransform(xt,[labeller,yt])
pets = DataSource(items, tfm, split_idx)
To access an element we need to specify index/filter (the latter defaults to 0)
xy = pets.get(0,0)
xy
We can decode an element for display purposes!
xy = pets.decode((xy), 0)
The show_xs function is there to combine the show methods of the base transforms to display x and y together. We can either pass a transform that has a show method or a custom list of show methods.
show_xs(xy, (xt, yt))
Let's monkey-patch a show method to DataSource to do this automatically for us. The TupleTransform will use show_xs by default, but can either pass a custom show_func, or also use the kwargs to pass along a custom list of show methods (set None for the ones you don't want to override).
# export
def _dsrc_show(self, o, filt=0, show_func=None, **kwargs):
o = self.decode(o, filt)
if show_func is None: show_func=_get_show_func(self.tfms)
show_func(o, **kwargs)
DataSource.show = _dsrc_show
def _fl_show(self, o, show_func=None, **kwargs):
o = self.decode(o)
if show_func is None: show_func=_get_show_func(self.dsrc.tfms)
show_func(o, **kwargs)
FilteredList.show = _fl_show
pets.show(pets.get(0,0))
Before we can batch our images, we'll need to apply some basic image transformations: conerting to RGB, making them all the same size and also converting them to tensors. We have to get prepared for different kind of targets: sometimes the target won't be applied the transform, but sometimes it will and in different ways. We support images, segmentation masks, points or bounding boxes.
# export
TfmY = Enum('TfmY', 'Mask Image Point Bbox No')
The ImageTransform class seems ugly but it just dispatches the apply or decode function properly between x and y, and allow different implementations of each function. This will be very important for data augmentation in the next notebook.
# export
class ImageTransform():
"Basic class for image transforms."
_order,_data_aug = 10,False
def randomize(self): pass
def __call__(self, o, filt=0, **kwargs):
if self._data_aug and filt != 0: return o
x,*y = o
self.x,self.filt = x,filt # Saves the x in case it's needed in the apply for y and filt
self.randomize() # Ensures we have the same state for x and y
return (self.apply(x),) + tuple(self.apply_y(y_, **kwargs) for y_ in y)
def decode(self, o, filt=0, **kwargs):
if self._data_aug and filt != 0: return o
x,*y = o
self.x,self.filt = x,filt
return (self.unapply(x),) + tuple(self.unapply_y(y_, **kwargs) for y_ in y)
def _tfm_name(self, t, is_decode=False):
return ('unapply_' if is_decode else 'apply_') + t.name.lower()
def apply_no(self,y): return y
def apply_image(self, y): return self.apply(y)
def apply_mask(self, y): return self.apply_image(y)
def apply_point(self, y): return y
def apply_bbox(self, y): return self.apply_point(y)
def apply(self, x): return x
def apply_y(self, y, tfm_y=TfmY.No):
return getattr(self, self._tfm_name(tfm_y))(y)
def unapply_no(self,y): return y
def unapply_image(self, y): return self.unapply(y)
def unapply_mask(self, y): return self.unapply_image(y)
def unapply_point(self, y): return y
def unapply_bbox(self, y): return self.unapply_point(y)
def unapply(self, x): return x
def unapply_y(self, y, tfm_y=TfmY.No):
return getattr(self, self._tfm_name(tfm_y,True))(y)
# test
import random
class FakeTransform(ImageTransform):
def randomize(self): self.a = random.randint(1,10)
def apply(self, x): return x + self.a
def apply_mask(self, x): return x + 5
def apply_point(self, x): return x + 2
tfm = FakeTransform()
(x,y) = (5,10)
#Basic behavior: x has changed, not y
t1 = tfm((x,y))
assert t1[0]!=x and t1[1]==y, t1
#Check the same random integer was used for x and y when transforming y
t1 = tfm((x,y), tfm_y=TfmY.Image)
test_eq(t1[0] - 5,t1[1] - 10)
#Check mask, point,bbox implementations
test_eq(tfm((x,y), tfm_y=TfmY.Mask) [1],15)
test_eq(tfm((x,y), tfm_y=TfmY.Point)[1],12)
test_eq(tfm((x,y), tfm_y=TfmY.Bbox) [1],12)
Our first transform decodes an image to 'RGB'. We can specify different modes for the xs and ys, the default is 'RGB' for x, then mode_x for y if our ys are images, 'L' for y if our ys are segmentation masks.
#export
class DecodeImg(ImageTransform):
"Convert regular image to RGB, masks to L mode."
def __init__(self, mode_x='RGB', mode_y=None): self.mode_x,self.mode_y = mode_x,mode_y
def apply(self, x): return x.convert(self.mode_x)
def apply_image(self, y): return y.convert(ifnone(self.mode_y,self.mode_x))
def apply_mask(self, y): return y.convert(ifnone(self.mode_y,'L'))
Our second transform resizes an image, using a given mode. It defaults to bilinear for images and nearest for segmentation masks.
# export
class ResizeFixed(ImageTransform):
"Resize image to `size` using `mode_x` (and `mode_y` on targets)."
_order=15
def __init__(self, size, mode_x=PIL.Image.BILINEAR, mode_y=None):
if isinstance(size,int): size=(size,size)
self.size = (size[1],size[0]) #PIL takes size in the otherway round
self.mode_x,self.mode_y = mode_x,mode_y
def apply(self, x): return x.resize(self.size, self.mode_x)
def apply_image(self, y): return y.resize(self.size, ifnone(self.mode_y,self.mode_x))
def apply_mask(self, y): return y.resize(self.size, ifnone(self.mode_y,PIL.Image.NEAREST))
The transformation to tensors is done in two steps just in case one wants to apply transforms to byte tensors. The permutation of axes needs to be reversed for display, so we have an unapply function (which is what is called by decode in an ImageTransform).
# export
class ToByteTensor(ImageTransform):
"Transform our items to byte tensors."
_order=20
def apply(self, x):
res = torch.ByteTensor(torch.ByteStorage.from_buffer(x.tobytes()))
w,h = x.size
return res.view(h,w,-1).permute(2,0,1)
def unapply(self, x):
return x[0] if x.shape[0] == 1 else x.permute(1,2,0)
Lastly we convert our tensors to floats (or ints for segmentation masks) and divides by 255 (can specify a different value and a div_y)
# export
class ToFloatTensor(ImageTransform):
"Transform our items to float tensors (int in the case of mask)."
_order=5 #Need to run after CUDA on the GPU
def __init__(self, div_x=255., div_y=None): self.div_x,self.div_y = div_x,div_y
def apply(self, x): return x.float().div_(self.div_x)
def apply_mask(self, x):
return x.long() if self.div_y is None else x.long().div_(self.div_y)
def unapply(self, x): return torch.clamp(x, 0, 1)
def unapply_mask(self, x): return x
Let's test it's all work properly.
tfms = [DecodeImg(), ResizeFixed(128), ToByteTensor()]
pets_t = pets.transformed(tfms)
pets_t.show(pets_t.get(0,0))
With transforms to make our images tensors of the same size, we're ready to create batches and dataloaders. We wrap a PyTorch dataloader to add batch transforms. Additional kwargs will be passed along. Those transforms can be decoded like before, for display purposes (like normalization).
# export
class TfmDataLoader():
def __init__(self, dl, tfms=None, **tfm_kwargs):
self.dl,self.tfms,self.tfm_kwargs = dl,order_sorted(tfms),tfm_kwargs
def __len__(self): return len(self.dl)
def __iter__(self):
for b in self.dl: yield apply_all(b, self.tfms, filter_kwargs=True, **self.tfm_kwargs)
def one_batch(self): return next(iter(self))
def decode_batch(self): return self.decode(self.one_batch())
def decode(self, o):
return apply_all(o, [getattr(f, 'decode', noop) for f in reversed(self.tfms)],
filter_kwargs=True, **self.tfm_kwargs)
@property
def dataset(self): return self.dl.dataset
Then we add a basic function to create dataloaders from a DataSource.
# export
from torch.utils.data.dataloader import DataLoader
def get_dl(dset, bs=64, tfms=None, tfm_kwargs=None, **kwargs):
dl = DataLoader(dset, bs, **kwargs)
return TfmDataLoader(dl, tfms=tfms, **(ifnone(tfm_kwargs,{})))
def get_dls(dsrc, bs=64, tfms=None, tfm_kwargs=None, **kwargs):
return [get_dl(dsrc[i], bs, shuffle=i==0, tfms=tfms, tfm_kwargs=tfm_kwargs, **kwargs)
for i in range_of(dsrc)]
dls = get_dls(pets_t, tfms=ToFloatTensor())
This is a convenience function to grab the k-th item in a batch, even if the batch is constituted of lists of tensors.
#export
def grab_item(b,k):
if isinstance(b, (list,tuple)): return [grab_item(o,k) for o in b]
return b[k]
# export
class DataBunch():
"Basic wrapper around several `DataLoader`s."
def __init__(self, *dls): self.dls = dls
def __getitem__(self, i): return self.dls[i]
@property
def train_dl(self): return self.dls[0]
@property
def valid_dl(self): return self.dls[1]
@property
def train_ds(self): return self.train_dl.dataset
@property
def valid_ds(self): return self.valid_dl.dataset
def show_batch(self, i=0, items=9, cols=3, figsize=None, show_func=None, **kwargs):
b = self.dls[i].decode_batch()
rows = (items+cols-1) // cols
if figsize is None: figsize = (cols*3, rows*3)
fig,axs = plt.subplots(rows, cols, figsize=figsize)
for k,ax in enumerate(axs.flatten()):
self[i].dataset.show(grab_item(b,k), ax=ax, show_func=show_func, **kwargs)
data = DataBunch(*dls)
x,y = data[0].one_batch()
x.shape,x.type(),y.shape,y.type()
data.show_batch()
Finally let's monkey-patch a databunch function in DataSource to quickly create a databunch.
# export
def _dsrc_databunch(self, bs=64, tfms=None, tfm_kwargs=None, **kwargs):
return DataBunch(*get_dls(self, bs=bs, tfms=tfms, tfm_kwargs=tfm_kwargs, **kwargs))
DataSource.databunch = _dsrc_databunch
device = torch.device('cuda',0)
# export
from fastai.torch_core import to_device, to_cpu
import torch.nn.functional as F
# export
class Cuda(Transform):
_order = 0
def __init__(self,device): self.device=device
def __call__(self, b, tfm_y=TfmY.No): return to_device(b, self.device)
def decode(self, b): return to_cpu(b)
We'll see other batch transforms in the next chapter but one that is pretty common is normalization.
# export
class Normalize(Transform):
_order=99
def __init__(self, mean, std, do_x=True, do_y=False):
self.mean,self.std,self.do_x,self.do_y = mean,std,do_x,do_y
def __call__(self, b):
x,y = b
if self.do_x: x = self.normalize(x)
if self.do_y: y = self.normalize(y)
return x,y
def decode(self, b):
x,y = b
if self.do_x: x = self.denorm(x)
if self.do_y: y = self.denorm(y)
return x,y
def normalize(self, x): return (x - self.mean) / self.std
def denorm(self, x): return x * self.std + self.mean
mean,std = tensor([0.5,0.5,0.5]).view(1,-1,1,1).cuda(),tensor([0.5,0.5,0.5]).view(1,-1,1,1).cuda()
data = pets_t.databunch(tfms = [Cuda(device), ToFloatTensor(), Normalize(mean,std)])
data.show_batch()
To make it easy to use this data block API, we add a high level class. User provides the type of inputs/targets in types. Then they implement the three following functions to gather the data:
get_items takes the source and retun the list of all itemssplit take the items and returns two (or more) list of indices or boolean masks that explain how to split the data in train and valid (potentially valids) set.label_func return the corresponding label for on item.Then during the intilialization, default transforms for x, y and the full datasource are collected (they can be overriden by a custom tfms_x, tfms_y or tfm_ds).
When calling datasource, the source is fetched by calling get_source, which then allows to collect the items (with get_items) and the different splits (with split). label_func is added to the y transforms and a DataSource can be created, with additional tfms passed.
When calling databunch, the datasource (created with ds_tfms) is converted, with additional batch transforms (in dl_tfms).
An Item is just a class containing three attributes:
tfm default transforms associated to that itemtfm_ds default transforms associated to that item that are applied to the tuple (x,y)tfm_kwargs default kwargs to pass to all transforms (it will be filtered and only passed to the transforms that accept them). For instance {'tfm_y': TfmY.Mask} in a SegmentMask.ds_tfms = [DecodeImg(), ResizeFixed(128), ToByteTensor()]
dl_tfms = [Cuda(device), ToFloatTensor()]
#export
class Item(): tfm,tfm_ds,tfm_kwargs = None,None,None
def resolve_tfms(o, tfmx, tfmy=None):
if o is not None: return o
return [t() for t in listify(tfmx)+listify(tfmy)]
#export
class DataBlock():
types = (Item,Item)
@staticmethod
def get_items(source): raise NotImplementedError
@staticmethod
def split(items): raise NotImplementedError
@staticmethod
def label_func(item): raise NotImplementedError
def __init__(self, source, tfms=None, tfms_ds=None):
self.source = source
if tfms is None: tfms = (None,)*len(self.types)
self.tfms = [resolve_tfms(tfm, x.tfm) for tfm,x in zip(tfms,self.types)]
self.tfms_ds = resolve_tfms(tfms_ds, *[getattr(x,"tfm_ds") for x in self.types[:2]])
self.tfm_kwargs = {}
for t in self.types: self.tfm_kwargs.update(t.tfm_kwargs or {})
def datasource(self, tfms=None, **tfm_kwargs):
cls = self.__class__
items = cls.get_items(self.source, self=self)
split_idx = cls.split(items, self=self)
lfs = getattr(cls, 'label_funcs', (noop,cls.label_func))
ttfms = [[partial(lf, self=self)]+listify(tfm) for lf,tfm in zip(lfs,self.tfms)]
ds = DataSource(items, TupleTransform(*ttfms), split_idx)
ds = ds.transformed(self.tfms_ds + listify(tfms), **{**self.tfm_kwargs, **tfm_kwargs})
return ds
def databunch(self, ds_tfms=None, dl_tfms=None, bs=64, tfm_kwargs=None, **kwargs):
tfm_kwargs = ifnone(tfm_kwargs, {})
dls = get_dls(self.datasource(tfms=ds_tfms, **tfm_kwargs), bs, tfms=dl_tfms,
tfm_kwargs={**self.tfm_kwargs, **tfm_kwargs}, **kwargs)
return DataBunch(*dls)
@property
def xt(self): return self.tfms[0][0]
@property
def yt(self): return self.tfms[1][0]
Here are some examples of items:
# export
class Image(Item): tfm = Imagify
class Category(Item): tfm = Categorize
And here is an example of use of the API:
class PetsData(DataBlock):
types = Image,Category
get_items = image_getter()
split = random_splitter()
label_func = re_labeller(pat = r'/([^/]+)_\d+.jpg$')
source = untar_data(URLs.PETS)/"images"
dsrc = PetsData(source)
data = dsrc.databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms)
data.show_batch()
' '.join(dsrc.yt.vocab)
class MnistData(DataBlock):
types = (Image, Category)
get_items = get_image_files
split = grandparent_splitter(train_name='training', valid_name='testing')
label_func = parent_label
source = untar_data(URLs.MNIST)
data = MnistData(source).databunch(ds_tfms=[ToByteTensor()], dl_tfms=dl_tfms)
data.show_batch()
There are seveal ways to get the show display properly our images. First we can pass a custom show_func method and change the function that shows the x.
data.show_batch(shows = (partial(show_image, cmap='gray'), None))
Or trust the API to dispatch the kwargs
data.show_batch(cmap='gray')
Or we could create a BlackAndWhiteImage class that uses the transform Imagify with a default cmap to gray:
class BlackAndWhiteImage(Item):
tfm = partial(Imagify, cmap='gray')
class MnistData(DataBlock):
types = (BlackAndWhiteImage, Category)
get_items = get_image_files
split = grandparent_splitter(train_name='training', valid_name='testing')
label_func = parent_label
data = MnistData(source).databunch(ds_tfms=[ToByteTensor()], dl_tfms=dl_tfms)
data.show_batch()
path = untar_data(URLs.PLANET_SAMPLE)
# export
class MultiCategorize(Transform):
_order=1
def __init__(self): self.vocab = None
def __call__(self,x): return [self.o2i[o] for o in x if o in self.o2i]
def decode(self, o): return [self.vocab[i] for i in o]
def show(self, o, ax=None):
(print if ax is None else ax.set_title)(';'.join(o))
def setup(self, dsrc):
if self.vocab is not None: return
vals = set()
for c in dsrc.train: vals = vals.union(set(c))
self.vocab,self.o2i = uniqueify(list(vals), sort=True, bidir=True)
class OneHotEncode(Transform):
_order=10
def setup(self, items):
self.c = None
for tfm in items.tfms:
if isinstance(tfm, MultiCategorize): self.c = len(tfm.vocab)
def __call__(self, o): return onehot(o, self.c) if self.c is not None else o
def decode(self, o): return [i for i,x in enumerate(o) if x == 1]
class MultiCategory(Item):
tfm = [MultiCategorize, OneHotEncode]
# test
tfm = MultiCategorize()
#Even if 'c' is the first class, vocab is sorted for reproducibility
ds = DataSource([['c','a'], ['a','b'], ['b'], []], filters=[[0,1,2,3], []])
tfm.setup(ds)
test_eq(tfm.vocab,['a','b','c'])
test_eq(tfm(['b','a']),[1,0])
test_eq(tfm.decode([2,0]),['c','a'])
# export
def get_str_column(df, col_name, prefix='', suffix='', delim=None):
"Read `col_name` in `df`, optionnally adding `prefix` or `suffix`."
values = df[col_name].values.astype(str)
values = np.char.add(np.char.add(prefix, values), suffix)
if delim is not None:
values = np.array(list(csv.reader(values, delimiter=delim)))
return values
# test
df = pd.DataFrame({'a': ['cat', 'dog', 'car'], 'b': ['a b', 'c d', 'a e']})
test_np_eq(get_str_column(df, 'a'), np.array(['cat', 'dog', 'car']))
test_np_eq(get_str_column(df, 'a', prefix='o'), np.array(['ocat', 'odog', 'ocar']))
test_np_eq(get_str_column(df, 'a', suffix='.png'), np.array(['cat.png', 'dog.png', 'car.png']))
test_np_eq(get_str_column(df, 'b', delim=' '), np.array([['a','b'], ['c','d'], ['a','e']]))
class PlanetData(DataBlock):
types = Image,MultiCategory
def get_items(source, self):
df = pd.read_csv(self.source/'labels.csv')
items = get_str_column(df, 'image_name', prefix=f'{self.source}/train/', suffix='.jpg')
labels = get_str_column(df, 'tags', delim=' ')
self.item2label = {i:s for i,s in zip(items,labels)}
return items
split = random_splitter()
def label_func(item, self): return self.item2label[item]
source = untar_data(URLs.PLANET_SAMPLE)
dsrc = PlanetData(source)
data = dsrc.databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms)
data.show_batch()
# export
class SegmentMask(Item):
tfm = partial(Imagify, cmap='tab20', alpha=0.5)
tfm_kwargs = {'tfm_y': TfmY.Mask}
class CamvidData(DataBlock):
types = Image,SegmentMask
get_items = image_getter('images')
split = random_splitter()
label_func = lambda o,self: self.source/'labels'/f'{o.stem}_P{o.suffix}'
source = untar_data(URLs.CAMVID_TINY)
data = CamvidData(source).databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)
data.show_batch(cmap='tab20')
import pickle
# export
class PointScaler(Transform):
_order = 5 #Run before we apply any ImageTransform
def __init__(self, do_scale=True, y_first=False):
self.do_scale,self.y_first = do_scale,y_first
def __call__(self, o, tfm_y=TfmY.No):
x,y = o
if not isinstance(y, torch.Tensor): y = tensor(y)
y = y.view(-1, 2).float()
if not self.y_first: y = y.flip(1)
if self.do_scale: y = y * 2/tensor(list(x.size)).float() - 1
return (x,y)
def decode(self, o, tfm_y=TfmY.No):
x,y = o
y = y.flip(1)
y = (y+1) * tensor([x.shape[:2]]).float()/2
return (x,y)
class PointShow(Transform):
def show(self, x, ax=None):
params = {'s': 10, 'marker': '.', 'c': 'r'}
ax.scatter(x[:, 1], x[:, 0], **params)
class Points(Item):
tfm = PointShow
tfm_ds = PointScaler
tfm_kwargs = {'tfm_y': TfmY.Point}
class BiwiData(DataBlock):
types = Image,Points
def __init__(self, source, *args, **kwargs):
super().__init__(source, *args, **kwargs)
self.fn2ctr = pickle.load(open(source/'centers.pkl', 'rb'))
get_items = image_getter('images')
split = random_splitter()
label_func = lambda o,self: self.fn2ctr[o.name]
dblk = BiwiData(untar_data(URLs.BIWI_SAMPLE))
data = dblk.databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)
data.show_batch()
#export
from fastai.vision.data import get_annotations
from matplotlib import patches, patheffects
def _draw_outline(o, lw):
o.set_path_effects([patheffects.Stroke(linewidth=lw, foreground='black'), patheffects.Normal()])
def _draw_rect(ax, b, color='white', text=None, text_size=14, hw=True, rev=False):
lx,ly,w,h = b
if rev: lx,ly,w,h = ly,lx,h,w
if not hw: w,h = w-lx,h-ly
patch = ax.add_patch(patches.Rectangle((lx,ly), w, h, fill=False, edgecolor=color, lw=2))
_draw_outline(patch, 4)
if text is not None:
patch = ax.text(lx,ly, text, verticalalignment='top', color=color, fontsize=text_size, weight='bold')
_draw_outline(patch,1)
#export
class BBoxScaler(PointScaler):
def __call__(self, o, tfm_y=TfmY.Bbox):
x,y = o
return x, (super().__call__((x,y[0]))[1].view(-1,4),y[1])
def decode(self, o, tfm_y=TfmY.Bbox):
x,y = o
_,bbox = super().decode((x,y[0].view(-1,2)))
return x, (bbox.view(-1,4),y[1])
class BBoxEncoder(Transform):
_order=1
def __init__(self): self.vocab = None
def __call__(self,o):
x,y = o
return (x,[self.otoi[o_] for o_ in y if o_ in self.otoi])
def decode(self, o):
x,y = o
return x, [self.vocab[i] for i in y]
def setup(self, dsrc):
if self.vocab is not None: return
vals = set()
for bb,c in dsrc.train: vals = vals.union(set(c))
self.vocab,self.otoi = uniqueify(list(vals), sort=True, bidir=True, start='#bg')
def show(self, x, ax):
bbox,label = x
for b,l in zip(bbox, label):
if l != '#bg': _draw_rect(ax, b, hw=False, rev=True, text=l)
# export
class BBox(Item): tfm,tfm_ds,tfm_kwargs = BBoxEncoder,BBoxScaler,{'tfm_y': TfmY.Bbox}
# export
def bb_pad_collate(samples, pad_idx=0):
max_len = max([len(s[1][1]) for s in samples])
bboxes = torch.zeros(len(samples), max_len, 4)
labels = torch.zeros(len(samples), max_len).long() + pad_idx
imgs = []
for i,s in enumerate(samples):
imgs.append(s[0][None])
bbs, lbls = s[1]
if not (bbs.nelement() == 0):
bboxes[i,-len(lbls):] = bbs
labels[i,-len(lbls):] = tensor(lbls)
return torch.cat(imgs,0), (bboxes,labels)
class CocoData(DataBlock):
types = Image,BBox
def __init__(self, source, *args, **kwargs):
super().__init__(source, *args, **kwargs)
images, lbl_bbox = get_annotations(source/'train.json')
self.img2bbox = dict(zip(images, lbl_bbox))
get_items = image_getter('train')
split = random_splitter()
label_func = lambda o,self: self.img2bbox[o.name]
def databunch(self, ds_tfms=None, dl_tfms=None, bs=64, tfm_kwargs=None, **kwargs):
return super().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=bs, tfm_kwargs=tfm_kwargs,
collate_fn=bb_pad_collate, **kwargs)
source = untar_data(URLs.COCO_TINY)
dblk = CocoData(source)
data = dblk.databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)
data.show_batch()
You can also use DataSource directly with just one transform that does everything, without using the blocks. You will have to provide your show method if you want to use show_batch however (no need to decode if you do everything in one transform).
def size_f(x): return tensor(x.size).float()
path = untar_data(URLs.COCO_TINY)
fns, lbl_bbox = get_annotations(path/'train.json')
img2bbox = dict(zip(fns, lbl_bbox))
class CocoTransform(Transform):
def __init__(self): self.vocab = None
def setup(self, data):
if self.vocab is not None: return
vals = set()
for c in data.train: vals = vals.union(set(img2bbox[c.name][1]))
self.vocab,self.otoi = uniqueify(list(vals), sort=True, bidir=True, start='#bg')
def __call__(self, fn):
img = PIL.Image.open(fn)
bbox,lbl = img2bbox[fn.name]
#flip and rescale to -1,1
bbox = tensor(bbox).view(-1,2).flip(1) * 2/size_f(img) - 1
lbl = [self.otoi[l] for l in lbl if l in self.otoi]
return (img, [bbox.view(-1,4), lbl])
def show(self, o, ax):
img,(bbox,lbl) = o
show_image(img, ax)
lbl = [self.vocab[l] for l in lbl if l != 0] #Unpad and decode
bbox = bbox[-len(lbl):,] #Unpad
bbox = (bbox.view(-1,2) + 1) * tensor(img.shape[:2]).float() / 2
bbox = bbox.flip(1).view(-1,4)
for b,l in zip(bbox, lbl): _draw_rect(ax, b, hw=False, rev=True, text=l)
fnames = get_image_files(path)
splits = random_splitter()(fnames)
ds = DataSource(fnames, tfms=CocoTransform(), filters=splits, tfm_y=TfmY.Bbox)
ds = ds.transformed(tfms=ds_tfms)
data = DataBunch(*get_dls(ds, 16, collate_fn=bb_pad_collate, tfms=dl_tfms))
data.show_batch()
class CocoData2(DataBlock):
types = Image,Item,MultiCategory
def __init__(self, source, *args, **kwargs):
super().__init__(source, *args, **kwargs)
images, lbl_bbox = get_annotations(source/'train.json')
self.img2bbox = dict(zip(images, lbl_bbox))
get_items = image_getter('train')
split = random_splitter()
label_funcs = (noop, lambda o,self: self.img2bbox[o.name][0],
lambda o,self: self.img2bbox[o.name][1])
def databunch(self, ds_tfms=None, dl_tfms=None, bs=64, tfm_kwargs=None, **kwargs):
return super().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=bs, tfm_kwargs=tfm_kwargs, **kwargs)
source = untar_data(URLs.COCO_TINY)
dblk = CocoData2(source)
dsrc = dblk.datasource(tfms=ds_tfms)
data = dblk.databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)
data.one_batch(1)
! python notebook2script.py "200_datablock_config.ipynb"