#| default_exp diffusion
import os
os.environ['CUDA_VISIBLE_DEVICES']='1'
#| export
from miniai.imports import *
from einops import rearrange
from fastprogress import progress_bar
torch.set_printoptions(precision=4, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray_r'
mpl.rcParams['figure.dpi'] = 70
import logging
logging.disable(logging.WARNING)
set_seed(42)
if fc.defaults.cpus>8: fc.defaults.cpus=8
xl,yl = 'image','label'
name = "fashion_mnist"
bs = 512
dsd = load_dataset(name)
0%| | 0/2 [00:00<?, ?it/s]
#| export
def abar(t): return (t*math.pi/2).cos()**2
def inv_abar(x): return x.sqrt().acos()*2/math.pi
def noisify(x0):
device = x0.device
n = len(x0)
t = torch.rand(n,).to(x0).clamp(0,0.999)
ε = torch.randn(x0.shape, device=device)
abar_t = abar(t).reshape(-1, 1, 1, 1).to(device)
xt = abar_t.sqrt()*x0 + (1-abar_t).sqrt()*ε
return (xt, t.to(device)), ε
def collate_ddpm(b): return noisify(default_collate(b)[xl])
def dl_ddpm(ds): return DataLoader(ds, batch_size=bs, collate_fn=collate_ddpm, num_workers=4)
@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))-0.5 for o in b[xl]]
tds = dsd.with_transform(transformi)
dls = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test']))
dl = dls.train
(xt,t),eps = b = next(iter(dl))
Based on Diffusers
#| export
def timestep_embedding(tsteps, emb_dim, max_period= 10000):
exponent = -math.log(max_period) * torch.linspace(0, 1, emb_dim//2, device=tsteps.device)
emb = tsteps[:,None].float() * exponent.exp()[None,:]
emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
return F.pad(emb, (0,1,0,0)) if emb_dim%2==1 else emb
#| export
def pre_conv(ni, nf, ks=3, stride=1, act=nn.SiLU, norm=None, bias=True):
layers = nn.Sequential()
if norm: layers.append(norm(ni))
if act : layers.append(act())
layers.append(nn.Conv2d(ni, nf, stride=stride, kernel_size=ks, padding=ks//2, bias=bias))
return layers
#| export
def upsample(nf): return nn.Sequential(nn.Upsample(scale_factor=2.), nn.Conv2d(nf, nf, 3, padding=1))
#| export
def lin(ni, nf, act=nn.SiLU, norm=None, bias=True):
layers = nn.Sequential()
if norm: layers.append(norm(ni))
if act : layers.append(act())
layers.append(nn.Linear(ni, nf, bias=bias))
return layers
# This version is giving poor results - use the cell below instead
class SelfAttention(nn.Module):
def __init__(self, ni, attn_chans):
super().__init__()
self.attn = nn.MultiheadAttention(ni, ni//attn_chans, batch_first=True)
self.norm = nn.BatchNorm2d(ni)
def forward(self, x):
n,c,h,w = x.shape
x = self.norm(x).view(n, c, -1).transpose(1, 2)
x = self.attn(x, x, x, need_weights=False)[0]
return x.transpose(1,2).reshape(n,c,h,w)
#| export
class SelfAttention(nn.Module):
def __init__(self, ni, attn_chans, transpose=True):
super().__init__()
self.nheads = ni//attn_chans
self.scale = math.sqrt(ni/self.nheads)
self.norm = nn.LayerNorm(ni)
self.qkv = nn.Linear(ni, ni*3)
self.proj = nn.Linear(ni, ni)
self.t = transpose
def forward(self, x):
n,c,s = x.shape
if self.t: x = x.transpose(1, 2)
x = self.norm(x)
x = self.qkv(x)
x = rearrange(x, 'n s (h d) -> (n h) s d', h=self.nheads)
q,k,v = torch.chunk(x, 3, dim=-1)
s = (q@k.transpose(1,2))/self.scale
x = s.softmax(dim=-1)@v
x = rearrange(x, '(n h) s d -> n s (h d)', h=self.nheads)
x = self.proj(x)
if self.t: x = x.transpose(1, 2)
return x
#| export
class SelfAttention2D(SelfAttention):
def forward(self, x):
n,c,h,w = x.shape
return super().forward(x.view(n, c, -1)).reshape(n,c,h,w)
#| export
class EmbResBlock(nn.Module):
def __init__(self, n_emb, ni, nf=None, ks=3, act=nn.SiLU, norm=nn.BatchNorm2d, attn_chans=0):
super().__init__()
if nf is None: nf = ni
self.emb_proj = nn.Linear(n_emb, nf*2)
self.conv1 = pre_conv(ni, nf, ks, act=act, norm=norm)
self.conv2 = pre_conv(nf, nf, ks, act=act, norm=norm)
self.idconv = fc.noop if ni==nf else nn.Conv2d(ni, nf, 1)
self.attn = False
if attn_chans: self.attn = SelfAttention2D(nf, attn_chans)
def forward(self, x, t):
inp = x
x = self.conv1(x)
emb = self.emb_proj(F.silu(t))[:, :, None, None]
scale,shift = torch.chunk(emb, 2, dim=1)
x = x*(1+scale) + shift
x = self.conv2(x)
x = x + self.idconv(inp)
if self.attn: x = x + self.attn(x)
return x
#| export
def saved(m, blk):
m_ = m.forward
@wraps(m.forward)
def _f(*args, **kwargs):
res = m_(*args, **kwargs)
blk.saved.append(res)
return res
m.forward = _f
return m
#| export
class DownBlock(nn.Module):
def __init__(self, n_emb, ni, nf, add_down=True, num_layers=1, attn_chans=0):
super().__init__()
self.resnets = nn.ModuleList([saved(EmbResBlock(n_emb, ni if i==0 else nf, nf, attn_chans=attn_chans), self)
for i in range(num_layers)])
self.down = saved(nn.Conv2d(nf, nf, 3, stride=2, padding=1), self) if add_down else nn.Identity()
def forward(self, x, t):
self.saved = []
for resnet in self.resnets: x = resnet(x, t)
x = self.down(x)
return x
#| export
class UpBlock(nn.Module):
def __init__(self, n_emb, ni, prev_nf, nf, add_up=True, num_layers=2, attn_chans=0):
super().__init__()
self.resnets = nn.ModuleList(
[EmbResBlock(n_emb, (prev_nf if i==0 else nf)+(ni if (i==num_layers-1) else nf), nf, attn_chans=attn_chans)
for i in range(num_layers)])
self.up = upsample(nf) if add_up else nn.Identity()
def forward(self, x, t, ups):
for resnet in self.resnets: x = resnet(torch.cat([x, ups.pop()], dim=1), t)
return self.up(x)
#| export
class EmbUNetModel(nn.Module):
def __init__( self, in_channels=3, out_channels=3, nfs=(224,448,672,896), num_layers=1, attn_chans=8, attn_start=1):
super().__init__()
self.conv_in = nn.Conv2d(in_channels, nfs[0], kernel_size=3, padding=1)
self.n_temb = nf = nfs[0]
n_emb = nf*4
self.emb_mlp = nn.Sequential(lin(self.n_temb, n_emb, norm=nn.BatchNorm1d),
lin(n_emb, n_emb))
self.downs = nn.ModuleList()
n = len(nfs)
for i in range(n):
ni = nf
nf = nfs[i]
self.downs.append(DownBlock(n_emb, ni, nf, add_down=i!=n-1, num_layers=num_layers,
attn_chans=0 if i<attn_start else attn_chans))
self.mid_block = EmbResBlock(n_emb, nfs[-1])
rev_nfs = list(reversed(nfs))
nf = rev_nfs[0]
self.ups = nn.ModuleList()
for i in range(n):
prev_nf = nf
nf = rev_nfs[i]
ni = rev_nfs[min(i+1, len(nfs)-1)]
self.ups.append(UpBlock(n_emb, ni, prev_nf, nf, add_up=i!=n-1, num_layers=num_layers+1,
attn_chans=0 if i>=n-attn_start else attn_chans))
self.conv_out = pre_conv(nfs[0], out_channels, act=nn.SiLU, norm=nn.BatchNorm2d, bias=False)
def forward(self, inp):
x,t = inp
temb = timestep_embedding(t, self.n_temb)
emb = self.emb_mlp(temb)
x = self.conv_in(x)
saved = [x]
for block in self.downs: x = block(x, emb)
saved += [p for o in self.downs for p in o.saved]
x = self.mid_block(x, emb)
for block in self.ups: x = block(x, emb, saved)
return self.conv_out(x)
lr = 1e-2
epochs = 25
opt_func = partial(optim.Adam, eps=1e-5)
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
cbs = [DeviceCB(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched), MixedPrecision()]
model = EmbUNetModel(in_channels=1, out_channels=1, nfs=(32,64,128,256), num_layers=2)
learn = Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=opt_func)
learn.fit(epochs)
| loss | epoch | train |
|---|---|---|
| 0.150 | 0 | train |
| 0.086 | 0 | eval |
| 0.069 | 1 | train |
| 0.171 | 1 | eval |
| 0.057 | 2 | train |
| 0.071 | 2 | eval |
| 0.050 | 3 | train |
| 0.055 | 3 | eval |
| 0.045 | 4 | train |
| 0.050 | 4 | eval |
| 0.043 | 5 | train |
| 0.073 | 5 | eval |
| 0.041 | 6 | train |
| 0.044 | 6 | eval |
| 0.039 | 7 | train |
| 0.044 | 7 | eval |
| 0.038 | 8 | train |
| 0.043 | 8 | eval |
| 0.038 | 9 | train |
| 0.058 | 9 | eval |
| 0.038 | 10 | train |
| 0.044 | 10 | eval |
| 0.036 | 11 | train |
| 0.042 | 11 | eval |
| 0.035 | 12 | train |
| 0.038 | 12 | eval |
| 0.035 | 13 | train |
| 0.039 | 13 | eval |
| 0.034 | 14 | train |
| 0.036 | 14 | eval |
| 0.034 | 15 | train |
| 0.036 | 15 | eval |
| 0.034 | 16 | train |
| 0.034 | 16 | eval |
| 0.034 | 17 | train |
| 0.035 | 17 | eval |
| 0.033 | 18 | train |
| 0.033 | 18 | eval |
| 0.033 | 19 | train |
| 0.033 | 19 | eval |
| 0.033 | 20 | train |
| 0.033 | 20 | eval |
| 0.033 | 21 | train |
| 0.032 | 21 | eval |
| 0.032 | 22 | train |
| 0.034 | 22 | eval |
| 0.032 | 23 | train |
| 0.032 | 23 | eval |
| 0.032 | 24 | train |
| 0.033 | 24 | eval |
from miniai.fid import ImageEval
cmodel = torch.load('models/data_aug2.pkl')
del(cmodel[8])
del(cmodel[7])
@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))*2-1 for o in b[xl]]
bs = 2048
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus)
dt = dls.train
xb,yb = next(iter(dt))
ie = ImageEval(cmodel, dls, cbs=[DeviceCB()])
sz = (2048,1,32,32)
#| export
def ddim_step(x_t, noise, abar_t, abar_t1, bbar_t, bbar_t1, eta, sig, clamp=True):
sig = ((bbar_t1/bbar_t).sqrt() * (1-abar_t/abar_t1).sqrt()) * eta
x_0_hat = ((x_t-(1-abar_t).sqrt()*noise) / abar_t.sqrt())
if clamp: x_0_hat = x_0_hat.clamp(-1,1)
if bbar_t1<=sig**2+0.01: sig=0. # set to zero if very small or NaN
x_t = abar_t1.sqrt()*x_0_hat + (bbar_t1-sig**2).sqrt()*noise
x_t += sig * torch.randn(x_t.shape).to(x_t)
return x_0_hat,x_t
#| export
@torch.no_grad()
def sample(f, model, sz, steps, eta=1., clamp=True):
model.eval()
ts = torch.linspace(1-1/steps,0,steps)
x_t = torch.randn(sz).cuda()
preds = []
for i,t in enumerate(progress_bar(ts)):
t = t[None].cuda()
abar_t = abar(t)
noise = model((x_t, t))
abar_t1 = abar(t-1/steps) if t>=1/steps else torch.tensor(1)
x_0_hat,x_t = f(x_t, noise, abar_t, abar_t1, 1-abar_t, 1-abar_t1, eta, 1-((i+1)/100), clamp=clamp)
preds.append(x_0_hat.float().cpu())
return preds
# set_seed(42)
preds = sample(ddim_step, model, sz, steps=100, eta=1.)
s = (preds[-1]*2)
s.min(),s.max(),s.shape
(tensor(-1.0918), tensor(1.4292), torch.Size([2048, 1, 32, 32]))
show_images(s[:25].clamp(-1,1), imsize=1.5)
ie.fid(s),ie.kid(s),s.shape
(4.058064770194278, 0.010895456187427044, torch.Size([2048, 1, 32, 32]))
preds = sample(ddim_step, model, sz, steps=100, eta=1.)
ie.fid(preds[-1]*2)
5.320260029850715
preds = sample(ddim_step, model, sz, steps=50, eta=1.)
ie.fid(preds[-1]*2)
5.243807277315682
preds = sample(ddim_step, model, sz, steps=50, eta=1.)
ie.fid(preds[-1]*2)
4.963977301033992
def collate_ddpm(b):
b = default_collate(b)
(xt,t),eps = noisify(b[xl])
return (xt,t,b[yl]),eps
@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))-0.5 for o in b[xl]]
tds = dsd.with_transform(transformi)
dls = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test']))
dl = dls.train
(xt,t,c),eps = b = next(iter(dl))
class CondUNetModel(nn.Module):
def __init__( self, n_classes, in_channels=3, out_channels=3, nfs=(224,448,672,896), num_layers=1):
super().__init__()
self.conv_in = nn.Conv2d(in_channels, nfs[0], kernel_size=3, padding=1)
self.n_temb = nf = nfs[0]
n_emb = nf*4
self.cond_emb = nn.Embedding(n_classes, n_emb)
self.emb_mlp = nn.Sequential(lin(self.n_temb, n_emb, norm=nn.BatchNorm1d),
lin(n_emb, n_emb))
self.downs = nn.ModuleList()
for i in range(len(nfs)):
ni = nf
nf = nfs[i]
self.downs.append(DownBlock(n_emb, ni, nf, add_down=i!=len(nfs)-1, num_layers=num_layers))
self.mid_block = EmbResBlock(n_emb, nfs[-1])
rev_nfs = list(reversed(nfs))
nf = rev_nfs[0]
self.ups = nn.ModuleList()
for i in range(len(nfs)):
prev_nf = nf
nf = rev_nfs[i]
ni = rev_nfs[min(i+1, len(nfs)-1)]
self.ups.append(UpBlock(n_emb, ni, prev_nf, nf, add_up=i!=len(nfs)-1, num_layers=num_layers+1))
self.conv_out = pre_conv(nfs[0], out_channels, act=nn.SiLU, norm=nn.BatchNorm2d, bias=False)
def forward(self, inp):
x,t,c = inp
temb = timestep_embedding(t, self.n_temb)
cemb = self.cond_emb(c)
emb = self.emb_mlp(temb) + cemb
x = self.conv_in(x)
saved = [x]
for block in self.downs: x = block(x, emb)
saved += [p for o in self.downs for p in o.saved]
x = self.mid_block(x, emb)
for block in self.ups: x = block(x, emb, saved)
return self.conv_out(x)
lr = 1e-2
epochs = 25
opt_func = partial(optim.Adam, eps=1e-5)
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
cbs = [DeviceCB(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched), MixedPrecision()]
model = CondUNetModel(10, in_channels=1, out_channels=1, nfs=(32,64,128,256), num_layers=2)
learn = Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=opt_func)
learn.fit(epochs)
| loss | epoch | train |
|---|---|---|
| 0.178 | 0 | train |
| 0.099 | 0 | eval |
| 0.072 | 1 | train |
| 0.066 | 1 | eval |
| 0.053 | 2 | train |
| 0.053 | 2 | eval |
| 0.047 | 3 | train |
| 0.050 | 3 | eval |
| 0.045 | 4 | train |
| 0.045 | 4 | eval |
| 0.042 | 5 | train |
| 0.048 | 5 | eval |
| 0.041 | 6 | train |
| 0.060 | 6 | eval |
| 0.039 | 7 | train |
| 0.042 | 7 | eval |
| 0.037 | 8 | train |
| 0.039 | 8 | eval |
| 0.037 | 9 | train |
| 0.051 | 9 | eval |
| 0.036 | 10 | train |
| 0.039 | 10 | eval |
| 0.035 | 11 | train |
| 0.041 | 11 | eval |
| 0.035 | 12 | train |
| 0.041 | 12 | eval |
| 0.034 | 13 | train |
| 0.035 | 13 | eval |
| 0.034 | 14 | train |
| 0.035 | 14 | eval |
| 0.034 | 15 | train |
| 0.036 | 15 | eval |
| 0.033 | 16 | train |
| 0.037 | 16 | eval |
| 0.033 | 17 | train |
| 0.032 | 17 | eval |
| 0.032 | 18 | train |
| 0.036 | 18 | eval |
| 0.032 | 19 | train |
| 0.033 | 19 | eval |
| 0.032 | 20 | train |
| 0.032 | 20 | eval |
| 0.032 | 21 | train |
| 0.033 | 21 | eval |
| 0.032 | 22 | train |
| 0.033 | 22 | eval |
| 0.031 | 23 | train |
| 0.032 | 23 | eval |
| 0.031 | 24 | train |
| 0.033 | 24 | eval |
sz = (256,1,32,32)
#| export
@torch.no_grad()
def cond_sample(c, f, model, sz, steps, eta=1.):
ts = torch.linspace(1-1/steps,0,steps)
x_t = torch.randn(sz).cuda()
c = x_t.new_full((sz[0],), c, dtype=torch.int32)
preds = []
for i,t in enumerate(progress_bar(ts)):
t = t[None].cuda()
abar_t = abar(t)
noise = model((x_t, t, c))
abar_t1 = abar(t-1/steps) if t>=1/steps else torch.tensor(1)
x_0_hat,x_t = f(x_t, noise, abar_t, abar_t1, 1-abar_t, 1-abar_t1, eta, 1-((i+1)/100))
preds.append(x_0_hat.float().cpu())
return preds
lbls = dsd['train'].features[yl].names
lbls
['T - shirt / top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
set_seed(42)
cid = 0
preds = sample(cid, ddim_step, model, sz, steps=100, eta=1.)
s = (preds[-1]*2)
show_images(s[:25].clamp(-1,1), imsize=1.5, suptitle=lbls[cid])
set_seed(42)
cid = 0
preds = sample(cid, ddim_step, model, sz, steps=100, eta=0.)
s = (preds[-1]*2)
show_images(s[:25].clamp(-1,1), imsize=1.5, suptitle=lbls[cid])
import nbdev; nbdev.nbdev_export()