#default_exp data.core
#export
from fastai2.torch_basics import *
from fastai2.test import *
from fastai2.data.load import *
from nbdev.showdoc import *
Core functionality for gathering data
The classes here provide functionality for applying a list of transforms to a set of items (TfmdList, DataSource) or a DataLoader (TfmdDl) as well as the base class used to gather the data for model training: DataBunch.
#export
@typedispatch
def show_batch(x, y, samples, ctxs=None, max_n=9, **kwargs):
if ctxs is None: ctxs = Inf.nones
for i in range_of(samples[0]):
ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]
return ctxs
show_batch is a type-dispatched function that is responsible for showing decoded samples. x and y are the input and the target in the batch to be shown, and are passed along to dispatch on their types. There is a different implementation of show_batch if x is a TensorImage or a TensorText for instance (see vision.core or text.data for more details). ctxs can be passed but the function is responsible to create them if necessary. kwargs depend on the specific implementation.
#export
@typedispatch
def show_results(x, y, samples, outs, ctxs=None, max_n=9, **kwargs):
if ctxs is None: ctxs = Inf.nones
for i in range(len(samples[0])):
ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]
for i in range(len(outs[0])):
ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(outs.itemgot(i),ctxs,range(max_n))]
return ctxs
show_results is a type-dispatched function that is responsible for showing decoded samples and their corresponding outs. Like in show_batch, x and y are the input and the target in the batch to be shown, and are passed along to dispatch on their types. ctxs can be passed but the function is responsible to create them if necessary. kwargs depend on the specific implementation.
#export
_all_ = ["show_batch", "show_results"]
#export
_batch_tfms = ('after_item','before_batch','after_batch')
#export
@delegates()
class TfmdDL(DataLoader):
"Transformed `DataLoader`"
def __init__(self, dataset, bs=16, shuffle=False, num_workers=None, **kwargs):
if num_workers is None: num_workers = min(16, defaults.cpus)
for nm in _batch_tfms:
kwargs[nm] = Pipeline(kwargs.get(nm,None), as_item=(nm=='before_batch'))
super().__init__(dataset, bs=bs, shuffle=shuffle, num_workers=num_workers, **kwargs)
for nm in _batch_tfms: kwargs[nm].setup(self)
def _one_pass(self):
its = self.after_batch(self.do_batch([self.do_item(0)]))
self._device = find_device(its)
self._n_inp = 1 if not isinstance(its, (list,tuple)) or len(its)==1 else len(its)-1
self._retain_dl = partial(retain_types, typs=mapped(type,its))
def _retain_dl(self,b):
self._one_pass()
# we just replaced ourselves, so this is *not* recursive! :)
return self._retain_dl(b)
def before_iter(self):
super().before_iter()
split_idx = getattr(self.dataset, 'split_idx', None)
for nm in _batch_tfms:
f = getattr(self,nm)
if isinstance(f,Pipeline): f.split_idx=split_idx
def decode(self, b): return self.before_batch.decode(self.after_batch.decode(self._retain_dl(b)))
def decode_batch(self, b, max_n=9, full=True): return self._decode_batch(self.decode(b), max_n, full)
def _decode_batch(self, b, max_n=9, full=True):
f = self.after_item.decode
f = compose(f, partial(getattr(self.dataset,'decode',noop), full = full))
return L(batch_to_samples(b, max_n=max_n)).map(f)
def _pre_show_batch(self, b, max_n=9):
"Decode `b` to be ready for `show_batch`"
b = self.decode(b)
if hasattr(b, 'show'): return b,None,None
its = self._decode_batch(b, max_n, full=False)
if not is_listy(b): b,its = [b],L((o,) for o in its)
return detuplify(b[:self.n_inp]),detuplify(b[self.n_inp:]),its
def show_batch(self, b=None, max_n=9, ctxs=None, show=True, **kwargs):
if b is None: b = self.one_batch()
if not show: return self._pre_show_batch(b, max_n=max_n)
show_batch(*self._pre_show_batch(b, max_n=max_n), ctxs=ctxs, max_n=max_n, **kwargs)
def show_results(self, b, out, max_n=9, ctxs=None, show=True, **kwargs):
x,y,its = self.show_batch(b, max_n=max_n, show=False)
b_out = b[:self.n_inp] + (tuple(out) if is_listy(out) else (out,))
x1,y1,outs = self.show_batch(b_out, max_n=max_n, show=False)
res = (x,x1,None,None) if its is None else (x, y, its, outs.itemgot(slice(self.n_inp,None)))
if not show: return res
show_results(*res, ctxs=ctxs, max_n=max_n, **kwargs)
@property
def device(self):
if not hasattr(self, '_device'): _ = self._one_pass()
return self._device
@property
def n_inp(self):
if hasattr(self.dataset, 'n_inp'): return self.dataset.n_inp
if not hasattr(self, '_n_inp'): self._one_pass()
return self._n_inp
A TfmdDL is a DataLoader that creates Pipeline from a list of Transforms for the callbacks after_item, before_batch and after_batch. As a result, it can decode or show a processed batch.
add_docs(TfmdDL,
decode="Decode `b` using `tfms`",
decode_batch="Decode `b` entirely",
show_batch="Show `b` (defaults to `one_batch`), a list of lists of pipeline outputs (i.e. output of a `DataLoader`)",
show_results="Show each item of `b` and `out`",
before_iter="override")
class _Category(int, ShowTitle): pass
#Test retain type
class NegTfm(Transform):
def encodes(self, x): return torch.neg(x)
def decodes(self, x): return torch.neg(x)
tdl = TfmdDL([(TensorImage([1]),)] * 4, after_batch=NegTfm(), bs=4, num_workers=4)
b = tdl.one_batch()
test_eq(type(b[0]), TensorImage)
b = (tensor([1.,1.,1.,1.]),)
test_eq(type(tdl.decode_batch(b)[0][0]), TensorImage)
class A(Transform):
def encodes(self, x): return x
def decodes(self, x): return Int(x)
@Transform
def f(x)->None: return Tuple((x,x))
start = torch.arange(50)
test_eq_type(f(2), Tuple((2,2)))
a = A()
tdl = TfmdDL(start, after_item=lambda x: (a(x), f(x)), bs=4)
x,y = tdl.one_batch()
test_eq(type(y), Tuple)
s = tdl.decode_batch((x,y))
test_eq(type(s[0][1]), Tuple)
tdl = TfmdDL(torch.arange(0,50), after_item=A(), after_batch=NegTfm(), bs=4)
test_eq(tdl.dataset[0], start[0])
test_eq(len(tdl), (50-1)//4+1)
test_eq(tdl.bs, 4)
test_stdout(tdl.show_batch, '0\n1\n2\n3')
show_doc(TfmdDL.one_batch)
DataLoader.one_batch[source]
DataLoader.one_batch()
tfm = NegTfm()
tdl = TfmdDL(start, after_batch=tfm, bs=4)
b = tdl.one_batch()
test_eq(tensor([0,-1,-2,-3]), b)
show_doc(TfmdDL.decode)
test_eq(tdl.decode(b), tensor(0,1,2,3))
show_doc(TfmdDL.decode_batch)
test_eq(tdl.decode_batch(b), [0,1,2,3])
show_doc(TfmdDL.show_batch)
TfmdDL.show_batch[source]
TfmdDL.show_batch(b=None,max_n=9,ctxs=None,show=True, ****kwargs**)
Show b (defaults to one_batch), a list of lists of pipeline outputs (i.e. output of a DataLoader)
# export
@docs
class DataBunch(GetAttr):
"Basic wrapper around several `DataLoader`s."
_default='train_dl'
def __init__(self, *dls, path='.'): self.dls,self.path = dls,Path(path)
def __getitem__(self, i): return self.dls[i]
def new_empty(self):
dls = [dl.new(dl.dataset.new_empty()) for dl in self.dls]
return type(self)(*dls)
train_dl,valid_dl = add_props(lambda i,x: x[i])
train_ds,valid_ds = add_props(lambda i,x: x[i].dataset)
@classmethod
@delegates(TfmdDL.__init__)
def from_dblock(cls, dblock, source, path='.', type_tfms=None, item_tfms=None, batch_tfms=None, **kwargs):
return dblock.databunch(source, path=path, type_tfms=type_tfms, item_tfms=item_tfms, batch_tfms=batch_tfms, **kwargs)
_docs=dict(__getitem__="Retrieve `DataLoader` at `i` (`0` is training, `1` is validation)",
train_dl="Training `DataLoader`",
valid_dl="Validation `DataLoader`",
train_ds="Training `Dataset`",
valid_ds="Validation `Dataset`",
new_empty="Create a new empty version of `self` with the same transforms",
from_dblock="Create a databunch from a given `dblock`")
dbch = DataBunch(tdl,tdl)
x = dbch.train_dl.one_batch()
x2 = first(tdl)
test_eq(x,x2)
x2 = dbch.one_batch()
test_eq(x,x2)
show_doc(DataBunch.__getitem__)
DataBunch.__getitem__[source]
DataBunch.getitem(i)
Retrieve DataLoader at i (0 is training, 1 is validation)
x2 = dbch[0].one_batch()
test_eq(x,x2)
show_doc(DataBunch.train_dl, name="train_dl")
train_dl[source]Training DataLoader
show_doc(DataBunch.valid_dl, name="valid_dl")
valid_dl[source]Validation DataLoader
show_doc(DataBunch.train_ds, name="train_ds")
train_ds[source]Training Dataset
show_doc(DataBunch.valid_ds, name="valid_ds")
valid_ds[source]Validation Dataset
#export
class FilteredBase:
"Base class for lists with subsets"
_dl_type = TfmdDL
def __init__(self, *args, dl_type=None, **kwargs):
if dl_type is not None: self._dl_type = dl_type
self.databunch = delegates(self._dl_type.__init__)(self.databunch)
super().__init__(*args, **kwargs)
@property
def n_subsets(self): return len(self.splits)
def _new(self, items, **kwargs): return super()._new(items, splits=self.splits, **kwargs)
def subset(self): raise NotImplemented
def databunch(self, bs=16, val_bs=None, shuffle_train=True, n=None, path='.', dl_type=None, dl_kwargs=None, **kwargs):
if dl_kwargs is None: dl_kwargs = [{}] * self.n_subsets
ns = self.n_subsets-1
bss = ([None]*(ns+1) if bs is None
else [bs] + [3*bs//2]*ns if val_bs is None
else [bs] + [val_bs]*ns)
shuffles = [shuffle_train] + [False]*ns
if dl_type is None: dl_type = self._dl_type
dls = [dl_type(self.subset(i), bs=b, shuffle=s, drop_last=s, n=n if i==0 else None, **kwargs, **dk)
for i,(b,s,dk) in enumerate(zip(bss,shuffles,dl_kwargs))]
return DataBunch(*dls, path=path)
FilteredBase.train,FilteredBase.valid = add_props(lambda i,x: x.subset(i))
#export
class TfmdList(FilteredBase, L, GetAttr):
"A `Pipeline` of `tfms` applied to a collection of `items`"
_default='tfms'
def __init__(self, items, tfms, use_list=None, do_setup=True, as_item=True, split_idx=None, train_setup=True, splits=None):
super().__init__(items, use_list=use_list)
self.splits = L([slice(None),[]] if splits is None else splits).map(mask2idxs)
if isinstance(tfms,TfmdList): tfms = tfms.tfms
if isinstance(tfms,Pipeline): do_setup=False
self.tfms = Pipeline(tfms, as_item=as_item, split_idx=split_idx)
if do_setup: self.setup(train_setup=train_setup)
def _new(self, items, **kwargs): return super()._new(items, tfms=self.tfms, do_setup=False, **kwargs)
def subset(self, i): return self._new(self._get(self.splits[i]), split_idx=i)
def _after_item(self, o): return self.tfms(o)
def __repr__(self): return f"{self.__class__.__name__}: {self.items}\ntfms - {self.tfms.fs}"
def __iter__(self): return (self[i] for i in range(len(self)))
def show(self, o, **kwargs): return self.tfms.show(o, **kwargs)
def decode(self, o, **kwargs): return self.tfms.decode(o, **kwargs)
def __call__(self, o, **kwargs): return self.tfms.__call__(o, **kwargs)
def setup(self, train_setup=True): self.tfms.setup(getattr(self,'train',self) if train_setup else self)
def overlapping_splits(self): return L(Counter(self.splits.concat()).values()).filter(gt(1))
def __getitem__(self, idx):
res = super().__getitem__(idx)
if self._after_item is None: return res
return self._after_item(res) if is_indexer(idx) else res.map(self._after_item)
add_docs(TfmdList,
setup="Transform setup with self",
decode="From `Pipeline",
show="From `Pipeline",
overlapping_splits="All splits that are in more than one split",
subset="New `TfmdList` with same tfms that only includes items in `i`th split")
#exports
def decode_at(o, idx):
"Decoded item at `idx`"
return o.decode(o[idx])
#exports
def show_at(o, idx, **kwargs):
"Show item at `idx`",
return o.show(o[idx], **kwargs)
A TfmdList combines a collection of object with a Pipeline. tfms can either be a Pipeline or a list of transforms, in which case, it will wrap them in a Pipeline. use_list is passed along to L with the items, as_item and split_idx are passed to each transform of the Pipeline. do_setup indicates if the Pipeline.setup method should be called during initialization.
class _IntFloatTfm(Transform):
def encodes(self, o): return Int(o)
def decodes(self, o): return Float(o)
int2f_tfm=_IntFloatTfm()
def _neg(o): return -o
neg_tfm = Transform(_neg, _neg)
items = L([1.,2.,3.]); tfms = [neg_tfm, int2f_tfm]
tl = TfmdList(items, tfms=tfms)
test_eq_type(tl[0], Int(-1))
test_eq_type(tl[1], Int(-2))
test_eq_type(tl.decode(tl[2]), Float(3.))
test_stdout(lambda: show_at(tl, 2), '-3')
tl
TfmdList: [1.0, 2.0, 3.0] tfms - (#2) [Transform: True (object,object) -> _neg (object,object) -> _neg,_IntFloatTfm: True (object,object) -> encodes (object,object) -> decodes]
# add splits to TfmdList
splits = [[0,2],[1]]
tl = TfmdList(items, tfms=tfms, splits=splits)
test_eq(tl.n_subsets, 2)
test_eq(tl.train, tl.subset(0))
test_eq(tl.valid, tl.subset(1))
test_eq(tl.train.items, items[splits[0]])
test_eq(tl.valid.items, items[splits[1]])
test_eq(tl.train.tfms.split_idx, 0)
test_eq(tl.valid.tfms.split_idx, 1)
test_eq_type(tl.splits, L(splits))
assert not tl.overlapping_splits()
df = pd.DataFrame(dict(a=[1,2,3],b=[2,3,4]))
tl = TfmdList(df, lambda o: o.a+1, splits=[[0],[1,2]])
test_eq(tl[1,2], [3,4])
tr = tl.subset(0)
test_eq(tr[:], [2])
val = tl.subset(1)
test_eq(val[:], [3,4])
class _B(Transform):
def __init__(self): self.m = 0
def encodes(self, o): return o+self.m
def decodes(self, o): return o-self.m
def setups(self, items): self.m = tensor(items).float().mean().item()
# test for setup, which updates `self.m`
tl = TfmdList(items, _B())
test_eq(tl.m, 2)
Here's how we can use TfmdList.setup to implement a simple category list, getting labels from a mock file list:
class _Cat(Transform):
order = 1
def encodes(self, o): return int(self.o2i[o])
def decodes(self, o): return Str(self.vocab[o])
def setups(self, items): self.vocab,self.o2i = uniqueify(L(items), sort=True, bidir=True)
tcat = _Cat()
def _lbl(o): return Str(o.split('_')[0])
# Check that tfms are sorted by `order` & `_lbl` is called first
fns = ['dog_0.jpg','cat_0.jpg','cat_2.jpg','cat_1.jpg','dog_1.jpg']
tl = TfmdList(fns, [tcat,_lbl])
exp_voc = ['cat','dog']
test_eq(tcat.vocab, exp_voc)
test_eq(tl.tfms.vocab, exp_voc)
test_eq(tl.vocab, exp_voc)
test_eq(tl, (1,0,0,0,1))
test_eq([tl.decode(o) for o in tl], ('dog','cat','cat','cat','dog'))
#Check only the training set is taken into account for setup
tl = TfmdList(fns, [tcat,_lbl], splits=[[0,4], [1,2,3]])
test_eq(tcat.vocab, ['dog'])
tfm = NegTfm(split_idx=1)
tds = TfmdList(start, A())
tdl = TfmdDL(tds, after_batch=tfm, bs=4)
x = tdl.one_batch()
test_eq(x, torch.arange(4))
tds.split_idx = 1
x = tdl.one_batch()
test_eq(x, -torch.arange(4))
tds.split_idx = 0
x = tdl.one_batch()
test_eq(x, torch.arange(4))
tds = TfmdList(start, A())
tdl = TfmdDL(tds, after_batch=NegTfm(), bs=4)
test_eq(tdl.dataset[0], start[0])
test_eq(len(tdl), (len(tds)-1)//4+1)
test_eq(tdl.bs, 4)
test_stdout(tdl.show_batch, '0\n1\n2\n3')
show_doc(TfmdList.subset)
#export
@docs
@delegates(TfmdList)
class DataSource(FilteredBase):
"A dataset that creates a tuple from each `tfms`, passed thru `item_tfms`"
def __init__(self, items=None, tfms=None, tls=None, n_inp=None, dl_type=None, **kwargs):
super().__init__(dl_type=dl_type)
self.tls = L(tls if tls else [TfmdList(items, t, **kwargs) for t in L(ifnone(tfms,[None]))])
self.n_inp = (1 if len(self.tls)==1 else len(self.tls)-1) if n_inp is None else n_inp
def __getitem__(self, it):
res = tuple([tl[it] for tl in self.tls])
return res if is_indexer(it) else list(zip(*res))
def __getattr__(self,k): return gather_attrs(self, k, 'tls')
def __dir__(self): return super().__dir__() + gather_attr_names(self, 'tls')
def __len__(self): return len(self.tls[0])
def __iter__(self): return (self[i] for i in range(len(self)))
def __repr__(self): return coll_repr(self)
def decode(self, o, full=True): return tuple(tl.decode(o_, full=full) for o_,tl in zip(o,tuplify(self.tls, match=o)))
def subset(self, i): return type(self)(tls=L(tl.subset(i) for tl in self.tls), n_inp=self.n_inp)
def _new(self, items, *args, **kwargs): return super()._new(items, tfms=self.tfms, do_setup=False, **kwargs)
def overlapping_splits(self): return self.tls[0].overlapping_splits()
@property
def splits(self): return self.tls[0].splits
@property
def split_idx(self): return self.tls[0].tfms.split_idx
@property
def items(self): return self.tls[0].items
@items.setter
def items(self, v):
for tl in self.tls: tl.items = v
def show(self, o, ctx=None, **kwargs):
for o_,tl in zip(o,self.tls): ctx = tl.show(o_, ctx=ctx, **kwargs)
return ctx
def new_empty(self):
tls = [tl._new([self.items[0]], split_idx=tl.split_idx) for tl in self.tls]
return type(self)(tls=tls, n_inp=self.n_inp)
@contextmanager
def set_split_idx(self, i):
old_split_idx = self.split_idx
for tl in self.tls: tl.tfms.split_idx = i
yield self
for tl in self.tls: tl.tfms.split_idx = old_split_idx
_docs=dict(
decode="Compose `decode` of all `tuple_tfms` then all `tfms` on `i`",
show="Show item `o` in `ctx`",
databunch="Get a `DataBunch`",
overlapping_splits="All splits that are in more than one split",
subset="New `DataSource` that only includes subset `i`",
new_empty="Create a new empty version of the `self`, keeping only the transforms",
set_split_idx="Contextmanager to use the same `DataSource` with another `split_idx`"
)
A DataSource creates a tuple from items (typically input,target) by applying to them each list of Transform (or Pipeline) in tfms. Note that if tfms contains only one list of tfms, the items given by DataSource will be tuples of one element.
n_inp is the number of elements in the tuples that should be considered part of the input and will default to 1 if tfms consists of one set of transforms, len(tfms)-1 otherwise. In most cases, the number of elements in the tuples spit out by DataSource will be 2 (for input,target) but it can happen that there is 3 (Siamese networks or tabular data) in which case we need to be able to determine when the inputs end and the targets begin.
items = [1,2,3,4]
dsrc = DataSource(items, [[neg_tfm,int2f_tfm], [add(1)]])
t = dsrc[0]
test_eq(t, (-1,2))
test_eq(dsrc[0,1,2], [(-1,2),(-2,3),(-3,4)])
test_eq(dsrc.n_inp, 1)
dsrc.decode(t)
(1.0, 2)
class Norm(Transform):
def encodes(self, o): return (o-self.m)/self.s
def decodes(self, o): return (o*self.s)+self.m
def setups(self, items):
its = tensor(items).float()
self.m,self.s = its.mean(),its.std()
items = [1,2,3,4]
nrm = Norm()
dsrc = DataSource(items, [[neg_tfm,int2f_tfm], [neg_tfm,nrm]])
x,y = zip(*dsrc)
test_close(tensor(y).mean(), 0)
test_close(tensor(y).std(), 1)
test_eq(x, (-1,-2,-3,-4,))
test_eq(nrm.m, -2.5)
test_stdout(lambda:show_at(dsrc, 1), '-2')
test_eq(dsrc.m, nrm.m)
test_eq(dsrc.norm.m, nrm.m)
test_eq(dsrc.train.norm.m, nrm.m)
#hide
#Check filtering is properly applied
class B(Transform):
def encodes(self, x)->None: return int(x+1)
def decodes(self, x): return Int(x-1)
add1 = B(split_idx=1)
dsrc = DataSource(items, [neg_tfm, [neg_tfm,int2f_tfm,add1]], splits=[[3],[0,1,2]])
test_eq(dsrc[1], [-2,-2])
test_eq(dsrc.valid[1], [-2,-1])
test_eq(dsrc.valid[[1,1]], [[-2,-1], [-2,-1]])
test_eq(dsrc.train[0], [-4,-4])
test_fns = ['dog_0.jpg','cat_0.jpg','cat_2.jpg','cat_1.jpg','kid_1.jpg']
tcat = _Cat()
dsrc = DataSource(test_fns, [[tcat,_lbl]], splits=[[0,1,2], [3,4]])
test_eq(tcat.vocab, ['cat','dog'])
test_eq(dsrc.train, [(1,),(0,),(0,)])
test_eq(dsrc.valid[0], (0,))
test_stdout(lambda: show_at(dsrc.train, 0), "dog")
inp = [0,1,2,3,4]
dsrc = DataSource(inp, tfms=[None])
test_eq(*dsrc[2], 2) # Retrieve one item (subset 0 is the default)
test_eq(dsrc[1,2], [(1,),(2,)]) # Retrieve two items by index
mask = [True,False,False,True,False]
test_eq(dsrc[mask], [(0,),(3,)]) # Retrieve two items by mask
inp = pd.DataFrame(dict(a=[5,1,2,3,4]))
dsrc = DataSource(inp, tfms=attrgetter('a')).subset(0)
test_eq(*dsrc[2], 2) # Retrieve one item (subset 0 is the default)
test_eq(dsrc[1,2], [(1,),(2,)]) # Retrieve two items by index
mask = [True,False,False,True,False]
test_eq(dsrc[mask], [(5,),(3,)]) # Retrieve two items by mask
#test n_inp
inp = [0,1,2,3,4]
dsrc = DataSource(inp, tfms=[None])
test_eq(dsrc.n_inp, 1)
dsrc = DataSource(inp, tfms=[[None],[None],[None]])
test_eq(dsrc.n_inp, 2)
dsrc = DataSource(inp, tfms=[[None],[None],[None]], n_inp=1)
test_eq(dsrc.n_inp, 1)
# splits can be indices
dsrc = DataSource(range(5), tfms=[None], splits=[tensor([0,2]), [1,3,4]])
test_eq(dsrc.subset(0), [(0,),(2,)])
test_eq(dsrc.train, [(0,),(2,)]) # Subset 0 is aliased to `train`
test_eq(dsrc.subset(1), [(1,),(3,),(4,)])
test_eq(dsrc.valid, [(1,),(3,),(4,)]) # Subset 1 is aliased to `valid`
test_eq(*dsrc.valid[2], 4)
#assert '[(1,),(3,),(4,)]' in str(dsrc) and '[(0,),(2,)]' in str(dsrc)
dsrc
(#5) [(0,),(1,),(2,),(3,),(4,)]
# splits can be boolean masks (they don't have to cover all items, but must be disjoint)
splits = [[False,True,True,False,True], [True,False,False,False,False]]
dsrc = DataSource(range(5), tfms=[None], splits=splits)
test_eq(dsrc.train, [(1,),(2,),(4,)])
test_eq(dsrc.valid, [(0,)])
# apply transforms to all items
tfm = [[lambda x: x*2,lambda x: x+1]]
splits = [[1,2],[0,3,4]]
dsrc = DataSource(range(5), tfm, splits=splits)
test_eq(dsrc.train,[(3,),(5,)])
test_eq(dsrc.valid,[(1,),(7,),(9,)])
test_eq(dsrc.train[False,True], [(5,)])
# only transform subset 1
class _Tfm(Transform):
split_idx=1
def encodes(self, x): return x*2
def decodes(self, x): return Str(x//2)
dsrc = DataSource(range(5), [_Tfm()], splits=[[1,2],[0,3,4]])
test_eq(dsrc.train,[(1,),(2,)])
test_eq(dsrc.valid,[(0,),(6,),(8,)])
test_eq(dsrc.train[False,True], [(2,)])
dsrc
(#5) [(0,),(1,),(2,),(3,),(4,)]
#A context manager to change the spli_idx and apply the validation transform on the training set
ds = dsrc.train
with ds.set_split_idx(1):
test_eq(ds,[(2,),(4,)])
test_eq(dsrc.train,[(1,),(2,)])
#hide
#Test DataSource pickles
dsrc1 = pickle.loads(pickle.dumps(dsrc))
test_eq(dsrc.train, dsrc1.train)
test_eq(dsrc.valid, dsrc1.valid)
dsrc = DataSource(range(5), [_Tfm(),noop], splits=[[1,2],[0,3,4]])
test_eq(dsrc.train,[(1,1),(2,2)])
test_eq(dsrc.valid,[(0,0),(6,3),(8,4)])
start = torch.arange(0,50)
tds = DataSource(start, [A()])
tdl = TfmdDL(tds, after_item=NegTfm(), bs=4)
b = tdl.one_batch()
test_eq(tdl.decode_batch(b), ((0,),(1,),(2,),(3,)))
test_stdout(tdl.show_batch, "0\n1\n2\n3")
# only transform subset 1
class _Tfm(Transform):
split_idx=1
def encodes(self, x): return x*2
dsrc = DataSource(range(8), [None], splits=[[1,2,5,7],[0,3,4,6]])
dbch = dsrc.databunch(bs=4, after_batch=_Tfm(), shuffle_train=False)
test_eq(dbch.train_dl, [(tensor([1,2,5, 7]),)])
test_eq(dbch.valid_dl, [(tensor([0,6,8,12]),)])
test_eq(dbch.n_inp, 1)
items = [1,2,3,4]
dsrc = DataSource(items, [[neg_tfm,int2f_tfm]])
show_doc(DataSource.decode)
DataSource.decode[source]
DataSource.decode(o,full=True)
Compose decode of all tuple_tfms then all tfms on i
test_eq(*dsrc[0], -1)
test_eq(*dsrc.decode((-1,)), 1)
show_doc(DataSource.show)
test_stdout(lambda:dsrc.show(dsrc[1]), '-2')
# only transform subset 1
class _Tfm1(Transform):
split_idx=0
def encodes(self, x): return x*3
dsrc = DataSource(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])
test_eq(dsrc.train, [(3,),(6,),(15,),(21,)])
test_eq(dsrc.valid, [(0,),(6,),(8,),(12,)])
#export
def test_set(dsrc, test_items, rm_tfms=0):
"Create a test set from `test_items` using validation transforms of `dsrc`"
test_tls = [tl._new(test_items, split_idx=1) for tl in dsrc.tls[:dsrc.n_inp]]
rm_tfms = tuplify(rm_tfms, match=test_tls)
for i,j in enumerate(rm_tfms): test_tls[i].tfms.fs = test_tls[i].tfms.fs[j:]
return DataSource(tls=test_tls)
class _Tfm1(Transform):
split_idx=0
def encodes(self, x): return x*3
dsrc = DataSource(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])
test_eq(dsrc.train, [(3,),(6,),(15,),(21,)])
test_eq(dsrc.valid, [(0,),(6,),(8,),(12,)])
#Tranform of the validation set are applied
tst = test_set(dsrc, [1,2,3])
test_eq(tst, [(2,),(4,),(6,)])
#hide
#Test with various input lengths
dsrc = DataSource(range(8), [[_Tfm(),_Tfm1()],[_Tfm(),_Tfm1()],[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])
tst = test_set(dsrc, [1,2,3])
test_eq(tst, [(2,2),(4,4),(6,6)])
dsrc = DataSource(range(8), [[_Tfm(),_Tfm1()],[_Tfm(),_Tfm1()],[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]], n_inp=1)
tst = test_set(dsrc, [1,2,3])
test_eq(tst, [(2,),(4,),(6,)])
#hide
#Test with rm_tfms
dsrc = DataSource(range(8), [[_Tfm(),_Tfm()]], splits=[[1,2,5,7],[0,3,4,6]])
tst = test_set(dsrc, [1,2,3])
test_eq(tst, [(4,),(8,),(12,)])
dsrc = DataSource(range(8), [[_Tfm(),_Tfm()]], splits=[[1,2,5,7],[0,3,4,6]])
tst = test_set(dsrc, [1,2,3], rm_tfms=1)
test_eq(tst, [(2,),(4,),(6,)])
dsrc = DataSource(range(8), [[_Tfm(),_Tfm()], [_Tfm(),_Tfm()]], splits=[[1,2,5,7],[0,3,4,6]], n_inp=2)
tst = test_set(dsrc, [1,2,3], rm_tfms=(1,0))
test_eq(tst, [(2,4),(4,8),(6,12)])
#export
@delegates(TfmdDL.__init__)
def test_dl(dbunch, test_items, rm_type_tfms=0, **kwargs):
"Create a test dataloader from `test_items` using validation transforms of `dbunch`"
test_ds = test_set(dbunch.valid_ds, test_items, rm_tfms=rm_type_tfms) if isinstance(dbunch.valid_ds, DataSource) else test_items
return dbunch.valid_dl.new(test_ds, **kwargs)
dsrc = DataSource(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])
dbunch = dsrc.databunch(bs=4)
tst_dl = test_dl(dbunch, [2,3,4,5])
test_eq(list(tst_dl), [(tensor([ 4, 6, 8, 10]),)])
#Test you can change transforms
tst_dl = test_dl(dbunch, [2,3,4,5], after_item=add1)
test_eq(list(tst_dl), [(tensor([ 5, 7, 9, 11]),)])
#hide
from nbdev.export import notebook2script
notebook2script()
Converted 00_test.ipynb. Converted 01_core_foundation.ipynb. Converted 01a_core_utils.ipynb. Converted 01b_core_dispatch.ipynb. Converted 01c_core_transform.ipynb. Converted 02_core_script.ipynb. Converted 03_torchcore.ipynb. Converted 03a_layers.ipynb. Converted 04_data_load.ipynb. Converted 05_data_core.ipynb. Converted 06_data_transforms.ipynb. Converted 07_data_block.ipynb. Converted 08_vision_core.ipynb. Converted 09_vision_augment.ipynb. Converted 09a_vision_data.ipynb. Converted 09b_vision_utils.ipynb. Converted 10_pets_tutorial.ipynb. Converted 11_vision_models_xresnet.ipynb. Converted 12_optimizer.ipynb. Converted 13_learner.ipynb. Converted 13a_metrics.ipynb. Converted 14_callback_schedule.ipynb. Converted 14a_callback_data.ipynb. Converted 15_callback_hook.ipynb. Converted 15a_vision_models_unet.ipynb. Converted 16_callback_progress.ipynb. Converted 17_callback_tracker.ipynb. Converted 18_callback_fp16.ipynb. Converted 19_callback_mixup.ipynb. Converted 20_interpret.ipynb. Converted 20a_distributed.ipynb. Converted 21_vision_learner.ipynb. Converted 22_tutorial_imagenette.ipynb. Converted 23_tutorial_transfer_learning.ipynb. Converted 30_text_core.ipynb. Converted 31_text_data.ipynb. Converted 32_text_models_awdlstm.ipynb. Converted 33_text_models_core.ipynb. Converted 34_callback_rnn.ipynb. Converted 35_tutorial_wikitext.ipynb. Converted 36_text_models_qrnn.ipynb. Converted 37_text_learner.ipynb. Converted 38_tutorial_ulmfit.ipynb. Converted 40_tabular_core.ipynb. Converted 41_tabular_model.ipynb. Converted 42_tabular_rapids.ipynb. Converted 50_data_block_examples.ipynb. Converted 60_medical_imaging.ipynb. Converted 65_medical_text.ipynb. Converted 70_callback_wandb.ipynb. Converted 71_callback_tensorboard.ipynb. Converted 90_notebook_core.ipynb. Converted 91_notebook_export.ipynb. Converted 92_notebook_showdoc.ipynb. Converted 93_notebook_export2html.ipynb. Converted 94_notebook_test.ipynb. Converted 95_index.ipynb. Converted 96_data_external.ipynb. Converted 97_utils_test.ipynb. Converted notebook2jekyll.ipynb. Converted xse_resnext.ipynb.