%matplotlib inline
%reload_ext autoreload
%autoreload 2
import os
os.chdir('..')
from nb_007 import *
from models import wrn_22
torch.backends.cudnn.benchmark = True
cifar_norm,cifar_denorm = normalize_funcs(*cifar_stats)
tfms = ([pad(padding=4), crop(size=32, row_pct=(0,1), col_pct=(0,1)), flip_lr(p=0.5)], [])
data = data_from_imagefolder(Path('data/cifar10'), valid='test', ds_tfms=tfms, tfms=cifar_norm)
learn = Learner(data, wrn_22(), metrics=accuracy).to_fp16()
learn.fit_one_cycle(25, 3e-3, wd=0.4, pct_start=0.45)
%time learn.fit_one_cycle(30, 3e-3, wd=0.4)
learn = Learner(data, wrn_22(), metrics=accuracy)
learn.fit_one_cycle(1, 3e-3, wd=0.4)