%reload_ext autoreload %autoreload 2 from nb_005 import * PATH = Path('data/stl10') data_mean, data_std = map(tensor, ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])) data_norm,data_denorm = normalize_funcs(data_mean,data_std) train_ds = ImageDataset.from_folder(PATH/'train') valid_ds = ImageDataset.from_folder(PATH/'valid') x=valid_ds[0][0] x.show() x.shape size=96 tfms = get_transforms(do_flip=True, max_rotate=10, max_lighting=0.2, max_warp=0.15, max_zoom=1.2) # tfms = get_transforms(do_flip=True, max_rotate=10, max_lighting=0.2) tds = transform_datasets(train_ds, valid_ds, tfms, size=size)#, padding_mode='zeros') data = DataBunch.create(*tds, bs=32, num_workers=8, tfms=data_norm) (x,y) = next(iter(data.valid_dl)) _,axs = plt.subplots(4,4,figsize=(12,12)) for i,ax in enumerate(axs.flatten()): show_image(data_denorm(x[i].cpu()), ax) (x,y) = next(iter(data.train_dl)) _,axs = plt.subplots(4,4,figsize=(12,12)) for i,ax in enumerate(axs.flatten()): show_image(data_denorm(x[i].cpu()), ax) _,axs = plt.subplots(4,4,figsize=(12,12)) for i,ax in enumerate(axs.flat): tds[0][1][0].show(ax) from torchvision.models import resnet18, resnet34, resnet50 arch = resnet50 lr = 2e-3 learn = ConvLearner(data, arch, 2 , wd=1e-2) #, train_bn=False #, callback_fns=[BnFreeze] # , opt_fn=partial(optim.SGD, momentum=0.9)) learn.metrics = [accuracy] learn.split(lambda m: (m[0][6], m[1])) lr_find(learn) learn.recorder.plot() learn.fit_one_cycle(6, slice(lr)) learn.save('0') learn.load('0') learn.unfreeze() lr = 2e-4 learn.fit_one_cycle(6, slice(lr/100,lr), pct_start=0.05) learn.save('1') learn.recorder.plot_losses() import pandas as pd csv = pd.read_csv(PATH/'default.csv') is_valid = csv['2']=='valid' valid_df,train_df = csv[is_valid],csv[~is_valid] len(valid_df),len(train_df) len(valid_ds) train_fns,train_lbls,valid_fns,valid_lbls = map(np.array, (train_df['0'],train_df['1'],valid_df['0'],valid_df['1'])) train_fns = [PATH/o for o in train_fns] valid_fns = [PATH/o for o in valid_fns] train_ds = ImageDataset(train_fns,train_lbls) valid_ds = ImageDataset(valid_fns,valid_lbls, classes=train_ds.classes)