# Example usage: python run_fastai.py /home/paperspace/ILSVRC/Data/CLS-LOC/ -a resnext_50_32x4d --epochs 1 -j 4 -b 64 --fp16
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--save-dir', type=str, default=Path.home()/'imagenet_training',
help='Directory to save logs and models.')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--cycle-len', default=1, type=float, metavar='N',
help='Length of cycle to run')
# parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
# help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
# parser.add_argument('--print-freq', '-p', default=10, type=int,
# metavar='N', help='print frequency (default: 10)')
# parser.add_argument('--resume', default='', type=str, metavar='PATH',
# help='path to latest checkpoint (default: none)')
# parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
# help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model')
parser.add_argument('--fp16', action='store_true', help='Run model fp16 mode.')
parser.add_argument('--use-tta', action='store_true', help='Validate model with TTA at the end of traiing.')
parser.add_argument('--train-128', action='store_true', help='Train model on 128. TODO: allow custom epochs and LR')
parser.add_argument('--sz', default=224, type=int, help='Size of transformed image.')
# parser.add_argument('--decay-int', default=30, type=int, help='Decay LR by 10 every decay-int epochs')
parser.add_argument('--use_clr', type=str,
help='div,pct,max_mom,min_mom. Pass in a string delimited by commas. Ex: "20,2,0.95,0.85"')
parser.add_argument('--loss-scale', type=float, default=1,
help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
parser.add_argument('--prof', dest='prof', action='store_true', help='Only run a few iters for profiling.')
parser.add_argument('--dist-url', default='file://sync.file', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
parser.add_argument('--world-size', default=1, type=int,
help='Number of GPUs to use. Can either be manually set ' +
'or automatically set by using \'python -m multiproc\'.')
parser.add_argument('--rank', default=0, type=int,
help='Used for multi-process training. Can either be manually set ' +
'or automatically set by using \'python -m multiproc\'.')
def fast_loader(data_path, size):
aug_tfms = [
RandomFlip(),
# RandomRotate(4),
# RandomLighting(0.05, 0.05),
RandomCrop(size)
]
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
cifar10_stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159]))
tfms = tfms_from_stats(cifar10_stats, size, aug_tfms=aug_tfms, pad=args.sz//8)
data = ImageClassifierData.from_paths(data_path, val_name='test', tfms=tfms,
bs=args.batch_size, num_workers=args.workers)
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(data.trn_dl)
else:
train_sampler = None
# TODO: Need to test train_sampler on distributed machines
# Use pytorch default data loader. 20% faster
data.trn_dl = torch.utils.data.DataLoader(
data.trn_ds, batch_size=data.bs, shuffle=(train_sampler is None),
num_workers=data.num_workers, pin_memory=True, sampler=train_sampler)
data.trn_dl = DataPrefetcher(data.trn_dl)
data.val_dl = torch.utils.data.DataLoader(
data.val_ds,
batch_size=data.bs, shuffle=False,
num_workers=data.num_workers, pin_memory=True)
data.val_dl = DataPrefetcher(data.val_dl, stop_early=args.prof)
return data, train_sampler
# Seems to speed up training by ~2%
class DataPrefetcher():
def __init__(self, loader, stop_early=False):
self.loader = loader
self.dataset = loader.dataset
self.stream = torch.cuda.Stream()
self.stop_early = stop_early
self.next_input = None
self.next_target = None
def __len__(self):
return len(self.loader)
def preload(self):
try:
self.next_input, self.next_target = next(self.loaditer)
except StopIteration:
self.next_input = None
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(async=True)
self.next_target = self.next_target.cuda(async=True)
def __iter__(self):
count = 0
self.loaditer = iter(self.loader)
self.preload()
while self.next_input is not None:
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
target = self.next_target
self.preload()
count += 1
yield input, target
if self.stop_early and (count > 50):
break
# Taken from main.py topk accuracy
def top5(output, target):
"""Computes the precision@k for the specified values of k"""
batch_size = target.size(0)
_, pred = output.topk(5, 1, True, True)
pred = pred.t()
return pred.eq(target.view(1, -1).expand_as(pred)).sum()/batch_size
class ValLoggingCallback(Callback):
def __init__(self, save_path):
super().__init__()
self.save_path=save_path
def on_train_begin(self):
self.batch = 0
self.epoch = 0
self.f = open(self.save_path, "a", 1)
def on_epoch_end(self, metrics):
log_str = f'\tEpoch:{self.epoch}\ttrn_loss:{self.last_loss}'
for (k,v) in zip(['val_loss', 'acc', 'top5', ''], metrics): log_str += f'\t{k}:{v}'
self.log(log_str)
self.epoch += 1
def on_batch_end(self, metrics):
self.last_loss = metrics
self.batch += 1
def on_train_end(self):
self.log("\ton_train_end")
self.f.close()
def log(self, string):
self.f.write(time.strftime("%Y-%m-%dT%H:%M:%S")+"\t"+string+"\n")
# Logging + saving models
def save_args(name, save_dir):
if (args.rank != 0) or not args.save_dir: return {}
log_dir = f'{save_dir}/training_logs'
os.makedirs(log_dir, exist_ok=True)
return {
'best_save_name': f'{name}_best_model',
'cycle_save_name': f'{name}',
'callbacks': [
LoggingCallback(f'{log_dir}/{name}_log.txt'),
ValLoggingCallback(f'{log_dir}/{name}_val_log.txt')
]
}
def save_sched(sched, save_dir):
if (args.rank != 0) or not args.save_dir: return {}
log_dir = f'{save_dir}/training_logs'
sched.save_path = log_dir
sched.plot_loss()
sched.plot_lr()
def update_model_dir(learner, base_dir):
learner.tmp_path = f'{base_dir}/tmp'
os.makedirs(learner.tmp_path, exist_ok=True)
learner.models_path = f'{base_dir}/models'
os.makedirs(learner.models_path, exist_ok=True)