%reload_ext autoreload
%autoreload 2
#export
from nb_005 import *
PATH = Path('data/dogscats')
arch = tvm.resnet34
size,lr = 224,3e-3
data_norm,data_denorm = normalize_funcs(*imagenet_stats)
tfms = get_transforms(do_flip=True, max_rotate=10, max_zoom=1.2, max_lighting=0.3, max_warp=0.15)
data = data_from_imagefolder(PATH, bs=64, ds_tfms=tfms, num_workers=8, tfms=data_norm, size=size)
#export
HookFunc = Callable[[Model, Tensors, Tensors], Any]
class Hook():
"Creates a hook"
def __init__(self, m:Model, hook_func:HookFunc, is_forward:bool=True):
self.hook_func,self.stored = hook_func,None
f = m.register_forward_hook if is_forward else m.register_backward_hook
self.hook = f(self.hook_fn)
self.removed = False
def hook_fn(self, module:Model, input:Tensors, output:Tensors):
input = (o.detach() for o in input ) if is_listy(input ) else input.detach()
output = (o.detach() for o in output) if is_listy(output) else output.detach()
self.stored = self.hook_func(module, input, output)
def remove(self):
if not self.removed:
self.hook.remove()
self.removed=True
class Hooks():
"Creates several hooks"
def __init__(self, ms:Collection[Model], hook_func:HookFunc, is_forward:bool=True):
self.hooks = [Hook(m, hook_func, is_forward) for m in ms]
def __getitem__(self,i:int) -> Hook: return self.hooks[i]
def __len__(self) -> int: return len(self.hooks)
def __iter__(self): return iter(self.hooks)
@property
def stored(self): return [o.stored for o in self]
def remove(self):
for h in self.hooks: h.remove()
def hook_output (module:Model) -> Hook: return Hook (module, lambda m,i,o: o)
def hook_outputs(modules:Collection[Model]) -> Hooks: return Hooks(modules, lambda m,i,o: o)
#export
class HookCallback(LearnerCallback):
"Callback that registers given hooks"
def __init__(self, learn:Learner, modules:Sequence[Model]=None, do_remove:bool=True):
super().__init__(learn)
self.modules,self.do_remove = modules,do_remove
def on_train_begin(self, **kwargs):
if not self.modules:
self.modules = [m for m in flatten_model(self.learn.model)
if hasattr(m, 'weight')]
self.hooks = Hooks(self.modules, self.hook)
def on_train_end(self, **kwargs):
if self.do_remove: self.remove()
def remove(self): self.hooks.remove
def __del__(self): self.remove()
class ActivationStats(HookCallback):
"Callback that record the activations"
def on_train_begin(self, **kwargs):
super().on_train_begin(**kwargs)
self.stats = []
def hook(self, m:Model, i:Tensors, o:Tensors) -> Tuple[Rank0Tensor,Rank0Tensor]:
return o.mean().item(),o.std().item()
def on_batch_end(self, **kwargs): self.stats.append(self.hooks.stored)
def on_train_end(self, **kwargs): self.stats = tensor(self.stats).permute(2,1,0)
def idx_dict(a): return {v:k for k,v in enumerate(a)}
learn = ConvLearner(data, arch, wd=1e-2, metrics=accuracy, path=PATH,
callback_fns=ActivationStats)
learn.fit_one_cycle(1, lr)
ms = learn.activation_stats.modules
d = idx_dict(ms)
ln = d[learn.model[1][8]]; ln
plt.plot(learn.activation_stats.stats[1][ln].numpy());
learn.save('e1')
learn = ConvLearner(data, arch, wd=1e-2, metrics=accuracy)
learn.load('1')
bs=64
classes = data.valid_ds.classes
preds,y = learn.TTA()
preds,y = learn.get_preds()
#export
def calc_loss(y_pred:Tensor, y_true:Tensor, loss_class:type=nn.CrossEntropyLoss):
"Calculate loss between `y_pred` and `y_true` using `loss_class`"
loss_dl = DataLoader(TensorDataset(tensor(y_pred),tensor(y_true)), bs)
with torch.no_grad():
return torch.cat([loss_class(reduction='none')(*b) for b in loss_dl])
class ClassificationInterpretation():
"Interpretation methods for classification models"
def __init__(self, data:DataBunch, y_pred:Tensor, y_true:Tensor,
loss_class:type=nn.CrossEntropyLoss, sigmoid:bool=True):
self.data,self.y_pred,self.y_true,self.loss_class = data,y_pred,y_true,loss_class
self.losses = calc_loss(y_pred, y_true, loss_class=loss_class)
self.probs = preds.sigmoid() if sigmoid else preds
self.pred_class = self.probs.argmax(dim=1)
def top_losses(self, k, largest=True):
"`k` largest(/smallest) losses"
return self.losses.topk(k, largest=largest)
def plot_top_losses(self, k, largest=True, figsize=(12,12)):
"Show images in `top_losses` along with their loss, label, and prediction"
tl = self.top_losses(k,largest)
classes = self.data.classes
rows = math.ceil(math.sqrt(k))
fig,axes = plt.subplots(rows,rows,figsize=figsize)
for i,idx in enumerate(self.top_losses(k, largest=largest)[1]):
t=data.valid_ds[idx]
t[0].show(ax=axes.flat[i], title=
f'{classes[self.pred_class[idx]]}/{classes[t[1]]} / {self.losses[idx]:.2f} / {self.probs[idx][0]:.2f}')
def confusion_matrix(self):
"Confusion matrix as an `np.ndarray`"
x=torch.arange(0,data.c)
cm = ((self.pred_class==x[:,None]) & (self.y_true==x[:,None,None])).sum(2)
return cm.cpu().numpy()
def plot_confusion_matrix(self, normalize:bool=False, title:str='Confusion matrix', cmap:Any="Blues", figsize:tuple=None):
"Plot the confusion matrix"
# This function is copied from the scikit docs
cm = self.confusion_matrix()
plt.figure(figsize=figsize)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, self.data.classes, rotation=45)
plt.yticks(tick_marks, self.data.classes)
if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j], horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
interp = ClassificationInterpretation(data, preds, y, loss_class=nn.CrossEntropyLoss)
interp.top_losses(9)
interp.plot_top_losses(9)
interp.confusion_matrix()
interp.plot_confusion_matrix()