#| default_exp fid #|export import pickle,gzip,math,os,time,shutil,torch,random import fastcore.all as fc,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt from collections.abc import Mapping from pathlib import Path from operator import attrgetter,itemgetter from functools import partial from copy import copy from contextlib import contextmanager from scipy import linalg from fastcore.foundation import L import torchvision.transforms.functional as TF,torch.nn.functional as F from torch import tensor,nn,optim from torch.utils.data import DataLoader,default_collate from torch.nn import init from torch.optim import lr_scheduler from torcheval.metrics import MulticlassAccuracy from datasets import load_dataset,load_dataset_builder from miniai.datasets import * from miniai.conv import * from miniai.learner import * from miniai.activations import * from miniai.init import * from miniai.sgd import * from miniai.resnet import * from miniai.augment import * from miniai.accel import * from fastcore.test import test_close from torch import distributions torch.set_printoptions(precision=2, linewidth=140, sci_mode=False) torch.manual_seed(1) mpl.rcParams['image.cmap'] = 'gray_r' import logging logging.disable(logging.WARNING) set_seed(42) if fc.defaults.cpus>8: fc.defaults.cpus=8 xl,yl = 'image','label' name = "fashion_mnist" bs = 512 @inplace def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))*2-1 for o in b[xl]] dsd = load_dataset(name) tds = dsd.with_transform(transformi) dls = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus) b = xb,yb = next(iter(dls.train)) cbs = [DeviceCB(), MixedPrecision()] model = torch.load('models/data_aug2.pkl') learn = Learner(model, dls, F.cross_entropy, cbs=cbs, opt_func=None) def append_outp(hook, mod, inp, outp): if not hasattr(hook,'outp'): hook.outp = [] hook.outp.append(to_cpu(outp)) hcb = HooksCallback(append_outp, mods=[learn.model[6]], on_valid=True) learn.fit(1, train=False, cbs=[hcb]) feats = hcb.hooks[0].outp[0].float()[:64] feats.shape del(learn.model[8]) del(learn.model[7]) feats,y = learn.capture_preds() feats = feats.float() feats.shape,y betamin,betamax,n_steps = 0.0001,0.02,1000 beta = torch.linspace(betamin, betamax, n_steps) alpha = 1.-beta alphabar = alpha.cumprod(dim=0) sigma = beta.sqrt() def noisify(x0, ᾱ): device = x0.device n = len(x0) t = torch.randint(0, n_steps, (n,), dtype=torch.long) ε = torch.randn(x0.shape, device=device) ᾱ_t = ᾱ[t].reshape(-1, 1, 1, 1).to(device) xt = ᾱ_t.sqrt()*x0 + (1-ᾱ_t).sqrt()*ε return (xt, t.to(device)), ε def collate_ddpm(b): return noisify(default_collate(b)[xl], alphabar) def dl_ddpm(ds): return DataLoader(ds, batch_size=bs, collate_fn=collate_ddpm, num_workers=4) dls2 = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test'])) from diffusers import UNet2DModel class UNet(UNet2DModel): def forward(self, x): return super().forward(*x).sample smodel = torch.load('models/fashion_ddpm_mp.pkl').cuda() @torch.no_grad() def sample(model, sz, alpha, alphabar, sigma, n_steps): device = next(model.parameters()).device x_t = torch.randn(sz, device=device) preds = [] for t in reversed(range(n_steps)): t_batch = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long) z = (torch.randn(x_t.shape) if t > 0 else torch.zeros(x_t.shape)).to(device) ᾱ_t1 = alphabar[t-1] if t > 0 else torch.tensor(1) b̄_t = 1 - alphabar[t] b̄_t1 = 1 - ᾱ_t1 x_0_hat = ((x_t - b̄_t.sqrt() * model((x_t, t_batch)))/alphabar[t].sqrt()) x_t = x_0_hat * ᾱ_t1.sqrt()*(1-alpha[t])/b̄_t + x_t * alpha[t].sqrt()*b̄_t1/b̄_t + sigma[t]*z preds.append(x_0_hat.cpu()) return preds %%time samples = sample(smodel, (256, 1, 32, 32), alpha, alphabar, sigma, n_steps) s = samples[-1]*2-1 show_images(s[:16], imsize=1.5) clearn = TrainLearner(model, DataLoaders([],[(s,yb)]), loss_func=fc.noop, cbs=[DeviceCB()], opt_func=None) feats2,y2 = clearn.capture_preds() feats2 = feats2.float().squeeze() feats2.shape means = feats.mean(0) means.shape covs = feats.T.cov() covs.shape #|export def _sqrtm_newton_schulz(mat, num_iters=100): mat_nrm = mat.norm() mat = mat.double() Y = mat/mat_nrm n = len(mat) I = torch.eye(n, n).to(mat) Z = torch.eye(n, n).to(mat) for i in range(num_iters): T = (3*I - Z@Y)/2 Y,Z = Y@T,T@Z res = Y*mat_nrm.sqrt() if ((mat-(res@res)).norm()/mat_nrm).abs()<=1e-6: break return res #|export def _calc_stats(feats): feats = feats.squeeze() return feats.mean(0),feats.T.cov() def _calc_fid(m1,c1,m2,c2): # csr = _sqrtm_newton_schulz(c1@c2) csr = tensor(linalg.sqrtm(c1@c2, 256).real) return (((m1-m2)**2).sum() + c1.trace() + c2.trace() - 2*csr.trace()).item() s1,s2 = _calc_stats(feats),_calc_stats(feats2) _calc_fid(*s1, *s2) #|export def _squared_mmd(x, y): def k(a,b): return (a@b.transpose(-2,-1)/a.shape[-1]+1)**3 m,n = x.shape[-2],y.shape[-2] kxx,kyy,kxy = k(x,x), k(y,y), k(x,y) kxx_sum = kxx.sum([-1,-2])-kxx.diagonal(0,-1,-2).sum(-1) kyy_sum = kyy.sum([-1,-2])-kyy.diagonal(0,-1,-2).sum(-1) kxy_sum = kxy.sum([-1,-2]) return kxx_sum/m/(m-1) + kyy_sum/n/(n-1) - kxy_sum*2/m/n #|export def _calc_kid(x, y, maxs=50): xs,ys = x.shape[0],y.shape[0] n = max(math.ceil(min(xs/maxs, ys/maxs)), 4) mmd = 0. for i in range(n): cur_x = x[round(i*xs/n) : round((i+1)*xs/n)] cur_y = y[round(i*ys/n) : round((i+1)*ys/n)] mmd += _squared_mmd(cur_x, cur_y) return (mmd/n).item() _calc_kid(feats, feats2) #|export class ImageEval: def __init__(self, model, dls, cbs=None): self.learn = TrainLearner(model, dls, loss_func=fc.noop, cbs=cbs, opt_func=None) self.feats = self.learn.capture_preds()[0].float().cpu().squeeze() self.stats = _calc_stats(self.feats) def get_feats(self, samp): self.learn.dls = DataLoaders([],[(samp, tensor([0]))]) return self.learn.capture_preds()[0].float().cpu().squeeze() def fid(self, samp): return _calc_fid(*self.stats, *_calc_stats(self.get_feats(samp))) def kid(self, samp): return _calc_kid(self.feats, self.get_feats(samp)) ie = ImageEval(model, learn.dls, cbs=[DeviceCB()]) %%time ie.fid(s) %%time ie.kid(s) xs = L.range(0,1000,50)+[975,990,999] plt.plot(xs, [ie.fid(samples[i].clamp(-0.5,0.5)*2) for i in xs]); xs = L.range(0,1000,50)+[975,990,999] plt.plot(xs, [ie.kid(samples[i].clamp(-0.5,0.5)*2) for i in xs]); ie.fid(xb) ie.kid(xb) from pytorch_fid.inception import InceptionV3 a = tensor([1,2,3]) a.repeat((3,1)) class IncepWrap(nn.Module): def __init__(self): super().__init__() self.m = InceptionV3(resize_input=True) def forward(self, x): return self.m(x.repeat(1,3,1,1))[0] tds = dsd.with_transform(transformi) dls = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus) ie = ImageEval(IncepWrap(), dls, cbs=[DeviceCB()]) %%time ie.fid(s) ie.fid(xb) %%time ie.kid(s) ie.kid(xb) import nbdev; nbdev.nbdev_export()