import os os.environ['CUDA_VISIBLE_DEVICES']='1' import timm, torch, random, datasets, math, fastcore.all as fc, numpy as np, matplotlib as mpl, matplotlib.pyplot as plt import k_diffusion as K, torchvision.transforms as T import torchvision.transforms.functional as TF,torch.nn.functional as F from torch.utils.data import DataLoader,default_collate from pathlib import Path from torch.nn import init from fastcore.foundation import L from torch import nn,tensor from datasets import load_dataset from operator import itemgetter from torcheval.metrics import MulticlassAccuracy,Mean,Metric from functools import partial from torch.optim import lr_scheduler from torch import optim from einops import rearrange from miniai.datasets import * from miniai.conv import * from miniai.learner import * from miniai.activations import * from miniai.training import * from miniai.init import * from miniai.sgd import * from miniai.resnet import * from miniai.augment import * from miniai.accel import * torch.set_printoptions(precision=4, linewidth=140, sci_mode=False) torch.manual_seed(1) mpl.rcParams['image.cmap'] = 'gray_r' mpl.rcParams['figure.dpi'] = 70 import logging logging.disable(logging.WARNING) set_seed(42) if fc.defaults.cpus>8: fc.defaults.cpus=8 xl,yl = 'image','label' name = "fashion_mnist" bs = 256 dsd = load_dataset(name) @inplace def transformi(b): img = [TF.to_tensor(o).flatten() for o in b[xl]] b[yl] = b[xl] = img tds = dsd.with_transform(transformi) dls = DataLoaders.from_dd(tds, bs, num_workers=8) dl = dls.valid xb,yb = b = next(iter(dl)) ni,nh,nl = 784,400,200 def lin(ni, nf, act=nn.SiLU, norm=nn.BatchNorm1d, bias=True): layers = nn.Sequential(nn.Linear(ni, nf, bias=bias)) if act : layers.append(act()) if norm: layers.append(norm(nf)) return layers def init_weights(m, leaky=0.): if isinstance(m, (nn.Conv1d,nn.Conv2d,nn.Conv3d,nn.Linear)): init.kaiming_normal_(m.weight, a=leaky) iw = partial(init_weights, leaky=0.2) class Autoenc(nn.Module): def __init__(self): super().__init__() self.enc = nn.Sequential(lin(ni, nh), lin(nh, nh), lin(nh, nl)) self.dec = nn.Sequential(lin(nl, nh), lin(nh, nh), lin(nh, ni, act=None)) iw(self) def forward(self, x): x = self.enc(x) return self.dec(x) opt_func = partial(optim.Adam, eps=1e-5) Learner(Autoenc(), dls, nn.BCEWithLogitsLoss(), cbs=[DeviceCB(), MixedPrecision()], opt_func=opt_func).lr_find() lr = 3e-2 epochs = 20 tmax = epochs * len(dls.train) sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax) cbs = [DeviceCB(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched), MixedPrecision()] model = Autoenc() learn = Learner(model, dls, nn.BCEWithLogitsLoss(), lr=lr, cbs=cbs, opt_func=opt_func) learn.fit(epochs) with torch.no_grad(): t = to_cpu(model(xb.cuda()).float()) show_images(xb[:9].reshape(-1,1,28,28), imsize=1.5, title='Original'); show_images(t[:9].reshape(-1,1,28,28).sigmoid(), imsize=1.5, title='Autoenc'); noise = torch.randn(16, nl).cuda() with torch.no_grad(): generated_images = model.dec(noise).sigmoid() show_images(generated_images.reshape(-1, 1, 28, 28), imsize=1.5) # sd vae is 3 down, 1 no-down, mid, conv, sampling, conv, mid, 3 up, 1 no-up class VAE(nn.Module): def __init__(self): super().__init__() self.enc = nn.Sequential(lin(ni, nh), lin(nh, nh)) self.mu,self.lv = lin(nh, nl, act=None),lin(nh, nl, act=None) self.dec = nn.Sequential(lin(nl, nh), lin(nh, nh), lin(nh, ni, act=None)) iw(self) def forward(self, x): x = self.enc(x) mu,lv = self.mu(x),self.lv(x) z = mu + (0.5*lv).exp()*torch.randn_like(lv) return self.dec(z),mu,lv def kld_loss(inp, x): x_hat,mu,lv = inp return -0.5 * (1 + lv - mu.pow(2) - lv.exp()).mean() def bce_loss(inp, x): return F.binary_cross_entropy_with_logits(inp[0], x) def vae_loss(inp, x): return kld_loss(inp, x) + bce_loss(inp,x) x = torch.linspace(-3,3,100) plt.figure(figsize=(4,3)) plt.plot(x, -0.5*(1+x-x.exp())); class FuncMetric(Mean): def __init__(self, fn, device=None): super().__init__(device=device) self.fn = fn def update(self, inp, targets): self.weighted_sum += self.fn(inp, targets) self.weights += 1 metrics = MetricsCB(kld=FuncMetric(kld_loss), bce=FuncMetric(bce_loss)) opt_func = partial(optim.Adam, eps=1e-5) lr = 3e-2 epochs = 20 tmax = epochs * len(dls.train) sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax) cbs = [DeviceCB(), ProgressCB(plot=True), metrics, BatchSchedCB(sched), MixedPrecision()] model = VAE() learn = Learner(model, dls, vae_loss, lr=lr, cbs=cbs, opt_func=opt_func) learn.fit(epochs) with torch.no_grad(): t,mu,lv = to_cpu(model(xb.cuda())) t = t.float() show_images(xb[:9].reshape(-1,1,28,28), imsize=1.5, title='Original'); show_images(t[:9].reshape(-1,1,28,28).sigmoid(), imsize=1.5, title='VAE'); noise = torch.randn(16, nl).cuda() with torch.no_grad(): ims = model.dec(noise).sigmoid() show_images(ims.reshape(-1, 1, 28, 28), imsize=1.5)