%reload_ext autoreload %autoreload 2 #export from nb_005 import * PATH = Path('data/dogscats') arch = tvm.resnet34 size,lr = 224,3e-3 data_norm,data_denorm = normalize_funcs(*imagenet_stats) tfms = get_transforms(do_flip=True, max_rotate=10, max_zoom=1.2, max_lighting=0.3, max_warp=0.15) data = data_from_imagefolder(PATH, bs=64, ds_tfms=tfms, num_workers=8, tfms=data_norm, size=size) #export HookFunc = Callable[[Model, Tensors, Tensors], Any] class Hook(): "Creates a hook" def __init__(self, m:Model, hook_func:HookFunc, is_forward:bool=True): self.hook_func,self.stored = hook_func,None f = m.register_forward_hook if is_forward else m.register_backward_hook self.hook = f(self.hook_fn) self.removed = False def hook_fn(self, module:Model, input:Tensors, output:Tensors): input = (o.detach() for o in input ) if is_listy(input ) else input.detach() output = (o.detach() for o in output) if is_listy(output) else output.detach() self.stored = self.hook_func(module, input, output) def remove(self): if not self.removed: self.hook.remove() self.removed=True class Hooks(): "Creates several hooks" def __init__(self, ms:Collection[Model], hook_func:HookFunc, is_forward:bool=True): self.hooks = [Hook(m, hook_func, is_forward) for m in ms] def __getitem__(self,i:int) -> Hook: return self.hooks[i] def __len__(self) -> int: return len(self.hooks) def __iter__(self): return iter(self.hooks) @property def stored(self): return [o.stored for o in self] def remove(self): for h in self.hooks: h.remove() def hook_output (module:Model) -> Hook: return Hook (module, lambda m,i,o: o) def hook_outputs(modules:Collection[Model]) -> Hooks: return Hooks(modules, lambda m,i,o: o) #export class HookCallback(LearnerCallback): "Callback that registers given hooks" def __init__(self, learn:Learner, modules:Sequence[Model]=None, do_remove:bool=True): super().__init__(learn) self.modules,self.do_remove = modules,do_remove def on_train_begin(self, **kwargs): if not self.modules: self.modules = [m for m in flatten_model(self.learn.model) if hasattr(m, 'weight')] self.hooks = Hooks(self.modules, self.hook) def on_train_end(self, **kwargs): if self.do_remove: self.remove() def remove(self): self.hooks.remove def __del__(self): self.remove() class ActivationStats(HookCallback): "Callback that record the activations" def on_train_begin(self, **kwargs): super().on_train_begin(**kwargs) self.stats = [] def hook(self, m:Model, i:Tensors, o:Tensors) -> Tuple[Rank0Tensor,Rank0Tensor]: return o.mean().item(),o.std().item() def on_batch_end(self, **kwargs): self.stats.append(self.hooks.stored) def on_train_end(self, **kwargs): self.stats = tensor(self.stats).permute(2,1,0) def idx_dict(a): return {v:k for k,v in enumerate(a)} learn = ConvLearner(data, arch, wd=1e-2, metrics=accuracy, path=PATH, callback_fns=ActivationStats) learn.fit_one_cycle(1, lr) ms = learn.activation_stats.modules d = idx_dict(ms) ln = d[learn.model[1][8]]; ln plt.plot(learn.activation_stats.stats[1][ln].numpy()); learn.save('e1') learn = ConvLearner(data, arch, wd=1e-2, metrics=accuracy) learn.load('1') bs=64 classes = data.valid_ds.classes preds,y = learn.TTA() preds,y = learn.get_preds() #export def calc_loss(y_pred:Tensor, y_true:Tensor, loss_class:type=nn.CrossEntropyLoss): "Calculate loss between `y_pred` and `y_true` using `loss_class`" loss_dl = DataLoader(TensorDataset(tensor(y_pred),tensor(y_true)), bs) with torch.no_grad(): return torch.cat([loss_class(reduction='none')(*b) for b in loss_dl]) class ClassificationInterpretation(): "Interpretation methods for classification models" def __init__(self, data:DataBunch, y_pred:Tensor, y_true:Tensor, loss_class:type=nn.CrossEntropyLoss, sigmoid:bool=True): self.data,self.y_pred,self.y_true,self.loss_class = data,y_pred,y_true,loss_class self.losses = calc_loss(y_pred, y_true, loss_class=loss_class) self.probs = preds.sigmoid() if sigmoid else preds self.pred_class = self.probs.argmax(dim=1) def top_losses(self, k, largest=True): "`k` largest(/smallest) losses" return self.losses.topk(k, largest=largest) def plot_top_losses(self, k, largest=True, figsize=(12,12)): "Show images in `top_losses` along with their loss, label, and prediction" tl = self.top_losses(k,largest) classes = self.data.classes rows = math.ceil(math.sqrt(k)) fig,axes = plt.subplots(rows,rows,figsize=figsize) for i,idx in enumerate(self.top_losses(k, largest=largest)[1]): t=data.valid_ds[idx] t[0].show(ax=axes.flat[i], title= f'{classes[self.pred_class[idx]]}/{classes[t[1]]} / {self.losses[idx]:.2f} / {self.probs[idx][0]:.2f}') def confusion_matrix(self): "Confusion matrix as an `np.ndarray`" x=torch.arange(0,data.c) cm = ((self.pred_class==x[:,None]) & (self.y_true==x[:,None,None])).sum(2) return cm.cpu().numpy() def plot_confusion_matrix(self, normalize:bool=False, title:str='Confusion matrix', cmap:Any="Blues", figsize:tuple=None): "Plot the confusion matrix" # This function is copied from the scikit docs cm = self.confusion_matrix() plt.figure(figsize=figsize) plt.imshow(cm, interpolation='nearest', cmap=cmap) plt.title(title) plt.colorbar() tick_marks = np.arange(len(classes)) plt.xticks(tick_marks, self.data.classes, rotation=45) plt.yticks(tick_marks, self.data.classes) if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] thresh = cm.max() / 2. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): plt.text(j, i, cm[i, j], horizontalalignment="center", color="white" if cm[i, j] > thresh else "black") plt.tight_layout() plt.ylabel('True label') plt.xlabel('Predicted label') interp = ClassificationInterpretation(data, preds, y, loss_class=nn.CrossEntropyLoss) interp.top_losses(9) interp.plot_top_losses(9) interp.confusion_matrix() interp.plot_confusion_matrix()