from fastai2.torch_basics import *
from fastai2.test import *
from fastai2.layers import *
from fastai2.data.all import *
from fastai2.data.block import *
from fastai2.optimizer import *
from fastai2.learner import *
from fastai2.metrics import *
from fastai2.callback.all import *
from fastai2.vision.all import *
source = untar_data(URLs.IMAGENETTE_160)
items = get_image_files(source)
split_idx = GrandparentSplitter(valid_name='val')(items)
tfms = [PILImage.create, [parent_label, Categorize()]]
item_img_tfms = [ToTensor(), FlipItem(0.5), RandomResizedCrop(128, min_scale=0.35)]
dsrc = DataSource(items, tfms, splits=split_idx)
imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
imagenet_stats = broadcast_vec(1, 4, *imagenet_stats)
batch_tfms = [Cuda(), IntToFloatTensor(), Normalize(*imagenet_stats)]
dbch = dsrc.databunch(after_item=item_img_tfms, after_batch=batch_tfms, bs=64, num_workers=0)
Get a batch
b = dbch.one_batch()
Cast to retained types, decode after batch and before batch
b = dbch.train_dl.decode(b)
If the batch know how to show itself at this stage, go for it (used for tabular data).
if hasattr(b, 'show'): b.show(max_n=9)
Grab the samples and decode after_item
db = dbch.train_dl._decode_batch(b, 9, False)
Get the contexts to show the batch, in this case, a subplot with 9 axis, then show each of the objects on it.
type(db[0][0])
local.transform.TensorImage
random.random()
0.30263058643764706
dbch.vocab[random.randint(0,9)]
'n01440764'
_,axs = plt.subplots(3,3, figsize=(9,10))
for x,ax in zip(db,axs.flatten()):
x[0].show(ctx=ax)
x[1].show(ctx=ax)
r = x[1] if random.random() < 0.5 else Category(dbch.vocab[random.randint(0,9)])
r.show(ctx=ax, color='green' if r==x[1] else 'red')
type(ax.title)
matplotlib.text.Text
ax.title.
Text(0.5, 1, 'n03888257\nn03888257')
ctxs = b[0].get_ctxs(max_n=9)
ctxs = [dbch.train_dl.dataset.show(o, ctx=ctx) for o,ctx in zip(db, ctxs)]
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) <ipython-input-12-9f171d50a336> in <module> ----> 1 ctxs = b[0].get_ctxs(max_n=9) 2 ctxs = [dbch.train_dl.dataset.show(o, ctx=ctx) for o,ctx in zip(db, ctxs)] AttributeError: 'TensorImage' object has no attribute 'get_ctxs'
learn = cnn_learner(dbch, resnet34)
Get one batch
b = dbch.one_batch()
Get the corresponding predictions
preds,_ = learn.get_preds(dl=[b])
Decode the predictions with the loss function
preds = getattr(learn.loss_func, "decodes", noop)(preds)
At this stage we have two batches to show together: b and (b[0],preds). There are two ways of doing this: superposed or aligned.
b_out = (b[0], preds)
ctxs = dbch.show_batch(b=b, max_n=9)
dbch.show_batch(b=b_out, max_n=9, ctxs=ctxs)
dbch.show_batch(b=b, max_n=9)
dbch.show_batch(b=b_out, max_n=9)
_,axs = plt.subplots(3, 6, figsize=(18,9))
ctxs1,ctxs2 = axs.flatten()[::2]a,axs.flatten()[1::2]
dbch.show_batch(b=b, max_n=9, ctxs=ctxs1)
dbch.show_batch(b=b_out, max_n=9, ctxs=ctxs2)