from fastai.torch_core import * from fastai.callback import * from fastai.basic_train import Learner, LearnerCallback from fastai import * from fastai.vision import * # __all__ = ['TerminateOnNaN', 'EarlyStopping', 'SaveModel'] path = untar_data(URLs.MNIST_SAMPLE) path data = ImageDataBunch.from_folder(path, ds_tfms=(rand_pad(2, 28), [])) learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy]) learn.fit(2,1e4) class TerminateOnNaN(Callback): "A `LearnerCallback` that terminates training if loss is NaN." def __init__(self): self.stop = False def on_batch_end(self, last_loss, epoch, num_batch, **kwargs:Any)->None: if self.stop: return True #to skip validation after stopping during traning if torch.isnan(last_loss): print (f'Epoch/Batch ({epoch}/{num_batch}): Invalid loss, terminating training.') self.stop = True return True def on_epoch_end(self, **kwargs:Any)->None: return self.stop learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy]) learn.fit(2,1e4, callbacks=[TerminateOnNaN()]) learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy]) learn.fit(3,1e-42) @dataclass class TrackerCallback(LearnerCallback): "A `LearnerCallback` that keeps track of the best value in `monitor`." monitor:str='val_loss' mode:str='auto' def __post_init__(self): if self.mode not in ['auto', 'min', 'max']: warn(f'{self.__name__} mode {self.mode} is invalid, falling back to "auto" mode.') self.mode = 'auto' mode_dict = {'min': np.less, 'max':np.greater} mode_dict['auto'] = np.less if 'loss' in self.monitor else np.greater self.operator = mode_dict[self.mode] def on_train_begin(self, **kwargs:Any)->None: self.best = float('inf') if self.operator == np.less else -float('inf') def get_monitor_value(self): values = {'trn_loss':self.learn.recorder.losses[-1:][0].cpu().numpy(), 'val_loss':self.learn.recorder.val_losses[-1:][0]} for i, name in enumerate(self.learn.recorder.names[3:]): values[name]=learn.recorder.metrics[-1:][0][i] if values.get(self.monitor) is None: warn(f'{self.__name__} conditioned on metric `{self.monitor}` which is not available. Available metrics are: {", ".join(map(str, self.learn.recorder.names[1:]))}') return values.get(self.monitor) @dataclass class EarlyStopping(TrackerCallback): "A `LearnerCallback` that terminates training when monitored quantity stops improving." min_delta:int=0 patience:int=0 def __post_init__(self): super().__post_init__() if self.operator == np.less: self.min_delta *= -1 def on_train_begin(self, **kwargs:Any)->None: self.wait = 0 super().on_train_begin(**kwargs) def on_epoch_end(self, epoch, **kwargs:Any)->None: current = self.get_monitor_value() if current is None: return if self.operator(current - self.min_delta, self.best): self.best,self.wait = current,0 else: self.wait += 1 if self.wait >= self.patience: print(f'Epoch {epoch}: early stopping') return True learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy]) learn.callback_fns.append(partial(EarlyStopping, monitor='accuracy', min_delta=0.01, patience=3)) learn.fit(50,1e-42) learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy] ) learn.fit(5,1e-42) @dataclass class SaveModel(TrackerCallback): "A `LearnerCallback` that saves the model when monitored quantity is best." every:str='improvement' name:str='bestmodel' def __post_init__(self): if self.every not in ['improvement', 'epoch']: warn(f'SaveModel every {every} is invalid, falling back to "improvement".') self.every = 'improvement' super().__post_init__() def on_epoch_end(self, epoch, **kwargs:Any)->None: if self.every=="epoch": learn.save(f'{self.name}_{epoch}') else: #every="improvement" current = self.get_monitor_value() if current is not None and self.operator(current, self.best): self.best = current learn.save(f'{self.name}') def on_train_end(self, **kwargs): if self.every=="improvement": learn.load(f'{self.name}') learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy]) learn.callback_fns.append(partial(SaveModel, every='epoch')) learn.fit(5,1e-42) !ls {path}/models/ learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy]) learn.callback_fns.append(partial(SaveModel, monitor='accuracy')) learn.fit(5,1e-2) !ls {path}/models/ validate(learn.model, learn.data.valid_dl, learn.loss_fn, metrics=[accuracy])