Contribution from @fredguth, https://github.com/fredguth/fastai_playground
from fastai.torch_core import *
from fastai.callback import *
from fastai.basic_train import Learner, LearnerCallback
from fastai import *
from fastai.vision import *
# __all__ = ['TerminateOnNaN', 'EarlyStopping', 'SaveModel']
path = untar_data(URLs.MNIST_SAMPLE)
path
data = ImageDataBunch.from_folder(path, ds_tfms=(rand_pad(2, 28), []))
The callbacks bellow are based on Keras Callbacks of same name: https://github.com/keras-team/keras/blob/master/keras/callbacks.py
learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy])
learn.fit(2,1e4)
The Callback bellow is very influenced by Keras Callback of same name.
class TerminateOnNaN(Callback):
"A `LearnerCallback` that terminates training if loss is NaN."
def __init__(self):
self.stop = False
def on_batch_end(self, last_loss, epoch, num_batch, **kwargs:Any)->None:
if self.stop: return True #to skip validation after stopping during traning
if torch.isnan(last_loss):
print (f'Epoch/Batch ({epoch}/{num_batch}): Invalid loss, terminating training.')
self.stop = True
return True
def on_epoch_end(self, **kwargs:Any)->None:
return self.stop
learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy])
learn.fit(2,1e4, callbacks=[TerminateOnNaN()])
learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy])
learn.fit(3,1e-42)
The Callback bellow is basically a simplified port of Keras Early Stopping callback to fastai/pytorch.
@dataclass
class TrackerCallback(LearnerCallback):
"A `LearnerCallback` that keeps track of the best value in `monitor`."
monitor:str='val_loss'
mode:str='auto'
def __post_init__(self):
if self.mode not in ['auto', 'min', 'max']:
warn(f'{self.__name__} mode {self.mode} is invalid, falling back to "auto" mode.')
self.mode = 'auto'
mode_dict = {'min': np.less, 'max':np.greater}
mode_dict['auto'] = np.less if 'loss' in self.monitor else np.greater
self.operator = mode_dict[self.mode]
def on_train_begin(self, **kwargs:Any)->None:
self.best = float('inf') if self.operator == np.less else -float('inf')
def get_monitor_value(self):
values = {'trn_loss':self.learn.recorder.losses[-1:][0].cpu().numpy(),
'val_loss':self.learn.recorder.val_losses[-1:][0]}
for i, name in enumerate(self.learn.recorder.names[3:]):
values[name]=learn.recorder.metrics[-1:][0][i]
if values.get(self.monitor) is None:
warn(f'{self.__name__} conditioned on metric `{self.monitor}` which is not available. Available metrics are: {", ".join(map(str, self.learn.recorder.names[1:]))}')
return values.get(self.monitor)
@dataclass
class EarlyStopping(TrackerCallback):
"A `LearnerCallback` that terminates training when monitored quantity stops improving."
min_delta:int=0
patience:int=0
def __post_init__(self):
super().__post_init__()
if self.operator == np.less: self.min_delta *= -1
def on_train_begin(self, **kwargs:Any)->None:
self.wait = 0
super().on_train_begin(**kwargs)
def on_epoch_end(self, epoch, **kwargs:Any)->None:
current = self.get_monitor_value()
if current is None: return
if self.operator(current - self.min_delta, self.best):
self.best,self.wait = current,0
else:
self.wait += 1
if self.wait >= self.patience:
print(f'Epoch {epoch}: early stopping')
return True
learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy])
learn.callback_fns.append(partial(EarlyStopping, monitor='accuracy', min_delta=0.01, patience=3))
learn.fit(50,1e-42)
learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy] )
learn.fit(5,1e-42)
Best epoch is #1. But model is in #4.
@dataclass
class SaveModel(TrackerCallback):
"A `LearnerCallback` that saves the model when monitored quantity is best."
every:str='improvement'
name:str='bestmodel'
def __post_init__(self):
if self.every not in ['improvement', 'epoch']:
warn(f'SaveModel every {every} is invalid, falling back to "improvement".')
self.every = 'improvement'
super().__post_init__()
def on_epoch_end(self, epoch, **kwargs:Any)->None:
if self.every=="epoch": learn.save(f'{self.name}_{epoch}')
else: #every="improvement"
current = self.get_monitor_value()
if current is not None and self.operator(current, self.best):
self.best = current
learn.save(f'{self.name}')
def on_train_end(self, **kwargs):
if self.every=="improvement": learn.load(f'{self.name}')
learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy])
learn.callback_fns.append(partial(SaveModel, every='epoch'))
learn.fit(5,1e-42)
!ls {path}/models/
learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy])
learn.callback_fns.append(partial(SaveModel, monitor='accuracy'))
learn.fit(5,1e-2)
!ls {path}/models/
validate(learn.model, learn.data.valid_dl, learn.loss_fn, metrics=[accuracy])