import os
os.environ['CUDA_VISIBLE_DEVICES']='1'
import timm, torch, random, datasets, math, fastcore.all as fc, numpy as np, matplotlib as mpl, matplotlib.pyplot as plt
import k_diffusion as K, torchvision.transforms as T
import torchvision.transforms.functional as TF,torch.nn.functional as F
from torch.utils.data import DataLoader,default_collate
from pathlib import Path
from torch.nn import init
from fastcore.foundation import L
from torch import nn,tensor
from datasets import load_dataset
from operator import itemgetter
from torcheval.metrics import MulticlassAccuracy,Mean,Metric
from functools import partial
from torch.optim import lr_scheduler
from torch import optim
from einops import rearrange
from miniai.datasets import *
from miniai.conv import *
from miniai.learner import *
from miniai.activations import *
from miniai.training import *
from miniai.init import *
from miniai.sgd import *
from miniai.resnet import *
from miniai.augment import *
from miniai.accel import *
torch.set_printoptions(precision=4, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray_r'
mpl.rcParams['figure.dpi'] = 70
import logging
logging.disable(logging.WARNING)
set_seed(42)
if fc.defaults.cpus>8: fc.defaults.cpus=8
xl,yl = 'image','label'
name = "fashion_mnist"
bs = 256
dsd = load_dataset(name)
0%| | 0/2 [00:00<?, ?it/s]
@inplace
def transformi(b):
img = [TF.to_tensor(o).flatten() for o in b[xl]]
b[yl] = b[xl] = img
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs, num_workers=8)
dl = dls.valid
xb,yb = b = next(iter(dl))
ni,nh,nl = 784,400,200
def lin(ni, nf, act=nn.SiLU, norm=nn.BatchNorm1d, bias=True):
layers = nn.Sequential(nn.Linear(ni, nf, bias=bias))
if act : layers.append(act())
if norm: layers.append(norm(nf))
return layers
def init_weights(m, leaky=0.):
if isinstance(m, (nn.Conv1d,nn.Conv2d,nn.Conv3d,nn.Linear)): init.kaiming_normal_(m.weight, a=leaky)
iw = partial(init_weights, leaky=0.2)
class Autoenc(nn.Module):
def __init__(self):
super().__init__()
self.enc = nn.Sequential(lin(ni, nh), lin(nh, nh), lin(nh, nl))
self.dec = nn.Sequential(lin(nl, nh), lin(nh, nh), lin(nh, ni, act=None))
iw(self)
def forward(self, x):
x = self.enc(x)
return self.dec(x)
opt_func = partial(optim.Adam, eps=1e-5)
Learner(Autoenc(), dls, nn.BCEWithLogitsLoss(), cbs=[DeviceCB(), MixedPrecision()], opt_func=opt_func).lr_find()
lr = 3e-2
epochs = 20
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
cbs = [DeviceCB(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched), MixedPrecision()]
model = Autoenc()
learn = Learner(model, dls, nn.BCEWithLogitsLoss(), lr=lr, cbs=cbs, opt_func=opt_func)
learn.fit(epochs)
| loss | epoch | train |
|---|---|---|
| 0.528 | 0 | train |
| 0.472 | 0 | eval |
| 0.400 | 1 | train |
| 0.350 | 1 | eval |
| 0.316 | 2 | train |
| 0.299 | 2 | eval |
| 0.286 | 3 | train |
| 0.285 | 3 | eval |
| 0.276 | 4 | train |
| 0.277 | 4 | eval |
| 0.271 | 5 | train |
| 0.273 | 5 | eval |
| 0.268 | 6 | train |
| 0.270 | 6 | eval |
| 0.266 | 7 | train |
| 0.267 | 7 | eval |
| 0.265 | 8 | train |
| 0.267 | 8 | eval |
| 0.264 | 9 | train |
| 0.266 | 9 | eval |
| 0.263 | 10 | train |
| 0.264 | 10 | eval |
| 0.262 | 11 | train |
| 0.264 | 11 | eval |
| 0.262 | 12 | train |
| 0.263 | 12 | eval |
| 0.261 | 13 | train |
| 0.262 | 13 | eval |
| 0.261 | 14 | train |
| 0.261 | 14 | eval |
| 0.260 | 15 | train |
| 0.261 | 15 | eval |
| 0.260 | 16 | train |
| 0.261 | 16 | eval |
| 0.260 | 17 | train |
| 0.261 | 17 | eval |
| 0.259 | 18 | train |
| 0.260 | 18 | eval |
| 0.259 | 19 | train |
| 0.260 | 19 | eval |
with torch.no_grad(): t = to_cpu(model(xb.cuda()).float())
show_images(xb[:9].reshape(-1,1,28,28), imsize=1.5, title='Original');
show_images(t[:9].reshape(-1,1,28,28).sigmoid(), imsize=1.5, title='Autoenc');
noise = torch.randn(16, nl).cuda()
with torch.no_grad(): generated_images = model.dec(noise).sigmoid()
show_images(generated_images.reshape(-1, 1, 28, 28), imsize=1.5)
# sd vae is 3 down, 1 no-down, mid, conv, sampling, conv, mid, 3 up, 1 no-up
class VAE(nn.Module):
def __init__(self):
super().__init__()
self.enc = nn.Sequential(lin(ni, nh), lin(nh, nh))
self.mu,self.lv = lin(nh, nl, act=None),lin(nh, nl, act=None)
self.dec = nn.Sequential(lin(nl, nh), lin(nh, nh), lin(nh, ni, act=None))
iw(self)
def forward(self, x):
x = self.enc(x)
mu,lv = self.mu(x),self.lv(x)
z = mu + (0.5*lv).exp()*torch.randn_like(lv)
return self.dec(z),mu,lv
def kld_loss(inp, x):
x_hat,mu,lv = inp
return -0.5 * (1 + lv - mu.pow(2) - lv.exp()).mean()
def bce_loss(inp, x): return F.binary_cross_entropy_with_logits(inp[0], x)
def vae_loss(inp, x): return kld_loss(inp, x) + bce_loss(inp,x)
x = torch.linspace(-3,3,100)
plt.figure(figsize=(4,3))
plt.plot(x, -0.5*(1+x-x.exp()));
Question: What would happen if the variance of the latents were very low? What if they were very high?
Bing: If the variance of the latents were very low, then the encoder distribution would be very peaked and concentrated around the mean. This would make the latent space less diverse and expressive, and limit the ability of the decoder to reconstruct the data accurately. It would also make it harder to generate new data that are different from the training data.
If the variance of the latents were very high, then the encoder distribution would be very spread out and diffuse. This would make the latent space more noisy and random, and reduce the correlation between the latent codes and the data. It would also make it easier to generate new data that are unrealistic or nonsensical.
class FuncMetric(Mean):
def __init__(self, fn, device=None):
super().__init__(device=device)
self.fn = fn
def update(self, inp, targets):
self.weighted_sum += self.fn(inp, targets)
self.weights += 1
metrics = MetricsCB(kld=FuncMetric(kld_loss), bce=FuncMetric(bce_loss))
opt_func = partial(optim.Adam, eps=1e-5)
lr = 3e-2
epochs = 20
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
cbs = [DeviceCB(), ProgressCB(plot=True), metrics, BatchSchedCB(sched), MixedPrecision()]
model = VAE()
learn = Learner(model, dls, vae_loss, lr=lr, cbs=cbs, opt_func=opt_func)
learn.fit(epochs)
| bce | kld | loss | epoch | train |
|---|---|---|---|---|
| 0.546 | 0.426 | 0.973 | 0 | train |
| 0.495 | 0.269 | 0.762 | 0 | eval |
| 0.445 | 0.100 | 0.545 | 1 | train |
| 0.407 | 0.049 | 0.455 | 1 | eval |
| 0.382 | 0.034 | 0.416 | 2 | train |
| 0.363 | 0.038 | 0.401 | 2 | eval |
| 0.356 | 0.035 | 0.391 | 3 | train |
| 0.349 | 0.038 | 0.387 | 3 | eval |
| 0.347 | 0.036 | 0.383 | 4 | train |
| 0.344 | 0.036 | 0.380 | 4 | eval |
| 0.337 | 0.032 | 0.369 | 5 | train |
| 0.334 | 0.029 | 0.363 | 5 | eval |
| 0.330 | 0.029 | 0.358 | 6 | train |
| 0.328 | 0.029 | 0.357 | 6 | eval |
| 0.325 | 0.029 | 0.353 | 7 | train |
| 0.324 | 0.029 | 0.353 | 7 | eval |
| 0.321 | 0.028 | 0.350 | 8 | train |
| 0.321 | 0.029 | 0.350 | 8 | eval |
| 0.320 | 0.029 | 0.349 | 9 | train |
| 0.318 | 0.029 | 0.347 | 9 | eval |
| 0.317 | 0.030 | 0.347 | 10 | train |
| 0.317 | 0.029 | 0.346 | 10 | eval |
| 0.316 | 0.030 | 0.345 | 11 | train |
| 0.316 | 0.030 | 0.346 | 11 | eval |
| 0.314 | 0.030 | 0.345 | 12 | train |
| 0.314 | 0.030 | 0.344 | 12 | eval |
| 0.313 | 0.030 | 0.344 | 13 | train |
| 0.313 | 0.030 | 0.343 | 13 | eval |
| 0.312 | 0.031 | 0.343 | 14 | train |
| 0.312 | 0.031 | 0.343 | 14 | eval |
| 0.311 | 0.031 | 0.342 | 15 | train |
| 0.311 | 0.031 | 0.342 | 15 | eval |
| 0.311 | 0.031 | 0.342 | 16 | train |
| 0.311 | 0.031 | 0.341 | 16 | eval |
| 0.310 | 0.031 | 0.341 | 17 | train |
| 0.310 | 0.031 | 0.341 | 17 | eval |
| 0.310 | 0.031 | 0.341 | 18 | train |
| 0.310 | 0.031 | 0.341 | 18 | eval |
| 0.310 | 0.031 | 0.341 | 19 | train |
| 0.310 | 0.031 | 0.341 | 19 | eval |
with torch.no_grad(): t,mu,lv = to_cpu(model(xb.cuda()))
t = t.float()
show_images(xb[:9].reshape(-1,1,28,28), imsize=1.5, title='Original');
show_images(t[:9].reshape(-1,1,28,28).sigmoid(), imsize=1.5, title='VAE');
noise = torch.randn(16, nl).cuda()
with torch.no_grad(): ims = model.dec(noise).sigmoid()
show_images(ims.reshape(-1, 1, 28, 28), imsize=1.5)