By Tanishq Abraham and Thomas Capelle
In this notebook, we will implement Denoising Diffusion Probabilistic Models, a seminal paper in the diffusion model literature.
A one-sentence summary: Train a denoising model conditioned on the amount of noise present in the image, and generate samples by iteratively denoising from pure noise to a final sample conditioned to the label of the image.
The final model is capable to generate an image form a label!
Let's get started with the implementation!
Here are all our imports. The unet file is taken from lucidrains' DDPM implementation just to focus on implementing the training process rather than architectural details.
from fastai.vision.all import *
from fastai.vision.gan import *
from unet import Unet
from copy import deepcopy
Let's load our data. We'll work with the famous MNIST dataset.
bs = 256 # batch size
size = 32 # image size
path = untar_data(URLs.MNIST)
We use the highly flexible DataBlock API in fastai to create our DataLoaders.
Note that we start with pure noise, generated with the obviously named generate_noise function.
Let's use a labelled dataset and train a conditional model on the label
dblock = DataBlock(blocks = (ImageBlock(cls=PILImageBW), CategoryBlock()),
get_items = get_image_files,
get_y = lambda p: p.parent.name,
splitter = IndexSplitter(range(bs)),
item_tfms=Resize(size),
batch_tfms = Normalize.from_stats(torch.tensor([0.5]), torch.tensor([0.5])))
dls = dblock.dataloaders(path, path=path, bs=bs)
xb, yb = next(iter(dls.train))
xb.shape, yb.shape
(torch.Size([256, 1, 32, 32]), torch.Size([256]))
dls.show_batch(max_n=8)
Same as DDPM callback but:
class ConditionalDDPMCallback(Callback):
def __init__(self, n_steps, beta_min, beta_max, tensor_type=TensorImage):
store_attr()
def before_fit(self):
self.beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps).to(self.dls.device) # variance schedule, linearly increased with timestep
self.alpha = 1. - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
self.sigma = torch.sqrt(self.beta)
def before_batch_training(self):
x0 = self.xb[0] # original images, x_0
eps = self.tensor_type(torch.randn(x0.shape, device=x0.device)) # noise, x_T
batch_size = x0.shape[0]
t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long) # select random timesteps
alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)
xt = torch.sqrt(alpha_bar_t)*x0 + torch.sqrt(1-alpha_bar_t)*eps #noisify the image
self.learn.xb = (xt, t, self.yb[0]) # input to our model is noisy image and timestep
self.learn.yb = (eps,) # ground truth is the noise
def before_batch_sampling(self):
xt = self.tensor_type(self.xb[0]) # a full batch at once!
batch_size = xt.shape[0]
label = torch.arange(10, dtype=torch.long, device=xt.device).repeat(batch_size//10 + 1).flatten()[0:batch_size]
for t in progress_bar(reversed(range(self.n_steps)), total=self.n_steps, leave=False):
t_batch = torch.full((batch_size,), t, device=xt.device, dtype=torch.long)
z = torch.randn(xt.shape, device=xt.device) if t > 0 else torch.zeros(xt.shape, device=xt.device)
alpha_t = self.alpha[t] # get noise level at current timestep
alpha_bar_t = self.alpha_bar[t]
sigma_t = self.sigma[t]
xt = 1/torch.sqrt(alpha_t) * (xt - (1-alpha_t)/torch.sqrt(1-alpha_bar_t) * self.model(xt, t_batch, label=label)) + sigma_t*z # predict x_(t-1) in accordance to Algorithm 2 in paper
self.learn.pred = (xt,)
raise CancelBatchException
def before_batch(self):
if not hasattr(self, 'gather_preds'): self.before_batch_training()
else: self.before_batch_sampling()
We have to add the conditioning to the Unet, to do so, we just subclass it and inject the encoded label on the forward pass.
@delegates(Unet)
class ConditionalUnet(Unet):
def __init__(self, dim, num_classes=None, **kwargs):
super().__init__(dim=dim, **kwargs)
if num_classes is not None:
self.label_emb = nn.Embedding(num_classes, dim * 4)
def forward(self, x, time, label=None):
x = self.init_conv(x)
t = self.time_mlp(time)
if label is not None:
t += self.label_emb(label)
return super().forward_blocks(x, t)
Let's now initialize our model:
model = ConditionalUnet(dim=32, channels=1, num_classes=10).cuda()
Now we can create a fastai Learner with our DataLoaders, Callback (with the appropriate number of timesteps and noise schedule) and the simple MSE loss that we use to train DDPM.
ddpm_learner = Learner(dls, model,
cbs=[ConditionalDDPMCallback(n_steps=1000, beta_min=0.0001, beta_max=0.02, tensor_type=TensorImageBW)],
loss_func=nn.MSELoss()).to_fp16()
Let's use fastai's amazing LR finder to select a good LR for training:
ddpm_learner.lr_find()
SuggestedLRs(valley=5.248074739938602e-05)
And now let's train with one-cycle LR schedule:
ddpm_learner.fit_one_cycle(10,3e-4)
| epoch | train_loss | valid_loss | time |
|---|---|---|---|
| 0 | 0.135299 | 0.097977 | 00:46 |
| 1 | 0.038314 | 0.048950 | 00:46 |
| 2 | 0.027058 | 0.028015 | 00:45 |
| 3 | 0.023399 | 0.024194 | 00:45 |
| 4 | 0.021188 | 0.025359 | 00:46 |
| 5 | 0.019909 | 0.023826 | 00:45 |
| 6 | 0.018747 | 0.023043 | 00:45 |
| 7 | 0.018231 | 0.019069 | 00:45 |
| 8 | 0.017847 | 0.020772 | 00:45 |
| 9 | 0.017604 | 0.026471 | 00:45 |
ddpm_learner.recorder.plot_loss()
Since we implemented sampling in the Callback, we simply can call fastai's built-in get_preds function to get our predictions.
preds = ddpm_learner.get_preds()
we are passing a labels vector that looks like [0,1,2,3,4,5,6,7,8,9,0,1,2,3,4.....]
nrows=5
ncols = int(math.ceil(25/10))
axs = subplots(nrows, 10)[1].flat
for pred, ax in zip(preds[0], axs):
pred.show(ax=ax)
Awesome, we got a simple MNIST digit!
Another useful thing to check is the prediction of the completely denoised image at some timestep. Our sampling takes our prediction of noise in the image but takes only a fraction of it to remove from the noisy image during the iterative process. But we can also try to see the full denoising prediction by fully subtracting out the prediction. Of course, at higher noise levels this will be inaccurate, but at lower noise levels it should be quite accurate.
eps = TensorImageBW(torch.randn(xb.shape, device=xb.device))
x0 = xb # original images
batch_size = x0.shape[0]
with torch.no_grad():
t = torch.randint(0, ddpm_learner.conditional_ddpm.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
alpha_bar_t = ddpm_learner.conditional_ddpm.alpha_bar[t].reshape(-1, 1, 1, 1)
xt = torch.sqrt(alpha_bar_t)*x0 + torch.sqrt(1-alpha_bar_t)*eps # noisy images
x0hat = (xt - torch.sqrt(1-alpha_bar_t)*ddpm_learner.model(xt,t))/torch.sqrt(alpha_bar_t) # predicted denoised images
Timestep 73 is closer to 0 so less noisy but noise is still visible.
np.where((t==73).cpu())
(array([252]),)
Now we can see the original clean image (x0), the noisy image (xt), and the model's attempt to remove the noise (x0hat)
ctxs = get_grid(3,1,3)
ax1 = x0[5].show(ctx=ctxs[0], title='Original')
ax2 = xt[5].show(ctx=ctxs[1], title='Noisy')
ax3 = x0hat[5].show(ctx=ctxs[2], title='Predicted denoised')