In this tutorial, we will be looking at the training script of timm. There are various features that timm has to offer and some of them have been listed below:
timm also supports multiple optimizers & schedulers. In this tutorial we will be only be looking at the above 7 features and look at how you could utilize timm to use these features for your own experiments on a custom dataset.
As part of this tutorial, we will first start out with a general introduction to the training script and look at the various key steps that occur inside this script at a high-level. Then, we will look at some of the details of the above 7 features to get a further understanding of train.py.
The training script in timm can accept ~100 arguments. You can find more about these by running python train.py --help. These arguments are to define Dataset/Model parameters, Optimizer parameters, Learnining Rate scheduler parameters, Augmentation and regularization, Batch Norm parameters, Model exponential moving average parameters, and some miscellaneaous parameters such as --seed, --tta etc.
As part of this tutorial, we will be looking at how the training script makes use of these parameters from a high-level view. This could be beneficial for you to able to run your own experiments on ImageNet or any other custom dataset using timm.
The only argument required by timm training script is the path to the training data such as ImageNet which is structured in the following way:
imagenette2-320
├── train
│ ├── n01440764
│ ├── n02102040
│ ├── n02979186
│ ├── n03000684
│ ├── n03028079
│ ├── n03394916
│ ├── n03417042
│ ├── n03425413
│ ├── n03445777
│ └── n03888257
└── val
├── n01440764
├── n02102040
├── n02979186
├── n03000684
├── n03028079
├── n03394916
├── n03417042
├── n03425413
├── n03445777
└── n03888257
So to start training on this imagenette2-320 we could just do something like python train.py <path_to_imagenette2-320_dataset>.
The various default args, in the training script are setup for you and what get's passed to the training script looks something like this:
Namespace(aa=None, amp=False, apex_amp=False, aug_splits=0, batch_size=32, bn_eps=None, bn_momentum=None, bn_tf=False, channels_last=False, clip_grad=None, color_jitter=0.4, cooldown_epochs=10, crop_pct=None, cutmix=0.0, cutmix_minmax=None, data_dir='../imagenette2-320', dataset='', decay_epochs=30, decay_rate=0.1, dist_bn='', drop=0.0, drop_block=None, drop_connect=None, drop_path=None, epochs=200, eval_metric='top1', gp=None, hflip=0.5, img_size=None, initial_checkpoint='', input_size=None, interpolation='', jsd=False, local_rank=0, log_interval=50, lr=0.01, lr_cycle_limit=1, lr_cycle_mul=1.0, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, mean=None, min_lr=1e-05, mixup=0.0, mixup_mode='batch', mixup_off_epoch=0, mixup_prob=1.0, mixup_switch_prob=0.5, model='resnet101', model_ema=False, model_ema_decay=0.9998, model_ema_force_cpu=False, momentum=0.9, native_amp=False, no_aug=False, no_prefetcher=False, no_resume_opt=False, num_classes=None, opt='sgd', opt_betas=None, opt_eps=None, output='', patience_epochs=10, pin_mem=False, pretrained=False, ratio=[0.75, 1.3333333333333333], recount=1, recovery_interval=0, remode='const', reprob=0.0, resplit=False, resume='', save_images=False, scale=[0.08, 1.0], sched='step', seed=42, smoothing=0.1, split_bn=False, start_epoch=None, std=None, sync_bn=False, torchscript=False, train_interpolation='random', train_split='train', tta=0, use_multi_epochs_loader=False, val_split='validation', validation_batch_size_multiplier=1, vflip=0.0, warmup_epochs=3, warmup_lr=0.0001, weight_decay=0.0001, workers=4)
Notice, that args is a Namespace which means we can set more along the way if needed by doing something like args.new_variable="some_value".
To get a one-line introduction of these various arguments, we can just do something like python train.py --help.
NOTE: We will be looking at what these parameters in greater detail as part of this tutorial. Do note that some random augmentations are set by default such as
color_jitter,hfliipbut there is a parameterno-augin case you wanted to turn of all training data augmentations. Also, the default optimizeroptis 'sgd' but it is possible to change that.timmoffers a vast number of optimizers to train your models with.
In this section we will look at the various steps from a high level perspective that occur inside the training script. These steps have been outlined below in the correct order:
args.distributed is True.timm.create_model function.{'url': '', 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'first_conv': 'conv1', 'classifier': 'fc'}
NOTE: I feel this needs a little more explaination. In general, when we train a model, we apply the data augmentation to the complete batch and then define batch norm statistics from this complete batch such as mean and variance. But as introduced in this paper, sometimes it is beneficial to split the data into groups and use separate Batch Normalization layers for each to normalize the groups independently throughout the training process. This is referred to as auxiliary batch norm in the paper and is referred to
SplitBatchNorm2dintimm.
NOTE: Let's assume that number of augmentation batch splits is set to two. In that case, we split the data into two groups - one without any augmentations (referred to as clean) and one with augmentations. Then we use two separate batch normalization layers to normalize the two groups throughout the training process.
If we are using multiple GPUs for training, then setup either apex syncBN or PyTorch native SyncBatchNorm to set up Synchronized Batch Normalization. This means that rather than normalizing the data on each individual GPU, we normalize the whole batch at one spread across multiple GPUs.
Make model exportable using torch.jit if requested.
Initialize optimizer based on arguments passed to the training script.
Setup mixed Precision - either using apex.amp or using native torch amp - torch.cuda.amp.autocast.
Load model weights if resuming from a model checkpoint.
Setup exponential moving average of model weights. This is similar to Stochastic Weight Averaging.
Setup distributed training based on parameters from step-1.
Setup learning rate scheduler.
Create training and validation dataset.
Setup Mixup/Cutmix data augmentation.
Convert training dataset to AugmixDataset if number of augmentation batch splits from step-5 is greater than 1.
Create training data loader and Validation dataloader.
NOTE: Transforms/Augmentations also get created inside the training dataloader.
timm features¶#hide
# TODO: reference docs for auto-augment here
To enable auto augmentation during training -
python train.py ./imagenette2-320 --aa 'v0'
A brief introduction about augmix has been presented here. To enable augmix during training, simply do:
python train.py ./imagenette2-320 --aug-splits 3 --jsd
timm also supports augmix with RandAugment and AutoAugment like so:
python train.py ./imagenette2-320 --aug-splits 3 --jsd --aa rand-m9-mstd0.5-inc1
To train models on multiple GPUs, simply replace python train.py with ./distributed_train.sh <num-gpus> like so:
./distributed_train.sh 4 ./imagenette2-320 --aug-splits 3 --jsd
This trains the model using AugMix data augmentation on 4 GPUs.
To enable mixed precision training, simply add the --amp flag. timm will automatically implement mixed precision training either using apex or PyTorch Native mixed precision training.
python train.py ../imagenette2-320 --aug-splits 3 --jsd --amp
SplitBatchNorm¶#hide
# TODO: refer blog post or tutorial here
From the paper,
Batch normalization serves as an essential component for many state-of-the-art computer vision models. Specifically, BN normalizes input features by the mean and variance computed within each mini-batch. **One intrinsic assumption of utilizing BN is that the input features should come from a single or similar distributions.** This normalization behavior could be problematic if the mini-batch contains data from different distributions, there- fore resulting in inaccurate statistics estimation.
To disentangle this mixture distribution into two simpler ones respectively for the clean and adversarial images, we hereby propose an auxiliary BN to guarantee its normalization statistics are exclusively preformed on the adversarial examples.
To enable split batch norm,
python train.py ./imagenette2-320 --aug-splits 3 --aa rand-m9-mstd0.5-inc1 --split-bn
Using the above command, timm now has separate batch normalization layer for each augmentation split.
Synchronized batch norm is only used when training on multiple GPUs. From papers with code:
Synchronized Batch Normalization (SyncBN) is a type of batch normalization used for multi-GPU training. Standard batch normalization only normalizes the data within each device (GPU). SyncBN normalizes the input within the whole mini-batch.
To enable, simply add --sync-bn flag like so:
./distributed_train.sh 4 ../imagenette2-320 --aug-splits 3 --jsd --sync-bn
To enable either mixup or cutmix, simply add the --mixup or --cutmix flag with alpha value.
Default probability of applying the augmentation is 1.0. If you need to change it, use --mixup-prob argument with new value.
For example, to enable mixup,
train.py ../imagenette2-320 --mixup 0.5
train.py ../imagenette2-320 --mixup 0.5 --mixup-prob 0.7
Or for Cutmix,
train.py ../imagenette2-320 --cutmix 0.5
train.py ../imagenette2-320 --cutmix 0.5 --mixup-prob 0.7
It is also possible to enable both,
python train.py ../imagenette2-320 --mixup 0.5 --cutmix 0.5 --mixup-switch-prob 0.3
The above command will use either Mixup or Cutmix as data augmentation techniques and apply it to the batch with 50% probability. It will also switch between the two with 30% probability (Mixup - 70%, 30% switch to Cutmix).
There is also a parameter to turn off Mixup/Cutmix augmentation at a certail epoch:
python train.py ../imagenette2-320 --mixup 0.5 --cutmix 0.5 --mixup-switch-prob 0.3 --mixup-off-epoch 10
The above command only applies the Mixup/Cutmix data augmentation for the first 10 epochs.