from fastai.basic_train import *
from fastai.gen_doc.nbdoc import *
from fastai import *
from fastai.vision import *
basic_train wraps together the data (in a DataBunch object) with a pytorch model to define a Learner object. This is where the basic training loop is defined for the fit function. The Learner object is the entry point of most of the Callback functions that will customize this training loop in different ways (and made available through the train module), notably:
Learner.lr_find will launch an LR range test that will help you select a good learning rateLearner.fit_one_cycle will launch a training using the 1cycle policy, to help you train your model fast.Learner.to_fp16 will convert your model in half precision and help you launch a training in mixed precision.show_doc(Learner, title_level=2)
class Learner[source]
Learner(data:DataBunch,model:Module,opt_func:Callable='Adam',loss_func:Callable=None,metrics:Collection[Callable]=None,true_wd:bool=True,bn_wd:bool=True,wd:Floats=0.01,train_bn:bool=True,path:str=None,model_dir:str='models',callback_fns:Collection[Callable]=None,callbacks:Collection[Callback]=<factory>,layer_groups:ModuleList=None)
Train model using data to minimize loss_func with optimizer opt_func.
The main purpose of Learner is to train model using Learner.fit. After every epoch, all metrics will be printed, and will also be available to callbacks.
The default weight decay will be wd, which will be handled using the method from Fixing Weight Decay Regularization in Adam if true_wd is set (otherwise it's L2 regularization). If bn_wd is False then weight decay will be removed from batchnorm layers, as recommended in Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour. You can ensure that batchnorm layer learnable params are trained even for frozen layer groups, by enabling train_bn.
To use discriminative layer training pass an nn.Module for each layer group to be optimized with different settings.
Any model files created will be saved in path/model_dir.
You can pass a list of callbacks that you have already created, or (more commonly) simply pass a list of callback functions to callback_fns and each function will be called (passing self) on object initialization, with the results stored as callback objects. For a walk-through, see the training overview page. You may also want to use an application to fit your model, e.g. using the create_cnn method:
path = untar_data(URLs.MNIST_SAMPLE)
data = ImageDataBunch.from_folder(path)
learn = create_cnn(data, models.resnet18, metrics=accuracy)
learn.fit(1)
Total time: 00:09 epoch train_loss valid_loss accuracy 1 0.067861 0.032277 0.992149 (00:09)
show_doc(Learner.fit)
Uses discriminative layer training if multiple learning rates or weight decay values are passed. To control training behaviour, use the callback system or one or more of the pre-defined callbacks.
show_doc(Learner.fit_one_cycle)
fit_one_cycle[source]
fit_one_cycle(learn:Learner,cyc_len:int,max_lr:Union[float,Collection[float],slice]=slice(None, 0.003, None),moms:Point=(0.95, 0.85),div_factor:float=25.0,pct_start:float=0.3,wd:float=None,callbacks:Optional[Collection[Callback]]=None,kwargs)
Fit a model following the 1cycle policy.
Uses the OneCycleScheduler callback.
show_doc(Learner.lr_find)
Runs the learning rate finder defined in LRFinder, as discussed in Cyclical Learning Rates for Training Neural Networks.
show_doc(Learner.get_preds)
get_preds[source]
get_preds(ds_type:DatasetType=<DatasetType.Valid: 2>,with_loss:bool=False,n_batch:Optional[int]=None,pbar:Union[MasterBar,ProgressBar,NoneType]=None) →List[Tensor]
Return predictions and targets on the valid, train, or test set, depending on ds_type.
show_doc(Learner.validate)
validate[source]
validate(dl=None,callbacks=None,metrics=None)
Validate on dl with potential callbacks and metrics.
show_doc(Learner.TTA, full_name = 'TTA')
TTA[source]
TTA(learn:Learner,beta:float=0.4,scale:float=1.35,ds_type:DatasetType=<DatasetType.Valid: 2>,with_loss:bool=False) →Tensors
Applies Test Time Augmentation to learn on the dataset ds_type. We take the average of our regular predictions (with a weight beta) with the average of predictions obtained thourh augmented versions of the training set (with a weight 1-beta). The transforms decided for the training set are applied with a few changes scale controls the scale for zoom (which isn't random), the cropping isn't random but we make sure to get the four corners of the image. Flipping isn't random but applied once on each of those corner images (so that makes 8 augmented versions total).
show_doc(Learner.to_fp16)
Uses the MixedPrecision callback to train in mixed precision (i.e. forward and backward passes using fp16, with weight updates using fp32), using all NVIDIA recommendations for ensuring speed and accuracy.
When fitting a model you can pass a list of learning rates (and/or weight decay amounts), which will apply a different rate to each layer group (i.e. the parameters of each module in self.layer_groups). See the Universal Language Model Fine-tuning for Text Classification paper for details and experimental results in NLP (we also frequently use them successfully in computer vision, but have not published a paper on this topic yet). When working with a Learner on which you've called split, you can set hyperparameters in four ways:
param = [val1, val2 ..., valn] (n = number of layer groups)param = valparam = slice(start,end)param = slice(end)If we chose to set it in way 1, we must specify a number of values exactly equal to the number of layer groups. If we chose to set it in way 2, the chosen value will be repeated for all layer groups. See Learner.lr_range for an explanation of the slice syntax).
Here's an example of how to use discriminative learning rates (note that you don't actually need to manually call Learner.split in this case, since fastai uses this exact function as the default split for resnet18; this is just to show how to customize it):
# creates 3 layer groups
learn.split(lambda m: (m[0][6], m[1]))
# only randomly initialized head now trainable
learn.freeze()
learn.fit_one_cycle(1)
Total time: 00:08 epoch train_loss valid_loss accuracy 1 0.036884 0.023377 0.993621 (00:08)
# all layers now trainable
learn.unfreeze()
# optionally, separate LR and WD for each group
learn.fit_one_cycle(1, max_lr=(1e-4, 1e-3, 1e-2), wd=(1e-4,1e-4,1e-1))
Total time: 00:11 epoch train_loss valid_loss accuracy 1 0.025823 0.008318 0.997547 (00:11)
show_doc(Learner.lr_range)
Rather than manually setting an LR for every group, it's often easier to use Learner.lr_range. This is a convenience method that returns one learning rate for each layer group. If you pass slice(start,end) then the first group's learning rate is start, the last is end, and the remaining are evenly geometrically spaced.
If you pass just slice(end) then the last group's learning rate is end, and all the other groups are end/3. For instance (for our learner that has 3 layer groups):
learn.lr_range(slice(1e-5,1e-3)), learn.lr_range(slice(3e-4))
(array([1.e-05, 1.e-04, 1.e-03]), array([1.e-04, 1.e-04, 3.e-04]))
show_doc(Learner.unfreeze)
Sets every layer group to trainable (i.e. requires_grad=True).
show_doc(Learner.freeze)
Sets every layer group except the last to untrainable (i.e. requires_grad=False).
show_doc(Learner.freeze_to)
show_doc(Learner.split)
A convenience method that sets layer_groups based on the result of split_model. If split_on is a function, it calls that function and passes the result to split_model (see above for example).
Simply call Learner.save and Learner.load to save and load models. Only the parameters are saved, not the actual architecture (so you'll need to create your model in the same way before loading weights back in). Models are saved to the path/model_dir directory.
show_doc(Learner.load)
show_doc(Learner.save)
save[source]
save(name:PathOrStr,return_path:bool=False) →Union[NoneType,str]
Save model with name to self.model_dir, and return path if return_path.
show_doc(Learner.show_results)
show_results[source]
show_results(ds_type=<DatasetType.Valid: 2>,rows:int=5,kwargs)
Show rows result of predictions on ds_type dataset.
show_doc(Learner.predict)
show_doc(Learner.validate)
validate[source]
validate(dl=None,callbacks=None,metrics=None)
Validate on dl with potential callbacks and metrics.
show_doc(Learner.create_unet, doc_string=False)
show_doc(Learner.init)
init[source]
init(init)
Initializes all weights (except batchnorm) using function init, which will often be from PyTorch's nn.init module.
show_doc(Learner.mixup)
mixup[source]
mixup(learn:Learner,alpha:float=0.4,stack_x:bool=False,stack_y:bool=True) →Learner
Add mixup https://arxiv.org/abs/1710.09412 to learn.
Uses MixUpCallback.
show_doc(Learner.pred_batch)
pred_batch[source]
pred_batch(ds_type:DatasetType=<DatasetType.Valid: 2>,pbar:Union[MasterBar,ProgressBar,NoneType]=None) →List[Tensor]
Return output of the model on one batch from valid, train, or test set, depending on ds_type.
Get the first batch of predictions. Mainly useful for debugging and quick tests.
show_doc(Learner.create_opt)
create_opt[source]
create_opt(lr:Floats,wd:Floats=0.0)
Create optimizer with lr learning rate and wd weight decay.
You generally won't need to call this yourself - it's used to create the optim optimizer before fitting the model.
show_doc(Learner.dl)
dl[source]
dl(ds_type:DatasetType=<DatasetType.Valid: 2>)
Return DataLoader for DatasetType ds_type.
show_doc(Recorder, title_level=2)
class Recorder[source]
Recorder(learn:Learner) ::LearnerCallback
A LearnerCallback that records epoch, loss, opt and metric data during training.
A Learner creates a Recorder object automatically - you do not need to explicitly pass to callback_fns - because other callbacks rely on it being available. It stores the smoothed loss, hyperparameter values, and metrics each batch, and provides plotting methods for each. Note that Learner automatically sets an attribute with the snake-cased name of each callback, so you can access this through Learner.recorder, as shown below.
show_doc(Recorder.plot)
plot[source]
plot(skip_start:int=10,skip_end:int=5)
Plot learning rate and losses, trimmed between skip_start and skip_end.
This is mainly used with the learning rate finder, since it shows a scatterplot of loss vs learning rate.
learn = create_cnn(data, models.resnet18, metrics=accuracy)
learn.lr_find()
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
show_doc(Recorder.plot_losses)
Note that validation losses are only calculated once per epoch, whereas training losses are calculated after every batch.
learn.fit_one_cycle(2)
learn.recorder.plot_losses()
Total time: 00:15 epoch train_loss valid_loss accuracy 1 0.110780 0.047797 0.981845 (00:07) 2 0.046339 0.038317 0.987733 (00:07)
show_doc(Recorder.plot_lr)
learn.recorder.plot_lr(show_moms=True)
show_doc(Recorder.plot_metrics)
Note that metrics are only collected at the end of each epoch, so you'll need to train at least two epochs to have anything to show here.
learn.recorder.plot_metrics()
You don't call these yourself - they're called by fastai's callback system automatically to enable the class's functionality.
show_doc(Recorder.on_backward_begin)
on_backward_begin[source]
on_backward_begin(smooth_loss:Tensor,kwargs:Any)
Record the loss before any other callback has a chance to modify it.
show_doc(Recorder.on_batch_begin)
on_batch_begin[source]
on_batch_begin(train,kwargs:Any)
Record learning rate and momentum at beginning of batch.
show_doc(Recorder.on_epoch_end)
on_epoch_end[source]
on_epoch_end(epoch:int,num_batch:int,smooth_loss:Tensor,last_metrics='Collection',kwargs:Any) →bool
Save epoch info: num_batch, smooth_loss, metrics.
show_doc(Recorder.on_train_begin)
on_train_begin[source]
on_train_begin(pbar:PBar,metrics_names:StrList,kwargs:Any)
Initialize recording status at beginning of training.
show_doc(fit)
Note that you have to create the Optimizer yourself if you call this function, whereas Learn.fit creates it for you automatically.
show_doc(train_epoch)
train_epoch[source]
train_epoch(model:Module,dl:DataLoader,opt:Optimizer,loss_func:LossFunction)
Simple training of model for 1 epoch of dl using optim opt and loss function loss_func.
You won't generally need to call this yourself - it's what fit calls for each epoch.
show_doc(validate)
validate[source]
validate(model:Module,dl:DataLoader,loss_func:OptLossFunc=None,cb_handler:Optional[CallbackHandler]=None,pbar:Union[MasterBar,ProgressBar,NoneType]=None,average=True,n_batch:Optional[int]=None) →Iterator[Tuple[IntOrTensor,Ellipsis]]
Calculate loss and metrics for the validation set.
This is what fit calls after each epoch. You can call it if you want to run inference on a DataLoader manually.
show_doc(get_preds)
get_preds[source]
get_preds(model:Module,dl:DataLoader,pbar:Union[MasterBar,ProgressBar,NoneType]=None,cb_handler:Optional[CallbackHandler]=None,activ:Module=None,loss_func:OptLossFunc=None,n_batch:Optional[int]=None) →List[Tensor]
Tuple of predictions and targets, and optional losses (if loss_func) using dl, max batches n_batch.
show_doc(loss_batch)
loss_batch[source]
loss_batch(model:Module,xb:Tensor,yb:Tensor,loss_func:OptLossFunc=None,opt:OptOptimizer=None,cb_handler:Optional[CallbackHandler]=None) →Tuple[Union[Tensor,int,float,str]]
Calculate loss and metrics for a batch, call out to callbacks as necessary.
show_doc(LearnerCallback, title_level=3)
show_doc(Learner.tta_only)
_tta_only[source]
_tta_only(learn:Learner,ds_type:DatasetType=<DatasetType.Valid: 2>,scale:float=1.35) →Iterator[List[Tensor]]
Computes the outputs for several augmented inputs for TTA
show_doc(Learner.get_preds)
get_preds[source]
get_preds(ds_type:DatasetType=<DatasetType.Valid: 2>,with_loss:bool=False,n_batch:Optional[int]=None,pbar:Union[MasterBar,ProgressBar,NoneType]=None) →List[Tensor]
Return predictions and targets on the valid, train, or test set, depending on ds_type.
show_doc(Learner.TTA)
_TTA[source]
_TTA(learn:Learner,beta:float=0.4,scale:float=1.35,ds_type:DatasetType=<DatasetType.Valid: 2>,with_loss:bool=False) →Tensors
show_doc(Recorder.format_stats)
show_doc(Recorder.add_metrics)
add_metrics[source]
add_metrics(metrics)
show_doc(Recorder.add_metric_names)
add_metric_names[source]
add_metric_names(names)