#default_exp callback.training
#export
from fastai2.basics import *
from fastai2.callback.progress import *
from fastai2.callback.fp16 import *
#hide
from nbdev.showdoc import *
from fastai2.test_utils import *
Callbacks that make decisions depending how a monitored metric/loss behaves
#export
@log_args
class ShortEpochCallback(Callback):
"Fit just `pct` of an epoch, then stop"
def __init__(self,pct=0.01,short_valid=True): self.pct,self.short_valid = pct,short_valid
def after_batch(self):
if self.iter/self.n_iter < self.pct: return
if self.training: raise CancelTrainException
if self.short_valid: raise CancelValidException
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback())
| epoch | train_loss | valid_loss | time |
|---|---|---|---|
| 0 | 00:00 |
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback(short_valid=False))
| epoch | train_loss | valid_loss | time |
|---|---|---|---|
| 0 | 12.395771 | 00:00 |
# export
@log_args
class GradientAccumulation(Callback):
"Accumulate gradients before updating weights"
toward_end,run_before=True,MixedPrecision
def __init__(self, n_acc=32): store_attr(self, 'n_acc')
def before_fit(self): self.count=0
def after_backward(self):
self.count += find_bs(self.learn.yb)
if self.count < self.n_acc: raise CancelBatchException() #skip weight update
else: self.count=0
_docs = dict(before_fit="Set counter to 0",
after_backward="Skip weight update if we have not seen enough items")
learn = synth_learner()
learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_acc=2*learn.dls.bs))
# ensure train_loss decreased
assert learn.recorder.values[-1][0] < learn.recorder.values[0][0]
learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_acc=1e6))
# ensure valid_loss didn't change (same weights)
assert learn.recorder.values[-1][1] == learn.recorder.values[0][1]
| epoch | train_loss | valid_loss | time |
|---|---|---|---|
| 0 | 10.566907 | 3.633753 | 00:00 |
| 1 | 5.525984 | 0.397483 | 00:00 |
| epoch | train_loss | valid_loss | time |
|---|---|---|---|
| 0 | 0.476599 | 0.397483 | 00:00 |
| 1 | 0.478213 | 0.397483 | 00:00 |
#export
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
def set_bn_eval(m:nn.Module, use_eval=True)->None:
"Set bn layers in eval mode for all recursive children of `m`."
for l in m.children():
if isinstance(l, bn_types) and not next(l.parameters()).requires_grad:
if use_eval: l.eval()
else: l.train()
set_bn_eval(l)
class BnFreeze(Callback):
"Freeze moving average statistics in all non-trainable batchnorm layers."
def before_epoch(self):
set_bn_eval(self.model)
BnFreeze is useful when you'd like to train two separate models that have a common feature extractor / body. The only part of the model that's different is the head that you attach for transfer learning.
Learner.freeze() doesn't suffice here as the BatchNorm layers are trainable by default, and running mean and sdev of batches are tracked. For feature extractors to fully match, you need to set train_bn=False and these stats need to be frozen as well, which is precisely the function of BnFreeze.
#slow
from fastai2.vision.all import *
path = untar_data(URLs.MNIST_TINY)
dls = ImageDataLoaders.from_folder(path, valid_pct=0.2)
We first demonstrate the mismatch of the running stats when using only train_bn=False, by creating a Learner...:
#slow
learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False)
...and grab the first BatchNorm layer, and store its running mean:
#slow
m = learn1.model[0][1].running_mean.clone()
You can see that now that running mean has changed:
#slow
learn1.fit(1, lr=0.02)
test_ne(learn1.model[0][1].running_mean, m)
| epoch | train_loss | valid_loss | time |
|---|---|---|---|
| 0 | 1.058304 | 0.713414 | 00:02 |
When we use the BnFreeze callback, the running statistics will not be changed during training. This is often important for getting good results from transfer learning.
#slow
learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False, cbs=BnFreeze)
m = learn1.model[0][1].running_mean.clone()
learn1.fit(1, lr=0.02)
test_eq(learn1.model[0][1].running_mean, m)
| epoch | train_loss | valid_loss | time |
|---|---|---|---|
| 0 | 0.540841 | 0.432421 | 00:02 |
#hide
from nbdev.export import notebook2script
notebook2script()
Converted 00_torch_core.ipynb. Converted 01_layers.ipynb. Converted 02_data.load.ipynb. Converted 03_data.core.ipynb. Converted 04_data.external.ipynb. Converted 05_data.transforms.ipynb. Converted 06_data.block.ipynb. Converted 07_vision.core.ipynb. Converted 08_vision.data.ipynb. Converted 09_vision.augment.ipynb. Converted 09b_vision.utils.ipynb. Converted 09c_vision.widgets.ipynb. Converted 10_tutorial.pets.ipynb. Converted 11_vision.models.xresnet.ipynb. Converted 12_optimizer.ipynb. Converted 13_callback.core.ipynb. Converted 13a_learner.ipynb. Converted 13b_metrics.ipynb. Converted 14_callback.schedule.ipynb. Converted 14a_callback.data.ipynb. Converted 15_callback.hook.ipynb. Converted 15a_vision.models.unet.ipynb. Converted 16_callback.progress.ipynb. Converted 17_callback.tracker.ipynb. Converted 18_callback.fp16.ipynb. Converted 18a_callback.training.ipynb. Converted 19_callback.mixup.ipynb. Converted 20_interpret.ipynb. Converted 20a_distributed.ipynb. Converted 21_vision.learner.ipynb. Converted 22_tutorial.imagenette.ipynb. Converted 23_tutorial.vision.ipynb. Converted 24_tutorial.siamese.ipynb. Converted 24_vision.gan.ipynb. Converted 30_text.core.ipynb. Converted 31_text.data.ipynb. Converted 32_text.models.awdlstm.ipynb. Converted 33_text.models.core.ipynb. Converted 34_callback.rnn.ipynb. Converted 35_tutorial.wikitext.ipynb. Converted 36_text.models.qrnn.ipynb. Converted 37_text.learner.ipynb. Converted 38_tutorial.text.ipynb. Converted 39_tutorial.transformers.ipynb. Converted 40_tabular.core.ipynb. Converted 41_tabular.data.ipynb. Converted 42_tabular.model.ipynb. Converted 43_tabular.learner.ipynb. Converted 44_tutorial.tabular.ipynb. Converted 45_collab.ipynb. Converted 46_tutorial.collab.ipynb. Converted 50_tutorial.datablock.ipynb. Converted 60_medical.imaging.ipynb. Converted 61_tutorial.medical_imaging.ipynb. Converted 65_medical.text.ipynb. Converted 70_callback.wandb.ipynb. Converted 71_callback.tensorboard.ipynb. Converted 72_callback.neptune.ipynb. Converted 73_callback.captum.ipynb. Converted 74_callback.cutmix.ipynb. Converted 97_test_utils.ipynb. Converted 99_pytorch_doc.ipynb. Converted index.ipynb. Converted tutorial.ipynb.