%matplotlib inline
%reload_ext autoreload
%autoreload 2
import argparse
import os
import shutil
import time
from fastai.transforms import *
from fastai.dataset import *
from fastai.fp16 import *
from fastai.conv_learner import *
from pathlib import *
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import models
import models.cifar10 as cifar10models
from distributed import DistributedDataParallel as DDP
# print(models.cifar10.__dict__)
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
cifar10_names = sorted(name for name in cifar10models.__dict__
if name.islower() and not name.startswith("__")
and callable(cifar10models.__dict__[name]))
model_names = cifar10_names + model_names
print(model_names)
# Example usage: python run_fastai.py /home/paperspace/ILSVRC/Data/CLS-LOC/ -a resnext_50_32x4d --epochs 1 -j 4 -b 64 --fp16
parser = argparse.ArgumentParser(description='PyTorch Cifar10 Training')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--save-dir', type=str, default=Path.home()/'imagenet_training',
help='Directory to save logs and models.')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet56',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet56)')
parser.add_argument('-j', '--workers', default=7, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=1, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--cycle-len', default=95, type=float, metavar='N',
help='Length of cycle to run')
# parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
# help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=512, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.8, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
# parser.add_argument('--print-freq', '-p', default=10, type=int,
# metavar='N', help='print frequency (default: 10)')
# parser.add_argument('--resume', default='', type=str, metavar='PATH',
# help='path to latest checkpoint (default: none)')
# parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
# help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model')
parser.add_argument('--fp16', action='store_true', help='Run model fp16 mode.')
parser.add_argument('--use-tta', default=True, type=bool, help='Validate model with TTA at the end of traiing.')
parser.add_argument('--train-half', action='store_true', help='Train model on half images. TODO: allow custom epochs and LR')
parser.add_argument('--sz', default=32, type=int, help='Size of transformed image.')
# parser.add_argument('--decay-int', default=30, type=int, help='Decay LR by 10 every decay-int epochs')
parser.add_argument('--use-clr', default='10,13.68,0.95,0.85', type=str,
help='div,pct,max_mom,min_mom. Pass in a string delimited by commas. Ex: "20,2,0.95,0.85"')
parser.add_argument('--loss-scale', type=float, default=128,
help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
parser.add_argument('--prof', dest='prof', action='store_true', help='Only run a few iters for profiling.')
parser.add_argument('--dist-url', default='file://sync.file', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
parser.add_argument('--world-size', default=1, type=int,
help='Number of GPUs to use. Can either be manually set ' +
'or automatically set by using \'python -m multiproc\'.')
parser.add_argument('--rank', default=0, type=int,
help='Used for multi-process training. Can either be manually set ' +
'or automatically set by using \'python -m multiproc\'.')
class TorchModelData(ModelData):
def __init__(self, path, trn_dl, val_dl, aug_dl=None):
super().__init__(path, trn_dl, val_dl)
self.aug_dl = aug_dl
def torch_loader(data_path, size):
# Data loading code
traindir = os.path.join(data_path, 'train')
valdir = os.path.join(data_path, 'test')
normalize = transforms.Normalize(mean=[0.4914 , 0.48216, 0.44653], std=[0.24703, 0.24349, 0.26159])
scale_size = 40
padding = int((scale_size - size) / 2)
train_tfms = transforms.Compose([
transforms.RandomCrop(size, padding=padding),
transforms.ColorJitter(.2,.2,.2),
# transforms.RandomRotation(2),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
train_dataset = datasets.ImageFolder(traindir, train_tfms)
train_sampler = (torch.utils.data.distributed.DistributedSampler(train_dataset)
if args.distributed else None)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
val_tfms = transforms.Compose([
# transforms.Resize(int(size*1.14)),
# transforms.CenterCrop(size),
transforms.ToTensor(),
normalize,
])
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, val_tfms),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
aug_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, train_tfms),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
train_loader = DataPrefetcher(train_loader)
val_loader = DataPrefetcher(val_loader)
aug_loader = DataPrefetcher(aug_loader)
if args.prof:
train_loader.stop_after = 200
val_loader.stop_after = 0
data = TorchModelData(data_path, train_loader, val_loader, aug_loader)
return data, train_sampler
# Seems to speed up training by ~2%
class DataPrefetcher():
def __init__(self, loader, stop_after=None):
self.loader = loader
self.dataset = loader.dataset
self.stream = torch.cuda.Stream()
self.stop_after = stop_after
self.next_input = None
self.next_target = None
def __len__(self):
return len(self.loader)
def preload(self):
try:
self.next_input, self.next_target = next(self.loaditer)
except StopIteration:
self.next_input = None
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(async=True)
self.next_target = self.next_target.cuda(async=True)
def __iter__(self):
count = 0
self.loaditer = iter(self.loader)
self.preload()
while self.next_input is not None:
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
target = self.next_target
self.preload()
count += 1
yield input, target
if type(self.stop_after) is int and (count > self.stop_after):
break
def top5(output, target):
"""Computes the precision@k for the specified values of k"""
top5 = 5
batch_size = target.size(0)
_, pred = output.topk(top5, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
correct_k = correct[:top5].view(-1).float().sum(0, keepdim=True)
return correct_k.mul_(1.0 / batch_size)
class ImagenetLoggingCallback(Callback):
def __init__(self, save_path, print_every=50):
super().__init__()
self.save_path=save_path
self.print_every=print_every
def on_train_begin(self):
self.batch = 0
self.epoch = 0
self.f = open(self.save_path, "a", 1)
self.log("\ton_train_begin")
def on_epoch_end(self, metrics):
log_str = f'\tEpoch:{self.epoch}\ttrn_loss:{self.last_loss}'
for (k,v) in zip(['val_loss', 'acc', 'top5', ''], metrics): log_str += f'\t{k}:{v}'
self.log(log_str)
self.epoch += 1
def on_batch_end(self, metrics):
self.last_loss = metrics
self.batch += 1
if self.batch % self.print_every == 0:
self.log(f'Epoch: {self.epoch} Batch: {self.batch} Metrics: {metrics}')
def on_train_end(self):
self.log("\ton_train_end")
self.f.close()
def log(self, string):
self.f.write(time.strftime("%Y-%m-%dT%H:%M:%S")+"\t"+string+"\n")
# Logging + saving models
def save_args(name, save_dir):
if (args.rank != 0) or not args.save_dir: return {}
log_dir = f'{save_dir}/training_logs'
os.makedirs(log_dir, exist_ok=True)
return {
'best_save_name': f'{name}_best_model',
'cycle_save_name': f'{name}',
'callbacks': [
ImagenetLoggingCallback(f'{log_dir}/{name}_log.txt')
]
}
def save_sched(sched, save_dir):
if (args.rank != 0) or not args.save_dir: return {}
log_dir = f'{save_dir}/training_logs'
sched.save_path = log_dir
sched.plot_loss()
sched.plot_lr()
def update_model_dir(learner, base_dir):
learner.tmp_path = f'{base_dir}/tmp'
os.makedirs(learner.tmp_path, exist_ok=True)
learner.models_path = f'{base_dir}/models'
os.makedirs(learner.models_path, exist_ok=True)
['resnet56', 'resnext29_16_64', 'resnext29_8_64', 'dpn107', 'dpn131', 'dpn68', 'dpn92', 'dpn98', 'inceptionresnetv2', 'inceptionresnetv2_conc', 'inceptionv4', 'load', 'load_block17', 'load_block35', 'load_block8', 'load_conv2d', 'load_conv2d_nobn', 'load_linear', 'load_mixed_4a_7a', 'load_mixed_5', 'load_mixed_5b', 'load_mixed_6', 'load_mixed_6a', 'load_mixed_7', 'load_mixed_7a', 'nasnetalarge', 'pre_resnet101', 'pre_resnet152', 'pre_resnet18', 'pre_resnet34', 'pre_resnet50', 'reduce', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'resnext101', 'resnext152', 'resnext18', 'resnext34', 'resnext50', 'resnext_101_32x4d', 'resnext_101_64x4d', 'resnext_50_32x4d', 'se_resnet_101', 'se_resnet_152', 'se_resnet_18', 'se_resnet_34', 'se_resnet_50', 'se_resnet_50_conc', 'se_resnext_101', 'se_resnext_152', 'se_resnext_50', 'test', 'test_block17', 'test_block35', 'test_block8', 'test_conv2d', 'test_conv2d_nobn', 'test_mixed_4a_7a', 'test_mixed_5b', 'test_mixed_6a', 'test_mixed_7a', 'wrn_50_2f']
args_input = [
'/home/paperspace/imagenet-fast/fp16/data/cifar10',
'--save-dir', '/home/paperspace/data/cifar_training/test1',
# '-a', 'resnext29_8_64',
# '-j', '6',
# '--prof',
'-b', '512',
# '--sz', '32',
# '--loss-scale', '128',
'--fp16',
'--cycle-len', '110',
# '--epochs', '1',
# '--use-clr', '10,13.68,0.95,0.85',
'--wd', '2e-4',
'--lr', '1',
# '--train-half' # With fp16, iterations are so fast this doesn't matter
]
# This is important for speed
cudnn.benchmark = True
global arg
args = parser.parse_args(args_input); args
if args.cycle_len > 1: args.cycle_len = int(args.cycle_len)
args.distributed = args.world_size > 1
args.gpu = 0
if args.distributed:
args.gpu = args.rank % torch.cuda.device_count()
if args.distributed:
torch.cuda.set_device(args.gpu)
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size)
if args.fp16:
assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
# create model
model = cifar10models.__dict__[args.arch] if args.arch in cifar10_names else models.__dict__[args.arch]
if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
model = model(pretrained=True)
else:
print("=> creating model '{}'".format(args.arch))
model = model()
=> creating model 'resnet56'
model = model.cuda()
if args.distributed:
model = DDP(model)
if args.train_half:
data, train_sampler = torch_loader(args.data, 16)
else:
data, train_sampler = torch_loader(args.data, args.sz)
learner = Learner.from_model_data(model, data)
# learner.crit = F.nll_loss
learner.crit = F.cross_entropy
learner.metrics = [accuracy]
if args.fp16: learner.half()
if args.prof:
args.epochs = 1
args.cycle_len=.01
if args.use_clr:
args.use_clr = tuple(map(float, args.use_clr.split(',')))
# x,y = next(iter(data.trn_dl))
# plt.imshow(np.transpose(x[50], (1, 2, 0)))
# %pdb off
# 128x128
if args.train_half:
save_dir = args.save_dir+'/128'
update_model_dir(learner, save_dir)
sargs = save_args('first_run_128', save_dir)
learner.fit(args.lr,args.epochs, cycle_len=45,
train_sampler=train_sampler,
wds=args.weight_decay,
use_clr_beta=args.use_clr,
loss_scale=args.loss_scale,
**sargs
)
save_sched(learner.sched, save_dir)
data, train_sampler = torch(args.data, args.sz)
learner.set_data(data)
# Full size
update_model_dir(learner, args.save_dir)
sargs = save_args('first_run', args.save_dir)
learner.fit(args.lr,args.epochs, cycle_len=args.cycle_len,
sampler=train_sampler,
wds=args.weight_decay,
use_clr_beta=args.use_clr,
loss_scale=args.loss_scale,
**sargs
)
save_sched(learner.sched, args.save_dir)
print('Finished!')
HBox(children=(IntProgress(value=0, description='Epoch', max=110), HTML(value='')))
epoch trn_loss val_loss accuracy
0 1.812441 1.859512 0.3174
1 1.623513 1.485934 0.4409
2 1.43354 1.352941 0.5064
3 1.20538 1.823247 0.44
4 1.047512 1.031067 0.6364
5 0.903786 0.990384 0.6529
6 0.794144 0.937554 0.6923
7 0.719986 0.794384 0.7304
8 0.663548 0.824012 0.7209
9 0.616721 0.684154 0.7641
10 0.599153 0.760171 0.7422
11 0.569865 0.973621 0.6983
12 0.544941 0.826669 0.7262
13 0.524519 0.852246 0.7406
14 0.513213 0.753155 0.7662
15 0.497413 0.846931 0.7504
16 0.477519 0.668112 0.7875
17 0.473021 0.655387 0.783
18 0.466157 0.857789 0.7328
19 0.450976 0.666845 0.7821
20 0.443292 0.860191 0.7377
21 0.439205 1.131203 0.6821
22 0.446406 1.028122 0.7189
23 0.426874 1.017566 0.7207
24 0.427285 0.96748 0.7216
25 0.425151 0.703964 0.7811
26 0.418674 0.983397 0.7092
27 0.418991 0.744101 0.7635
28 0.415415 0.681863 0.7864
29 0.413028 0.61349 0.8049
30 0.41116 0.649954 0.789
31 0.395728 1.480305 0.6203
32 0.402052 0.689628 0.7703
33 0.39418 1.019106 0.7183
34 0.402665 0.668238 0.7811
35 0.408001 0.809439 0.7619
36 0.397961 0.859566 0.7617
37 0.394969 0.698795 0.7666
38 0.390791 0.685463 0.7879
39 0.393455 0.872631 0.7195
40 0.391663 0.686177 0.7817
41 0.395742 0.608006 0.8057
42 0.382299 0.761684 0.7698
43 0.390123 0.580282 0.8171
44 0.383594 0.948203 0.7183
45 0.386429 1.288566 0.675
46 0.386106 0.783413 0.7616
47 0.387071 0.907251 0.7336
48 0.376726 0.665081 0.8034
49 0.389123 0.749052 0.7737
50 0.383904 0.616589 0.801
51 0.374272 0.730153 0.7736
52 0.371789 0.712743 0.7745
53 0.377667 0.703202 0.7764
54 0.36974 0.61975 0.802
55 0.358827 0.551073 0.8208
56 0.356373 0.862586 0.7476
57 0.366815 0.699805 0.7845
58 0.358783 0.637675 0.8054
59 0.357312 0.856862 0.7544
60 0.353695 0.657164 0.7961
61 0.349274 0.767159 0.7683
62 0.346976 0.767261 0.7736
63 0.340436 0.628997 0.7969
64 0.341639 0.806389 0.7749
65 0.334653 0.665875 0.7943
66 0.34402 0.803923 0.7592
67 0.339453 0.633359 0.8014
68 0.332211 0.563867 0.8298
69 0.33266 0.641863 0.7961
70 0.329087 0.530855 0.8222
71 0.323273 0.578356 0.8087
72 0.319474 0.788859 0.7752
73 0.315593 0.51844 0.8351
74 0.314834 0.528026 0.8364
75 0.310698 0.609739 0.8137
76 0.310191 0.567951 0.82
77 0.307431 0.445347 0.8574
78 0.301006 0.643344 0.8068
79 0.301107 0.801222 0.7837
80 0.285171 0.49959 0.835
81 0.282677 0.486095 0.8435
82 0.271954 0.684845 0.8037
83 0.270613 0.485216 0.8463
84 0.262175 0.629983 0.8131
85 0.267298 0.418199 0.8652
86 0.253407 0.545097 0.8377
87 0.250527 0.376923 0.8776
88 0.252519 0.616157 0.8102
89 0.230011 0.489934 0.8478
90 0.218822 0.430929 0.8712
91 0.203619 0.395288 0.8746
92 0.192234 0.372025 0.8803
93 0.180982 0.338984 0.8912
94 0.167879 0.350927 0.8904
95 0.155493 0.402861 0.8785
96 0.148298 0.349918 0.8914
97 0.136534 0.316343 0.9028
98 0.124189 0.3294 0.906
99 0.113916 0.31101 0.9081
100 0.107916 0.299103 0.9114
101 0.097518 0.290329 0.9201
102 0.088652 0.305113 0.9159
103 0.076407 0.297681 0.9155
104 0.065395 0.297327 0.9162
105 0.055489 0.270902 0.9239
106 0.04417 0.259012 0.9284
107 0.034333 0.247566 0.9341
108 0.028114 0.247997 0.9339
109 0.022313 0.247694 0.9345
Finished!
learner.save('cifar10-resnext-lr1')
learner.sched.plot()
learner.lr_find()
learner.sched.plot()
learner.fit(1e-5,1, cycle_len=15,
sampler=train_sampler,
wds=args.weight_decay,
loss_scale=args.loss_scale,
**sargs
)
if args.use_tta:
log_preds,y = learner.TTA()
preds = np.mean(np.exp(log_preds),0)
acc = accuracy(torch.FloatTensor(preds),torch.LongTensor(y))
print('TTA acc:', acc)
with open(args.save_dir+'/tta_accuracy.txt', "a", 1) as f:
f.write(time.strftime("%Y-%m-%dT%H:%M:%S")+f"\tTTA accuracty: {acc}\n")
TTA acc: 0.9389
if args.use_tta:
log_preds,y = learner.TTA()
preds = np.mean(np.exp(log_preds),0)
acc = accuracy(torch.FloatTensor(preds),torch.LongTensor(y))
print('TTA acc:', acc)
with open(args.save_dir+'/tta_accuracy.txt', "a", 1) as f:
f.write(time.strftime("%Y-%m-%dT%H:%M:%S")+f"\tTTA accuracty: {acc}\n")
TTA acc: 0.9402