#default_exp callback.tracker #export from fastai2.basics import * from fastai2.callback.progress import * from fastai2.callback.fp16 import MixedPrecision from nbdev.showdoc import * from fastai2.test_utils import * # export class TerminateOnNaNCallback(Callback): "A `Callback` that terminates training if loss is NaN." run_before=Recorder def after_batch(self): "Test if `last_loss` is NaN and interrupts training." if torch.isinf(self.loss) or torch.isnan(self.loss): raise CancelFitException learn = synth_learner() learn.fit(10, lr=100, cbs=TerminateOnNaNCallback()) assert len(learn.recorder.losses) < 10 * len(learn.dls.train) for l in learn.recorder.losses: assert not torch.isinf(l) and not torch.isnan(l) # export class TrackerCallback(Callback): "A `Callback` that keeps track of the best value in `monitor`." remove_on_fetch,run_after = True,Recorder def __init__(self, monitor='valid_loss', comp=None, min_delta=0.): if comp is None: comp = np.less if 'loss' in monitor or 'error' in monitor else np.greater if comp == np.less: min_delta *= -1 self.monitor,self.comp,self.min_delta = monitor,comp,min_delta def begin_fit(self): "Prepare the monitored value" self.run = not hasattr(self, "lr_finder") and not hasattr(self, "gather_preds") self.best = float('inf') if self.comp == np.less else -float('inf') assert self.monitor in self.recorder.metric_names[1:] self.idx = list(self.recorder.metric_names[1:]).index(self.monitor) def after_epoch(self): "Compare the last value to the best up to know" val = self.recorder.values[-1][self.idx] if self.comp(val - self.min_delta, self.best): self.best,self.new_best = val,True else: self.new_best = False def after_fit(self): self.run=True #hide class FakeRecords(Callback): run_after=Recorder run_before=TrackerCallback def __init__(self, monitor, values): self.monitor,self.values = monitor,values def begin_fit(self): self.idx = list(self.recorder.metric_names[1:]).index(self.monitor) def after_epoch(self): self.recorder.values[-1][self.idx] = self.values[self.epoch] class TestTracker(Callback): run_after=TrackerCallback def begin_fit(self): self.bests,self.news = [],[] def after_epoch(self): self.bests.append(self.tracker.best) self.news.append(self.tracker.new_best) #hide learn = synth_learner(n_trn=2, cbs=TestTracker()) cbs=[TrackerCallback(monitor='valid_loss'), FakeRecords('valid_loss', [0.2,0.1])] with learn.no_logging(): learn.fit(2, cbs=cbs) test_eq(learn.test_tracker.bests, [0.2, 0.1]) test_eq(learn.test_tracker.news, [True,True]) #With a min_delta cbs=[TrackerCallback(monitor='valid_loss', min_delta=0.15), FakeRecords('valid_loss', [0.2,0.1])] with learn.no_logging(): learn.fit(2, cbs=cbs) test_eq(learn.test_tracker.bests, [0.2, 0.2]) test_eq(learn.test_tracker.news, [True,False]) #hide #By default metrics have to be bigger at each epoch. def tst_metric(out,targ): return F.mse_loss(out,targ) learn = synth_learner(n_trn=2, cbs=TestTracker(), metrics=tst_metric) cbs=[TrackerCallback(monitor='tst_metric'), FakeRecords('tst_metric', [0.2,0.1])] with learn.no_logging(): learn.fit(2, cbs=cbs) test_eq(learn.test_tracker.bests, [0.2, 0.2]) test_eq(learn.test_tracker.news, [True,False]) #This can be overwritten by passing `comp=np.less`. learn = synth_learner(n_trn=2, cbs=TestTracker(), metrics=tst_metric) cbs=[TrackerCallback(monitor='tst_metric', comp=np.less), FakeRecords('tst_metric', [0.2,0.1])] with learn.no_logging(): learn.fit(2, cbs=cbs) test_eq(learn.test_tracker.bests, [0.2, 0.1]) test_eq(learn.test_tracker.news, [True,True]) #hide #A tracker callback is not run during an lr_find from fastai2.callback.schedule import * learn = synth_learner(n_trn=2, cbs=TrackerCallback(monitor='tst_metric'), metrics=tst_metric) learn.lr_find(num_it=15, show_plot=False) assert not hasattr(learn, 'new_best') # export @log_args class EarlyStoppingCallback(TrackerCallback): "A `TrackerCallback` that terminates training when monitored quantity stops improving." def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1): super().__init__(monitor=monitor, comp=comp, min_delta=min_delta) self.patience = patience def begin_fit(self): self.wait = 0; super().begin_fit() def after_epoch(self): "Compare the value monitored to its best score and maybe stop training." super().after_epoch() if self.new_best: self.wait = 0 else: self.wait += 1 if self.wait >= self.patience: print(f'No improvement since epoch {self.epoch-self.wait}: early stopping') raise CancelFitException() learn = synth_learner(n_trn=2, metrics=F.mse_loss) learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='mse_loss', min_delta=0.1, patience=2)) learn.validate() learn = synth_learner(n_trn=2) learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='valid_loss', min_delta=0.1, patience=2)) #hide test_eq(len(learn.recorder.values), 3) # export @log_args class SaveModelCallback(TrackerCallback): "A `TrackerCallback` that saves the model's best during training and loads it at the end." def __init__(self, monitor='valid_loss', comp=None, min_delta=0., fname='model', every_epoch=False, add_save=None, with_opt=False): super().__init__(monitor=monitor, comp=comp, min_delta=min_delta) store_attr(self, 'fname,every_epoch,add_save,with_opt') def _save(self, name): self.learn.save(name, with_opt=self.with_opt) if self.add_save is not None: with self.add_save.open('wb') as f: self.learn.save(f, with_opt=self.with_opt) def after_epoch(self): "Compare the value monitored to its best score and save if best." if self.every_epoch: self._save(f'{self.fname}_{self.epoch}') else: #every improvement super().after_epoch() if self.new_best: print(f'Better model found at epoch {self.epoch} with {self.monitor} value: {self.best}.') self._save(f'{self.fname}') def after_fit(self, **kwargs): "Load the best model." if not self.every_epoch: self.learn.load(f'{self.fname}') learn = synth_learner(n_trn=2, path=Path.cwd()/'tmp') learn.fit(n_epoch=2, cbs=SaveModelCallback()) assert (Path.cwd()/'tmp/models/model.pth').exists() learn.fit(n_epoch=2, cbs=SaveModelCallback(every_epoch=True)) for i in range(2): assert (Path.cwd()/f'tmp/models/model_{i}.pth').exists() shutil.rmtree(Path.cwd()/'tmp') # export @log_args class ReduceLROnPlateau(TrackerCallback): "A `TrackerCallback` that reduces learning rate when a metric has stopped improving." def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1, factor=10., min_lr=0): super().__init__(monitor=monitor, comp=comp, min_delta=min_delta) self.patience,self.factor,self.min_lr = patience,factor,min_lr def begin_fit(self): self.wait = 0; super().begin_fit() def after_epoch(self): "Compare the value monitored to its best score and reduce LR by `factor` if no improvement." super().after_epoch() if self.new_best: self.wait = 0 else: self.wait += 1 if self.wait >= self.patience: old_lr = self.opt.hypers[-1]['lr'] for h in self.opt.hypers: h['lr'] = max(h['lr'] / self.factor, self.min_lr) self.wait = 0 if self.opt.hypers[-1]["lr"] < old_lr: print(f'Epoch {self.epoch}: reducing lr to {self.opt.hypers[-1]["lr"]}') learn = synth_learner(n_trn=2) learn.fit(n_epoch=4, lr=1e-7, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2)) #hide test_eq(learn.opt.hypers[-1]['lr'], 1e-8) learn = synth_learner(n_trn=2) learn.fit(n_epoch=6, lr=5e-8, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2, min_lr=1e-8)) #hide test_eq(learn.opt.hypers[-1]['lr'], 1e-8) #hide from nbdev.export import notebook2script notebook2script()