%reload_ext autoreload
%autoreload 2
from nb_005 import *
PATH = Path('data/stl10')
data_mean, data_std = map(tensor, ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
data_norm,data_denorm = normalize_funcs(data_mean,data_std)
train_ds = ImageDataset.from_folder(PATH/'train')
valid_ds = ImageDataset.from_folder(PATH/'valid')
x=valid_ds[0][0]
x.show()
x.shape
size=96
tfms = get_transforms(do_flip=True, max_rotate=10, max_lighting=0.2, max_warp=0.15, max_zoom=1.2)
# tfms = get_transforms(do_flip=True, max_rotate=10, max_lighting=0.2)
tds = transform_datasets(train_ds, valid_ds, tfms, size=size)#, padding_mode='zeros')
data = DataBunch.create(*tds, bs=32, num_workers=8, tfms=data_norm)
(x,y) = next(iter(data.valid_dl))
_,axs = plt.subplots(4,4,figsize=(12,12))
for i,ax in enumerate(axs.flatten()): show_image(data_denorm(x[i].cpu()), ax)
(x,y) = next(iter(data.train_dl))
_,axs = plt.subplots(4,4,figsize=(12,12))
for i,ax in enumerate(axs.flatten()): show_image(data_denorm(x[i].cpu()), ax)
_,axs = plt.subplots(4,4,figsize=(12,12))
for i,ax in enumerate(axs.flat): tds[0][1][0].show(ax)
from torchvision.models import resnet18, resnet34, resnet50
arch = resnet50
lr = 2e-3
learn = ConvLearner(data, arch, 2 , wd=1e-2) #, train_bn=False #, callback_fns=[BnFreeze]
# , opt_fn=partial(optim.SGD, momentum=0.9))
learn.metrics = [accuracy]
learn.split(lambda m: (m[0][6], m[1]))
lr_find(learn)
learn.recorder.plot()
learn.fit_one_cycle(6, slice(lr))
learn.save('0')
learn.load('0')
learn.unfreeze()
lr = 2e-4
learn.fit_one_cycle(6, slice(lr/100,lr), pct_start=0.05)
learn.save('1')
learn.recorder.plot_losses()
import pandas as pd
csv = pd.read_csv(PATH/'default.csv')
is_valid = csv['2']=='valid'
valid_df,train_df = csv[is_valid],csv[~is_valid]
len(valid_df),len(train_df)
len(valid_ds)
train_fns,train_lbls,valid_fns,valid_lbls = map(np.array,
(train_df['0'],train_df['1'],valid_df['0'],valid_df['1']))
train_fns = [PATH/o for o in train_fns]
valid_fns = [PATH/o for o in valid_fns]
train_ds = ImageDataset(train_fns,train_lbls)
valid_ds = ImageDataset(valid_fns,valid_lbls, classes=train_ds.classes)