#|default_exp augment #|export import torch,random import fastcore.all as fc from torch import nn from torch.nn import init from miniai.datasets import * from miniai.conv import * from miniai.learner import * from miniai.activations import * from miniai.init import * from miniai.sgd import * from miniai.resnet import * import pickle,gzip,math,os,time,shutil import matplotlib as mpl,numpy as np,matplotlib.pyplot as plt from collections.abc import Mapping from pathlib import Path from operator import attrgetter,itemgetter from functools import partial from copy import copy from contextlib import contextmanager import torchvision.transforms.functional as TF,torch.nn.functional as F from torch import tensor,optim from torch.utils.data import DataLoader,default_collate from torch.optim import lr_scheduler from torcheval.metrics import MulticlassAccuracy from datasets import load_dataset,load_dataset_builder from fastcore.test import test_close from torch import distributions torch.set_printoptions(precision=2, linewidth=140, sci_mode=False) torch.manual_seed(1) mpl.rcParams['image.cmap'] = 'gray_r' import logging logging.disable(logging.WARNING) set_seed(42) if fc.defaults.cpus>8: fc.defaults.cpus=8 xl,yl = 'image','label' name = "fashion_mnist" bs = 1024 xmean,xstd = 0.28, 0.35 @inplace def transformi(b): b[xl] = [(TF.to_tensor(o)-xmean)/xstd for o in b[xl]] dsd = load_dataset(name) tds = dsd.with_transform(transformi) dls = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus) metrics = MetricsCB(accuracy=MulticlassAccuracy()) astats = ActivationStats(fc.risinstance(GeneralRelu)) cbs = [DeviceCB(), metrics, ProgressCB(plot=True), astats] act_gr = partial(GeneralRelu, leak=0.1, sub=0.4) iw = partial(init_weights, leaky=0.1) set_seed(42) lr,epochs = 6e-2,5 def get_model(act=nn.ReLU, nfs=(16,32,64,128,256,512), norm=nn.BatchNorm2d): layers = [ResBlock(1, 16, ks=5, stride=1, act=act, norm=norm)] layers += [ResBlock(nfs[i], nfs[i+1], act=act, norm=norm, stride=2) for i in range(len(nfs)-1)] layers += [nn.Flatten(), nn.Linear(nfs[-1], 10, bias=False), nn.BatchNorm1d(10)] return nn.Sequential(*layers) lr = 1e-2 tmax = epochs * len(dls.train) sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax) xtra = [BatchSchedCB(sched)] model = get_model(act_gr, norm=nn.BatchNorm2d).apply(iw) learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW) learn.fit(epochs) class GlobalAvgPool(nn.Module): def forward(self, x): return x.mean((-2,-1)) def get_model2(act=nn.ReLU, nfs=(16,32,64,128,256), norm=nn.BatchNorm2d): layers = [ResBlock(1, 16, ks=5, stride=1, act=act, norm=norm)] layers += [ResBlock(nfs[i], nfs[i+1], act=act, norm=norm, stride=2) for i in range(len(nfs)-1)] layers += [ResBlock(256, 512, act=act, norm=norm), GlobalAvgPool()] layers += [nn.Linear(512, 10, bias=False), nn.BatchNorm1d(10)] return nn.Sequential(*layers) #|export def _flops(x, h, w): if x.dim()<3: return x.numel() if x.dim()==4: return x.numel()*h*w @fc.patch def summary(self:Learner): res = '|Module|Input|Output|Num params|MFLOPS|\n|--|--|--|--|--|\n' totp,totf = 0,0 def _f(hook, mod, inp, outp): nonlocal res,totp,totf nparms = sum(o.numel() for o in mod.parameters()) totp += nparms *_,h,w = outp.shape flops = sum(_flops(o, h, w) for o in mod.parameters())/1e6 totf += flops res += f'|{type(mod).__name__}|{tuple(inp[0].shape)}|{tuple(outp.shape)}|{nparms}|{flops:.1f}|\n' with Hooks(self.model, _f) as hooks: self.fit(1, lr=1, cbs=SingleBatchCB()) print(f"Tot params: {totp}; MFLOPS: {totf:.1f}") if fc.IN_NOTEBOOK: from IPython.display import Markdown return Markdown(res) else: print(res) TrainLearner(get_model2(), dls, F.cross_entropy, lr=lr, cbs=[DeviceCB()]).summary() set_seed(42) model = get_model2(act_gr, norm=nn.BatchNorm2d).apply(iw) learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW) learn.fit(epochs) def get_model3(act=nn.ReLU, nfs=(16,32,64,128,256), norm=nn.BatchNorm2d): layers = [ResBlock(1, 16, ks=5, stride=1, act=act, norm=norm)] layers += [ResBlock(nfs[i], nfs[i+1], act=act, norm=norm, stride=2) for i in range(len(nfs)-1)] layers += [GlobalAvgPool(), nn.Linear(256, 10, bias=False), nn.BatchNorm1d(10)] return nn.Sequential(*layers) TrainLearner(get_model3(), dls, F.cross_entropy, lr=lr, cbs=[DeviceCB()]).summary() [o.shape for o in get_model3()[0].parameters()] set_seed(42) model = get_model3(act_gr, norm=nn.BatchNorm2d).apply(iw) learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW) learn.fit(epochs) def get_model4(act=nn.ReLU, nfs=(16,32,64,128,256), norm=nn.BatchNorm2d): layers = [conv(1, 16, ks=5, stride=1, act=act, norm=norm)] layers += [ResBlock(nfs[i], nfs[i+1], act=act, norm=norm, stride=2) for i in range(len(nfs)-1)] layers += [GlobalAvgPool(), nn.Linear(256, 10, bias=False), nn.BatchNorm1d(10)] return nn.Sequential(*layers) [o.shape for o in get_model4()[0].parameters()] TrainLearner(get_model4(), dls, F.cross_entropy, lr=lr, cbs=[DeviceCB()]).summary() set_seed(42) model = get_model4(act_gr, norm=nn.BatchNorm2d).apply(iw) learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW) learn.fit(epochs) from torchvision import transforms def tfm_batch(b, tfm_x=fc.noop, tfm_y = fc.noop): return tfm_x(b[0]),tfm_y(b[1]) tfms = nn.Sequential(transforms.RandomCrop(28, padding=4), transforms.RandomHorizontalFlip()) augcb = BatchTransformCB(partial(tfm_batch, tfm_x=tfms), on_val=False) model = get_model() learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=[SingleBatchCB(), augcb]) learn.fit(1) xb,yb = learn.batch show_images(xb[:16], imsize=1.5) #| export @fc.patch @fc.delegates(show_images) def show_image_batch(self:Learner, max_n=9, cbs=None, **kwargs): self.fit(1, cbs=[SingleBatchCB()]+fc.L(cbs)) show_images(self.batch[0][:max_n], **kwargs) learn.show_image_batch(max_n=16, imsize=(1.5)) tfms = nn.Sequential(transforms.RandomCrop(28, padding=1), transforms.RandomHorizontalFlip()) augcb = BatchTransformCB(partial(tfm_batch, tfm_x=tfms), on_val=False) set_seed(42) epochs = 20 lr = 1e-2 tmax = epochs * len(dls.train) sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax) xtra = [BatchSchedCB(sched), augcb] model = get_model(act_gr, norm=nn.BatchNorm2d).apply(iw) learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW) learn.fit(epochs) mdl_path = Path('models') mdl_path.mkdir(exist_ok=True) torch.save(learn.model, mdl_path/'data_aug.pkl') #| export class CapturePreds(Callback): def before_fit(self, learn): self.all_inps,self.all_preds,self.all_targs = [],[],[] def after_batch(self, learn): self.all_inps. append(to_cpu(learn.batch[0])) self.all_preds.append(to_cpu(learn.preds)) self.all_targs.append(to_cpu(learn.batch[1])) def after_fit(self, learn): self.all_preds,self.all_targs,self.all_inps = map(torch.cat, [self.all_preds,self.all_targs,self.all_inps]) #| export @fc.patch def capture_preds(self: Learner, cbs=None, inps=False): cp = CapturePreds() self.fit(1, train=False, cbs=[cp]+fc.L(cbs)) res = cp.all_preds,cp.all_targs if inps: res = res+(cp.all_inps,) return res ap1, at = learn.capture_preds() ttacb = BatchTransformCB(partial(tfm_batch, tfm_x=TF.hflip), on_val=True) ap2, at = learn.capture_preds(cbs=[ttacb]) ap1.shape,ap2.shape,at.shape ap = torch.stack([ap1,ap2]).mean(0).argmax(1) round((ap==at).float().mean().item(), 3) xb,_ = next(iter(dls.train)) xbt = xb[:16] xm,xs = xbt.mean(),xbt.std() xbt.min(), xbt.max() pct = 0.2 szx = int(pct*xbt.shape[-2]) szy = int(pct*xbt.shape[-1]) stx = int(random.random()*(1-pct)*xbt.shape[-2]) sty = int(random.random()*(1-pct)*xbt.shape[-1]) stx,sty,szx,szy init.normal_(xbt[:,:,stx:stx+szx,sty:sty+szy], mean=xm, std=xs); show_images(xbt, imsize=1.5) xbt.min(), xbt.max() #|export def _rand_erase1(x, pct, xm, xs, mn, mx): szx = int(pct*x.shape[-2]) szy = int(pct*x.shape[-1]) stx = int(random.random()*(1-pct)*x.shape[-2]) sty = int(random.random()*(1-pct)*x.shape[-1]) init.normal_(x[:,:,stx:stx+szx,sty:sty+szy], mean=xm, std=xs) x.clamp_(mn, mx) xb,_ = next(iter(dls.train)) xbt = xb[:16] _rand_erase1(xbt, 0.2, xbt.mean(), xbt.std(), xbt.min(), xbt.max()) show_images(xbt, imsize=1.5) xbt.mean(),xbt.std(),xbt.min(), xbt.max() #|export def rand_erase(x, pct=0.2, max_num = 4): xm,xs,mn,mx = x.mean(),x.std(),x.min(),x.max() num = random.randint(0, max_num) for i in range(num): _rand_erase1(x, pct, xm, xs, mn, mx) # print(num) return x xb,_ = next(iter(dls.train)) xbt = xb[:16] rand_erase(xbt, 0.2, 4) show_images(xbt, imsize=1.5) #|export class RandErase(nn.Module): def __init__(self, pct=0.2, max_num=4): super().__init__() self.pct,self.max_num = pct,max_num def forward(self, x): return rand_erase(x, self.pct, self.max_num) tfms = nn.Sequential(transforms.RandomCrop(28, padding=1), transforms.RandomHorizontalFlip(), RandErase()) augcb = BatchTransformCB(partial(tfm_batch, tfm_x=tfms), on_val=False) model = get_model() learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=[DeviceCB(), SingleBatchCB(), augcb]) learn.fit(1) xb,yb = learn.batch show_images(xb[:16], imsize=1.5) epochs = 50 lr = 2e-2 tmax = epochs * len(dls.train) sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax) xtra = [BatchSchedCB(sched), augcb] model = get_model(act_gr, norm=nn.BatchNorm2d).apply(iw) learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW) learn.fit(epochs) xb,_ = next(iter(dls.train)) xbt = xb[:16] szx = int(pct*xbt.shape[-2]) szy = int(pct*xbt.shape[-1]) stx1 = int(random.random()*(1-pct)*xbt.shape[-2]) sty1 = int(random.random()*(1-pct)*xbt.shape[-1]) stx2 = int(random.random()*(1-pct)*xbt.shape[-2]) sty2 = int(random.random()*(1-pct)*xbt.shape[-1]) stx1,sty1,stx2,sty2,szx,szy xbt[:,:,stx1:stx1+szx,sty1:sty1+szy] = xbt[:,:,stx2:stx2+szx,sty2:sty2+szy] show_images(xbt, imsize=1.5) #|export def _rand_copy1(x, pct): szx = int(pct*x.shape[-2]) szy = int(pct*x.shape[-1]) stx1 = int(random.random()*(1-pct)*x.shape[-2]) sty1 = int(random.random()*(1-pct)*x.shape[-1]) stx2 = int(random.random()*(1-pct)*x.shape[-2]) sty2 = int(random.random()*(1-pct)*x.shape[-1]) x[:,:,stx1:stx1+szx,sty1:sty1+szy] = x[:,:,stx2:stx2+szx,sty2:sty2+szy] xb,_ = next(iter(dls.train)) xbt = xb[:16] _rand_copy1(xbt, 0.2) show_images(xbt, imsize=1.5) #|export def rand_copy(x, pct=0.2, max_num = 4): num = random.randint(0, max_num) for i in range(num): _rand_copy1(x, pct) # print(num) return x xb,_ = next(iter(dls.train)) xbt = xb[:16] rand_copy(xbt, 0.2, 4) show_images(xbt, imsize=1.5) #|export class RandCopy(nn.Module): def __init__(self, pct=0.2, max_num=4): super().__init__() self.pct,self.max_num = pct,max_num def forward(self, x): return rand_copy(x, self.pct, self.max_num) tfms = nn.Sequential(transforms.RandomCrop(28, padding=1), transforms.RandomHorizontalFlip(), RandCopy()) augcb = BatchTransformCB(partial(tfm_batch, tfm_x=tfms), on_val=False) model = get_model() learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=[DeviceCB(), SingleBatchCB(), augcb]) learn.fit(1) xb,yb = learn.batch show_images(xb[:16], imsize=1.5) set_seed(1) epochs = 25 lr = 1e-2 tmax = epochs * len(dls.train) sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax) xtra = [BatchSchedCB(sched), augcb] model = get_model(act_gr, norm=nn.BatchNorm2d).apply(iw) learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW) learn.fit(epochs) model2 = get_model(act_gr, norm=nn.BatchNorm2d).apply(iw) learn2 = TrainLearner(model2, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW) learn2.fit(epochs) mdl_path = Path('models') torch.save(learn.model, mdl_path/'randcopy1.pkl') torch.save(learn2.model, mdl_path/'randcopy2.pkl') cp1 = CapturePreds() learn.fit(1, train=False, cbs=cp1) cp2 = CapturePreds() learn2.fit(1, train=False, cbs=cp2) ap = torch.stack([cp1.all_preds,cp2.all_preds]).mean(0).argmax(1) round((ap==cp1.all_targs).float().mean().item(), 3) p = 0.1 dist = distributions.binomial.Binomial(probs=1-p) dist.sample((10,)) class Dropout(nn.Module): def __init__(self, p=0.1): super().__init__() self.p = p def forward(self, x): if not self.training: return x dist = distributions.binomial.Binomial(tensor(1.0).to(x.device), probs=1-self.p) return x * dist.sample(x.size()) * 1/(1-self.p) def get_dropmodel(act=nn.ReLU, nfs=(16,32,64,128,256,512), norm=nn.BatchNorm2d, drop=0.0): layers = [ResBlock(1, 16, ks=5, stride=1, act=act, norm=norm), nn.Dropout2d(drop)] layers += [ResBlock(nfs[i], nfs[i+1], act=act, norm=norm, stride=2) for i in range(len(nfs)-1)] layers += [nn.Flatten(), Dropout(drop), nn.Linear(nfs[-1], 10, bias=False), nn.BatchNorm1d(10)] return nn.Sequential(*layers) set_seed(42) epochs=5 lr = 1e-2 tmax = epochs * len(dls.train) sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax) xtra = [BatchSchedCB(sched)] model = get_dropmodel(act_gr, norm=nn.BatchNorm2d, drop=0.1).apply(iw) learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW) learn.fit(epochs) class TTD_CB(Callback): def before_epoch(self, learn): learn.model.apply(lambda m: m.train() if isinstance(m, (nn.Dropout,nn.Dropout2d)) else None) @inplace def transformi(b): b[xl] = [(TF.to_tensor(o)*2-1) for o in b[xl]] tds = dsd.with_transform(transformi) dls = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus) set_seed(42) epochs = 20 lr = 1e-2 tmax = epochs * len(dls.train) sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax) xtra = [BatchSchedCB(sched), augcb] model = get_model(act_gr, norm=nn.BatchNorm2d).apply(iw) learn = TrainLearner(model, dls, F.cross_entropy, lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW) learn.fit(epochs) torch.save(learn.model, 'models/data_aug2.pkl') import nbdev; nbdev.nbdev_export()