#default_exp callback.training #export from fastai2.basics import * from fastai2.callback.progress import * from fastai2.callback.fp16 import * #hide from nbdev.showdoc import * from fastai2.test_utils import * #export @log_args class ShortEpochCallback(Callback): "Fit just `pct` of an epoch, then stop" def __init__(self,pct=0.01,short_valid=True): self.pct,self.short_valid = pct,short_valid def after_batch(self): if self.iter/self.n_iter < self.pct: return if self.training: raise CancelTrainException if self.short_valid: raise CancelValidException learn = synth_learner() learn.fit(1, cbs=ShortEpochCallback()) learn = synth_learner() learn.fit(1, cbs=ShortEpochCallback(short_valid=False)) # export @log_args class GradientAccumulation(Callback): "Accumulate gradients before updating weights" toward_end,run_before=True,MixedPrecision def __init__(self, n_acc=32): store_attr(self, 'n_acc') def before_fit(self): self.count=0 def after_backward(self): self.count += find_bs(self.learn.yb) if self.count < self.n_acc: raise CancelBatchException() #skip weight update else: self.count=0 _docs = dict(before_fit="Set counter to 0", after_backward="Skip weight update if we have not seen enough items") learn = synth_learner() learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_acc=2*learn.dls.bs)) # ensure train_loss decreased assert learn.recorder.values[-1][0] < learn.recorder.values[0][0] learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_acc=1e6)) # ensure valid_loss didn't change (same weights) assert learn.recorder.values[-1][1] == learn.recorder.values[0][1] #export bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) def set_bn_eval(m:nn.Module, use_eval=True)->None: "Set bn layers in eval mode for all recursive children of `m`." for l in m.children(): if isinstance(l, bn_types) and not next(l.parameters()).requires_grad: if use_eval: l.eval() else: l.train() set_bn_eval(l) class BnFreeze(Callback): "Freeze moving average statistics in all non-trainable batchnorm layers." def before_epoch(self): set_bn_eval(self.model) #slow from fastai2.vision.all import * path = untar_data(URLs.MNIST_TINY) dls = ImageDataLoaders.from_folder(path, valid_pct=0.2) #slow learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False) #slow m = learn1.model[0][1].running_mean.clone() #slow learn1.fit(1, lr=0.02) test_ne(learn1.model[0][1].running_mean, m) #slow learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False, cbs=BnFreeze) m = learn1.model[0][1].running_mean.clone() learn1.fit(1, lr=0.02) test_eq(learn1.model[0][1].running_mean, m) #hide from nbdev.export import notebook2script notebook2script()