from fastai.gen_doc.nbdoc import *
from fastai.vision import *
from fastai.text import *
from fastai.callbacks import *
from fastai.basic_train import *
from fastai.train import *
from fastai import callbacks
fastai's training loop is highly extensible, with a rich callback system. See the callback docs if you're interested in writing your own callback. See below for a list of callbacks that are provided with fastai, grouped by the module they're defined in.
Every callback that is passed to Learner with the callback_fns parameter will be automatically stored as an attribute. The attribute name is snake-cased, so for instance ActivationStats will appear as learn.activation_stats (assuming your object is named learn).
LRFinder¶Use Leslie Smith's learning rate finder to find a good learning rate for training your model. Let's see an example of use on the MNIST dataset with a simple CNN.
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
def simple_learner(): return Learner(data, simple_cnn((3,16,16,2)), metrics=[accuracy])
learn = simple_learner()
learn.lr_find()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.recorder.plot()
In this example, a learning rate around 2e-2 seems like the right fit.
lr = 2e-2
OneCycleScheduler¶Train with Leslie Smith's 1cycle annealing method. Let's train our simple learner using the one cycle policy.
learn.fit_one_cycle(3, lr)
| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 0.119191 | 0.071195 | 0.972522 | 00:02 |
| 1 | 0.057419 | 0.042737 | 0.984298 | 00:02 |
| 2 | 0.031792 | 0.028259 | 0.987733 | 00:02 |
The learning rate and the momentum were changed during the epochs as follows (more info on the dedicated documentation page).
learn.recorder.plot_lr(show_moms=True)
MixUpCallback¶Data augmentation using the method from mixup: Beyond Empirical Risk Minimization. It is very simple to add mixup in fastai :
learn = Learner(data, simple_cnn((3, 16, 16, 2)), metrics=[accuracy]).mixup()
learn = Learner(data, simple_cnn((3, 16, 16, 2)), metrics=[accuracy, error_rate], callback_fns=[CSVLogger])
learn.fit(3)
| epoch | train_loss | valid_loss | accuracy | error_rate | time |
|---|---|---|---|---|---|
| 0 | 0.119083 | 0.108034 | 0.959274 | 0.040726 | 00:02 |
| 1 | 0.078156 | 0.071208 | 0.973013 | 0.026987 | 00:02 |
| 2 | 0.056985 | 0.045835 | 0.984789 | 0.015211 | 00:02 |
You can then read the csv.
learn.csv_logger.read_logged_file()
| epoch | train_loss | valid_loss | accuracy | error_rate | time | |
|---|---|---|---|---|---|---|
| 0 | 0 | 0.119083 | 0.108034 | 0.959274 | 0.040726 | NaN |
| 1 | 1 | 0.078156 | 0.071208 | 0.973013 | 0.026987 | NaN |
| 2 | 2 | 0.056985 | 0.045835 | 0.984789 | 0.015211 | NaN |
GeneralScheduler¶Create your own multi-stage annealing schemes with a convenient API. To illustrate, let's implement a 2 phase schedule.
def fit_odd_shedule(learn, lr):
n = len(learn.data.train_dl)
phases = [TrainingPhase(n).schedule_hp('lr', lr, anneal=annealing_cos),
TrainingPhase(n*2).schedule_hp('lr', lr, anneal=annealing_poly(2))]
sched = GeneralScheduler(learn, phases)
learn.callbacks.append(sched)
total_epochs = 3
learn.fit(total_epochs)
learn = Learner(data, simple_cnn((3,16,16,2)), metrics=accuracy)
fit_odd_shedule(learn, 1e-3)
| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 0.171962 | 0.154716 | 0.947498 | 00:02 |
| 1 | 0.133720 | 0.132249 | 0.957802 | 00:02 |
| 2 | 0.132928 | 0.129927 | 0.957802 | 00:02 |
learn.recorder.plot_lr()
MixedPrecision¶Use fp16 to take advantage of tensor cores on recent NVIDIA GPUs for a 200% or more speedup.
HookCallback¶Convenient wrapper for registering and automatically deregistering PyTorch hooks. Also contains pre-defined hook callback: ActivationStats.
RNNTrainer¶Callback taking care of all the tweaks to train an RNN.
TerminateOnNaNCallback¶Stop training if the loss reaches NaN.
EarlyStoppingCallback¶Stop training if a given metric/validation loss doesn't improve.
SaveModelCallback¶Save the model at every epoch, or the best model for a given metric/validation loss.
learn = Learner(data, simple_cnn((3,16,16,2)), metrics=accuracy)
learn.fit_one_cycle(3,1e-4, callbacks=[SaveModelCallback(learn, every='epoch', monitor='accuracy')])
| epoch | train_loss | valid_loss | accuracy | time |
|---|---|---|---|---|
| 0 | 0.665244 | 0.642582 | 0.816487 | 00:02 |
| 1 | 0.508492 | 0.471950 | 0.937684 | 00:02 |
| 2 | 0.438286 | 0.435377 | 0.941119 | 00:02 |
!ls ~/.fastai/data/mnist_sample/models
best.pth bestmodel_2.pth model_1.pth model_4.pth stage-1.pth bestmodel_0.pth bestmodel_3.pth model_2.pth model_5.pth tmp.pth bestmodel_1.pth model_0.pth model_3.pth one_epoch.pth trained_model.pth
ReduceLROnPlateauCallback¶Reduce the learning rate each time a given metric/validation loss doesn't improve by a certain factor.
PeakMemMetric¶GPU and general RAM profiling callback
StopAfterNBatches¶Stop training after n batches of the first epoch.
train and basic_train¶GradientClipping¶Clips gradient during training.