#| default_exp training #|export import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt from pathlib import Path from torch import tensor,nn import torch.nn.functional as F from fastcore.test import test_close torch.set_printoptions(precision=2, linewidth=140, sci_mode=False) torch.manual_seed(1) mpl.rcParams['image.cmap'] = 'gray' path_data = Path('data') path_gz = path_data/'mnist.pkl.gz' with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1') x_train, y_train, x_valid, y_valid = map(tensor, [x_train, y_train, x_valid, y_valid]) n,m = x_train.shape c = y_train.max()+1 nh = 50 class Model(nn.Module): def __init__(self, n_in, nh, n_out): super().__init__() self.layers = [nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out)] def __call__(self, x): for l in self.layers: x = l(x) return x model = Model(m, nh, 10) pred = model(x_train) pred.shape def log_softmax(x): return (x.exp()/(x.exp().sum(-1,keepdim=True))).log() log_softmax(pred) def log_softmax(x): return x - x.exp().sum(-1,keepdim=True).log() def logsumexp(x): m = x.max(-1)[0] return m + (x-m[:,None]).exp().sum(-1).log() def log_softmax(x): return x - x.logsumexp(-1,keepdim=True) test_close(logsumexp(pred), pred.logsumexp(-1)) sm_pred = log_softmax(pred) sm_pred y_train[:3] sm_pred[0,5],sm_pred[1,0],sm_pred[2,4] sm_pred[[0,1,2], y_train[:3]] def nll(input, target): return -input[range(target.shape[0]), target].mean() loss = nll(sm_pred, y_train) loss test_close(F.nll_loss(F.log_softmax(pred, -1), y_train), loss, 1e-3) test_close(F.cross_entropy(pred, y_train), loss, 1e-3) loss_func = F.cross_entropy bs=50 # batch size xb = x_train[0:bs] # a mini-batch from x preds = model(xb) # predictions preds[0], preds.shape yb = y_train[0:bs] yb loss_func(preds, yb) preds.argmax(dim=1) #|export def accuracy(out, yb): return (out.argmax(dim=1)==yb).float().mean() accuracy(preds, yb) lr = 0.5 # learning rate epochs = 3 # how many epochs to train for #|export def report(loss, preds, yb): print(f'{loss:.2f}, {accuracy(preds, yb):.2f}') xb,yb = x_train[:bs],y_train[:bs] preds = model(xb) report(loss_func(preds, yb), preds, yb) for epoch in range(epochs): for i in range(0, n, bs): s = slice(i, min(n,i+bs)) xb,yb = x_train[s],y_train[s] preds = model(xb) loss = loss_func(preds, yb) loss.backward() with torch.no_grad(): for l in model.layers: if hasattr(l, 'weight'): l.weight -= l.weight.grad * lr l.bias -= l.bias.grad * lr l.weight.grad.zero_() l.bias .grad.zero_() report(loss, preds, yb) m1 = nn.Module() m1.foo = nn.Linear(3,4) m1 list(m1.named_children()) m1.named_children() list(m1.parameters()) class MLP(nn.Module): def __init__(self, n_in, nh, n_out): super().__init__() self.l1 = nn.Linear(n_in,nh) self.l2 = nn.Linear(nh,n_out) self.relu = nn.ReLU() def forward(self, x): return self.l2(self.relu(self.l1(x))) model = MLP(m, nh, 10) model.l1 model for name,l in model.named_children(): print(f"{name}: {l}") for p in model.parameters(): print(p.shape) def fit(): for epoch in range(epochs): for i in range(0, n, bs): s = slice(i, min(n,i+bs)) xb,yb = x_train[s],y_train[s] preds = model(xb) loss = loss_func(preds, yb) loss.backward() with torch.no_grad(): for p in model.parameters(): p -= p.grad * lr model.zero_grad() report(loss, preds, yb) fit() class MyModule: def __init__(self, n_in, nh, n_out): self._modules = {} self.l1 = nn.Linear(n_in,nh) self.l2 = nn.Linear(nh,n_out) def __setattr__(self,k,v): if not k.startswith("_"): self._modules[k] = v super().__setattr__(k,v) def __repr__(self): return f'{self._modules}' def parameters(self): for l in self._modules.values(): yield from l.parameters() mdl = MyModule(m,nh,10) mdl for p in mdl.parameters(): print(p.shape) from functools import reduce layers = [nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10)] class Model(nn.Module): def __init__(self, layers): super().__init__() self.layers = layers for i,l in enumerate(self.layers): self.add_module(f'layer_{i}', l) def forward(self, x): return reduce(lambda val,layer: layer(val), self.layers, x) model = Model(layers) model model(xb).shape class SequentialModel(nn.Module): def __init__(self, layers): super().__init__() self.layers = nn.ModuleList(layers) def forward(self, x): for l in self.layers: x = l(x) return x model = SequentialModel(layers) model fit() model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10)) fit() loss_func(model(xb), yb), accuracy(model(xb), yb) model class Optimizer(): def __init__(self, params, lr=0.5): self.params,self.lr=list(params),lr def step(self): with torch.no_grad(): for p in self.params: p -= p.grad * self.lr def zero_grad(self): for p in self.params: p.grad.data.zero_() model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10)) opt = Optimizer(model.parameters()) for epoch in range(epochs): for i in range(0, n, bs): s = slice(i, min(n,i+bs)) xb,yb = x_train[s],y_train[s] preds = model(xb) loss = loss_func(preds, yb) loss.backward() opt.step() opt.zero_grad() report(loss, preds, yb) from torch import optim def get_model(): model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10)) return model, optim.SGD(model.parameters(), lr=lr) model,opt = get_model() loss_func(model(xb), yb) for epoch in range(epochs): for i in range(0, n, bs): s = slice(i, min(n,i+bs)) xb,yb = x_train[s],y_train[s] preds = model(xb) loss = loss_func(preds, yb) loss.backward() opt.step() opt.zero_grad() report(loss, preds, yb) #|export class Dataset(): def __init__(self, x, y): self.x,self.y = x,y def __len__(self): return len(self.x) def __getitem__(self, i): return self.x[i],self.y[i] train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid) assert len(train_ds)==len(x_train) assert len(valid_ds)==len(x_valid) xb,yb = train_ds[0:5] assert xb.shape==(5,28*28) assert yb.shape==(5,) xb,yb model,opt = get_model() for epoch in range(epochs): for i in range(0, n, bs): xb,yb = train_ds[i:min(n,i+bs)] preds = model(xb) loss = loss_func(preds, yb) loss.backward() opt.step() opt.zero_grad() report(loss, preds, yb) class DataLoader(): def __init__(self, ds, bs): self.ds,self.bs = ds,bs def __iter__(self): for i in range(0, len(self.ds), self.bs): yield self.ds[i:i+self.bs] train_dl = DataLoader(train_ds, bs) valid_dl = DataLoader(valid_ds, bs) xb,yb = next(iter(valid_dl)) xb.shape yb plt.imshow(xb[0].view(28,28)) yb[0] model,opt = get_model() def fit(): for epoch in range(epochs): for xb,yb in train_dl: preds = model(xb) loss = loss_func(preds, yb) loss.backward() opt.step() opt.zero_grad() report(loss, preds, yb) fit() loss_func(model(xb), yb), accuracy(model(xb), yb) import random class Sampler(): def __init__(self, ds, shuffle=False): self.n,self.shuffle = len(ds),shuffle def __iter__(self): res = list(range(self.n)) if self.shuffle: random.shuffle(res) return iter(res) from itertools import islice ss = Sampler(train_ds) it = iter(ss) for o in range(5): print(next(it)) list(islice(ss, 5)) ss = Sampler(train_ds, shuffle=True) list(islice(ss, 5)) import fastcore.all as fc class BatchSampler(): def __init__(self, sampler, bs, drop_last=False): fc.store_attr() def __iter__(self): yield from fc.chunked(iter(self.sampler), self.bs, drop_last=self.drop_last) batchs = BatchSampler(ss, 4) list(islice(batchs, 5)) def collate(b): xs,ys = zip(*b) return torch.stack(xs),torch.stack(ys) class DataLoader(): def __init__(self, ds, batchs, collate_fn=collate): fc.store_attr() def __iter__(self): yield from (self.collate_fn(self.ds[i] for i in b) for b in self.batchs) train_samp = BatchSampler(Sampler(train_ds, shuffle=True ), bs) valid_samp = BatchSampler(Sampler(valid_ds, shuffle=False), bs) train_dl = DataLoader(train_ds, batchs=train_samp) valid_dl = DataLoader(valid_ds, batchs=valid_samp) xb,yb = next(iter(valid_dl)) plt.imshow(xb[0].view(28,28)) yb[0] xb.shape,yb.shape model,opt = get_model() fit() import torch.multiprocessing as mp from fastcore.basics import store_attr train_ds[[3,6,8,1]] train_ds.__getitem__([3,6,8,1]) for o in map(train_ds.__getitem__, ([3,6],[8,1])): print(o) class DataLoader(): def __init__(self, ds, batchs, n_workers=1, collate_fn=collate): fc.store_attr() def __iter__(self): with mp.Pool(self.n_workers) as ex: yield from ex.map(self.ds.__getitem__, iter(self.batchs)) train_dl = DataLoader(train_ds, batchs=train_samp, n_workers=2) it = iter(train_dl) xb,yb = next(it) xb.shape,yb.shape #|export from torch.utils.data import DataLoader, SequentialSampler, RandomSampler, BatchSampler train_samp = BatchSampler(RandomSampler(train_ds), bs, drop_last=False) valid_samp = BatchSampler(SequentialSampler(valid_ds), bs, drop_last=False) train_dl = DataLoader(train_ds, batch_sampler=train_samp, collate_fn=collate) valid_dl = DataLoader(valid_ds, batch_sampler=valid_samp, collate_fn=collate) model,opt = get_model() fit() loss_func(model(xb), yb), accuracy(model(xb), yb) train_dl = DataLoader(train_ds, bs, sampler=RandomSampler(train_ds), collate_fn=collate) valid_dl = DataLoader(valid_ds, bs, sampler=SequentialSampler(valid_ds), collate_fn=collate) train_dl = DataLoader(train_ds, bs, shuffle=True, drop_last=True, num_workers=2) valid_dl = DataLoader(valid_ds, bs, shuffle=False, num_workers=2) model,opt = get_model() fit() loss_func(model(xb), yb), accuracy(model(xb), yb) train_ds[[4,6,7]] train_dl = DataLoader(train_ds, sampler=train_samp) valid_dl = DataLoader(valid_ds, sampler=valid_samp) xb,yb = next(iter(train_dl)) xb.shape,yb.shape #|export def fit(epochs, model, loss_func, opt, train_dl, valid_dl): for epoch in range(epochs): model.train() for xb,yb in train_dl: loss = loss_func(model(xb), yb) loss.backward() opt.step() opt.zero_grad() model.eval() with torch.no_grad(): tot_loss,tot_acc,count = 0.,0.,0 for xb,yb in valid_dl: pred = model(xb) n = len(xb) count += n tot_loss += loss_func(pred,yb).item()*n tot_acc += accuracy (pred,yb).item()*n print(epoch, tot_loss/count, tot_acc/count) return tot_loss/count, tot_acc/count #|export def get_dls(train_ds, valid_ds, bs, **kwargs): return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs), DataLoader(valid_ds, batch_size=bs*2, **kwargs)) train_dl,valid_dl = get_dls(train_ds, valid_ds, bs) model,opt = get_model() %time loss,acc = fit(5, model, loss_func, opt, train_dl, valid_dl) import nbdev; nbdev.nbdev_export()