Now that we written our own barebones training library, let's make some progress towards exploring diffusion model and building Stable Diffusion from scratch.
We'll start with building and training the model described in the seminal 2020 paper Denoising Diffusion Probabilistic Models (DDPM). For more context, while diffusion models were technically invented back in 2015, diffusion models flew under the radar until this 2020 paper since they were complicated and difficult to train. The 2020 paper introducing DDPMs made some crucial assumptions that significantly simplify the model training and generation processes, as we will see here. Later versions of diffusion models all build upon the same framework introduced in this paper.
Let's get started and train our own DDPM!
We'll start with some imports.
import pickle,gzip,math,os,time,shutil,torch,random,logging
import fastcore.all as fc,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
from collections.abc import Mapping
from pathlib import Path
from operator import attrgetter,itemgetter
from functools import partial
from copy import copy
from contextlib import contextmanager
from fastcore.foundation import L
import torchvision.transforms.functional as TF,torch.nn.functional as F
from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
from torch.nn import init
from torch.optim import lr_scheduler
from torcheval.metrics import MulticlassAccuracy
from datasets import load_dataset,load_dataset_builder
from miniai.datasets import *
from miniai.conv import *
from miniai.learner import *
from miniai.activations import *
from miniai.init import *
from miniai.sgd import *
from miniai.resnet import *
from miniai.augment import *
mpl.rcParams['image.cmap'] = 'gray'
logging.disable(logging.WARNING)
We will load the dataset from HuggingFace Hub:
x,y = 'image','label'
name = "fashion_mnist"
dsd = load_dataset(name)
0%| | 0/2 [00:00<?, ?it/s]
To make life simpler (mostly with the model architecture), we'll resize the 28x28 images to 32x32:
@inplace
def transformi(b): b[x] = [TF.resize(TF.to_tensor(o), (32,32)) for o in b[x]]
Let's set our batch size and create our DataLoaders with this batch size. we can confirm the shapes are correct. Note that while we do get the labels for the dataset, we actuallydon't care about that for our task of unconditional image generation.
set_seed(42)
bs = 128
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs, num_workers=8)
dt = dls.train
xb,yb = next(iter(dt))
xb.shape,yb[:10]
/home/jhoward/mambaforge/lib/python3.9/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True). warnings.warn( /home/jhoward/mambaforge/lib/python3.9/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True). warnings.warn( /home/jhoward/mambaforge/lib/python3.9/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True). warnings.warn( /home/jhoward/mambaforge/lib/python3.9/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True). warnings.warn( /home/jhoward/mambaforge/lib/python3.9/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True). warnings.warn(
(torch.Size([128, 1, 32, 32]), tensor([5, 7, 4, 7, 3, 8, 9, 5, 3, 1]))
We will create a U-net. A U-net looks something like this:
The DDPM U-net is a modification of this with some modern tricks like using attention.
We will cover how U-nets are created and how modules like attention work in future lessons. For now, we'll import the U-net from the diffusers library:
from diffusers import UNet2DModel
model = UNet2DModel(in_channels=1, out_channels=1, block_out_channels=(32, 64, 128, 128))
DDPM is trained quite simply in a few steps:
We will implement this in a callback. The callback will randomly select the timestep and create the noisy image before setting up our input and ground truth tensors for the model forward pass and loss calculation.
After training, we need to sample from this model. This is an iterative denoising process starting from pure noise. We simply keep removing noise predicted by the neural network, but we do it with an expected noise schedule that is reverse of what we saw during training. This is also done in our callback.
class DDPMCB(TrainCB):
order = DeviceCB.order+1
def __init__(self, n_steps, beta_min, beta_max):
super().__init__()
self.n_steps,self.βmin,self.βmax = n_steps,beta_min,beta_max
# variance schedule, linearly increased with timestep
self.β = torch.linspace(self.βmin, self.βmax, self.n_steps)
self.α = 1. - self.β
self.ᾱ = torch.cumprod(self.α, dim=0)
self.σ = self.β.sqrt()
def predict(self, learn): learn.preds = learn.model(*learn.batch[0]).sample
def before_batch(self, learn):
device = learn.batch[0].device
ε = torch.randn(learn.batch[0].shape, device=device) # noise, x_T
x0 = learn.batch[0] # original images, x_0
self.ᾱ = self.ᾱ.to(device)
n = x0.shape[0]
# select random timesteps
t = torch.randint(0, self.n_steps, (n,), device=device, dtype=torch.long)
ᾱ_t = self.ᾱ[t].reshape(-1, 1, 1, 1).to(device)
xt = ᾱ_t.sqrt()*x0 + (1-ᾱ_t).sqrt()*ε #noisify the image
# input to our model is noisy image and timestep, ground truth is the noise
learn.batch = ((xt, t), ε)
@torch.no_grad()
def sample(self, model, sz):
device = next(model.parameters()).device
x_t = torch.randn(sz, device=device)
preds = []
for t in reversed(range(self.n_steps)):
t_batch = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)
z = (torch.randn(x_t.shape) if t > 0 else torch.zeros(x_t.shape)).to(device)
ᾱ_t1 = self.ᾱ[t-1] if t > 0 else torch.tensor(1)
b̄_t = 1 - self.ᾱ[t]
b̄_t1 = 1 - ᾱ_t1
noise_pred = learn.model(x_t, t_batch).sample
x_0_hat = ((x_t - b̄_t.sqrt() * noise_pred)/self.ᾱ[t].sqrt()).clamp(-1,1)
x0_coeff = ᾱ_t1.sqrt()*(1-self.α[t])/b̄_t
xt_coeff = self.α[t].sqrt()*b̄_t1/b̄_t
x_t = x_0_hat*x0_coeff + x_t*xt_coeff + self.σ[t]*z
preds.append(x_t.cpu())
return preds
Okay now we're ready to train a model!
Let's create our Learner. We'll add our callbacks and train with MSE loss.
We specify the number of timesteps and the minimum and maximum variance for the DDPM model.
lr = 4e-3
epochs = 5
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
ddpm_cb = DDPMCB(n_steps=1000, beta_min=0.0001, beta_max=0.02)
cbs = [ddpm_cb, DeviceCB(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched)]
learn = Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=optim.Adam)
Now let's run the fit function:
learn.fit(epochs)
| loss | epoch | train |
|---|---|---|
| 0.059 | 0 | train |
| 0.024 | 0 | eval |
| 0.021 | 1 | train |
| 0.020 | 1 | eval |
| 0.019 | 2 | train |
| 0.017 | 2 | eval |
| 0.017 | 3 | train |
| 0.017 | 3 | eval |
| 0.016 | 4 | train |
| 0.016 | 4 | eval |
mdl_path = Path('models')
mdl_path.mkdir(exist_ok=True)
torch.save(learn.model, mdl_path/'fashion_ddpm.pkl')
learn.model = torch.load(mdl_path/'fashion_ddpm.pkl')
Now that we've trained our model, let's generate some images with our model:
set_seed(42)
samples = ddpm_cb.sample(learn.model, (16, 1, 32, 32))
len(samples)
1000
show_images(-samples[-1], figsize=(5,5))
Let's visualize the sampling process:
%matplotlib auto
import matplotlib.animation as animation
from IPython.display import display, HTML
fig,ax = plt.subplots(figsize=(3,3))
def _show_i(i): return show_image(-samples[i][9], ax=ax, animated=True).get_images()
r = L.range(800,990, 5)+L.range(990,1000)+[999]*10
ims = r.map(_show_i)
animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=3000)
display(HTML(animate.to_html5_video()))
Using matplotlib backend: <object object>
Note that I only take the steps between 800 and 1000 since most of the previous steps are actually quite noisy. This is a limitation of the noise schedule used for small images, and papers like Improved DDPM suggest other noise schedules for this purpose! (Some potential homework: try out the noise schedule from Improved DDPM and see if it helps.)