By Tanishq Abraham
In this notebook, we will implement Denoising Diffusion Probabilistic Models, a seminal paper in the diffusion model literature.
A one-sentence summary: Train a denoising model conditioned on the amount of noise present in the image, and generate samples by iteratively denoising from pure noise to a final sample.
Let's get started with the implementation!
Here are all our imports. The unet file is taken from lucidrains' DDPM implementation just to focus on implementing the training process rather than architectural details.
from fastai.vision.all import *
from fastai.vision.gan import *
from unet import Unet
from copy import deepcopy
Let's load our data. We'll work with the famous CIFAR10 dataset.
bs = 256 # batch size
size = 32 # image size
path = untar_data(URLs.CIFAR)
We use the highly flexible DataBlock API in fastai to create our DataLoaders.
Note that we start with pure noise, generated with the obviously named generate_noise function.
dblock = DataBlock(blocks = (TransformBlock, ImageBlock),
get_x = partial(generate_noise, size=(3,size,size)),
get_items = get_image_files,
splitter = IndexSplitter(range(bs)),
item_tfms=Resize(size),
batch_tfms = Normalize.from_stats(torch.tensor([0.5]), torch.tensor([0.5])))
dls = dblock.dataloaders(path, path=path, bs=bs)
dls.show_batch(max_n=8)
A key aspect of the diffusion models is that our model has the same size input and output:
xb, yb = next(iter(dls.train))
assert xb.shape == yb.shape
DDPM is trained quite simply in a few steps:
We will implement this in a callback. The callback will randomly select the timestep and create the noisy image before setting up our input and ground truth tensors for the model forward pass and loss calculation.
After training, we need to sample from this model. This is an iterative denoising process starting from pure noise. We simply keep removing noise predicted by the neural network, but we do it with an expected noise schedule that is reverse of what we saw during training. This is also done in our callback.
class DDPMCallback(Callback):
def __init__(self, n_steps, beta_min, beta_max, tensor_type=TensorImage):
store_attr()
def before_fit(self):
self.beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps).to(self.dls.device) # variance schedule, linearly increased with timestep
self.alpha = 1. - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
self.sigma = torch.sqrt(self.beta)
def before_batch_training(self):
eps = self.tensor_type(self.xb[0]) # noise, x_T
x0 = self.yb[0] # original images, x_0
batch_size = x0.shape[0]
t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long) # select random timesteps
alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)
xt = torch.sqrt(alpha_bar_t)*x0 + torch.sqrt(1-alpha_bar_t)*eps #noisify the image
self.learn.xb = (xt, t) # input to our model is noisy image and timestep
self.learn.yb = (eps,) # ground truth is the noise
def before_batch_sampling(self):
xt = self.tensor_type(self.xb[0])
for t in reversed(range(self.n_steps)):
t_batch = torch.full((xt.shape[0],), t, device=xt.device, dtype=torch.long)
z = torch.randn(xt.shape, device=xt.device) if t > 0 else torch.zeros(xt.shape, device=xt.device)
alpha_t = self.alpha[t] # get noise level at current timestep
alpha_bar_t = self.alpha_bar[t]
sigma_t = self.sigma[t]
alpha_bar_t_1 = self.alpha_bar[t-1] if t > 0 else torch.tensor(1, device=xt.device)
beta_bar_t = 1 - alpha_bar_t
beta_bar_t_1 = 1 - alpha_bar_t_1
x0hat = (xt - torch.sqrt(beta_bar_t) * self.model(xt, t_batch))/torch.sqrt(alpha_bar_t)
x0hat = torch.clamp(x0hat, -1, 1)
xt = x0hat * torch.sqrt(alpha_bar_t_1)*(1-alpha_t)/beta_bar_t + xt * torch.sqrt(alpha_t)*beta_bar_t_1/beta_bar_t + sigma_t*z
#xt = 1/torch.sqrt(alpha_t) * (xt - (1-alpha_t)/torch.sqrt(1-alpha_bar_t) * self.model(xt, t_batch)) + sigma_t*z # predict x_(t-1) in accordance to Algorithm 2 in paper
self.learn.pred = (xt,)
raise CancelBatchException
def before_batch(self):
if not hasattr(self, 'gather_preds'): self.before_batch_training()
else: self.before_batch_sampling()
Let's now initialize our model:
model = Unet(dim=32)
Now we can create a fastai Learner with our DataLoaders, Callback (with the appropriate number of timesteps and noise schedule) and the simple MSE loss that we use to train DDPM.
ddpm_learner = Learner(dls, model, cbs=[DDPMCallback(n_steps=1000, beta_min=0.0001, beta_max=0.02)], loss_func=nn.MSELoss())
Let's use fastai's amazing LR finder to select a good LR for training:
ddpm_learner.lr_find()
SuggestedLRs(valley=4.365158383734524e-05)
And now let's train with one-cycle LR schedule:
ddpm_learner.fit_one_cycle(10,3e-4)
| epoch | train_loss | valid_loss | time |
|---|---|---|---|
| 0 | 0.327154 | 0.183927 | 00:26 |
| 1 | 0.072795 | 0.055295 | 00:27 |
| 2 | 0.050636 | 0.040971 | 00:27 |
| 3 | 0.044012 | 0.047171 | 00:26 |
| 4 | 0.041888 | 0.037453 | 00:26 |
| 5 | 0.040037 | 0.043432 | 00:27 |
| 6 | 0.038795 | 0.041311 | 00:26 |
| 7 | 0.038451 | 0.038509 | 00:26 |
| 8 | 0.037811 | 0.044651 | 00:27 |
| 9 | 0.038354 | 0.046532 | 00:26 |
ddpm_learner.recorder.plot_loss()
ddpm_learner.save('cifar10-test')
Path('/home/tmabraham/.fastai/data/cifar10/models/cifar10-test.pth')
ddpm_learner = ddpm_learner.load('cifar10-test')
Since we implemented sampling in the Callback, we simply can call fastai's built-in get_preds function to get our predictions and decode them into an image (rescale from [-1, 1] to [0,255]).
preds, targ = ddpm_learner.get_preds()
dls.after_batch.decode((preds,))[0][0].show()
<matplotlib.axes._subplots.AxesSubplot at 0x7f736368b950>
Alternatively we can use fastai's built-in show_results function as well:
ddpm_learner.show_results()
Awesome, we have some images vaguely resembling CIFAR10 images!
Another useful thing to check is the prediction of the completely denoised image at some timestep. Our sampling takes our prediction of noise in the image but takes only a fraction of it to remove from the noisy image during the iterative process. But we can also try to see the full denoising prediction by fully subtracting out the prediction. Of course, at higher noise levels this will be inaccurate, but at lower noise levels it should be quite accurate.
eps = TensorImage(xb)
x0 = yb # original images
batch_size = x0.shape[0]
with torch.no_grad():
t = torch.randint(0, ddpm_learner.ddpm.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
alpha_bar_t = ddpm_learner.ddpm.alpha_bar[t].reshape(-1, 1, 1, 1)
xt = torch.sqrt(alpha_bar_t)*x0 + torch.sqrt(1-alpha_bar_t)*eps # noisy images
x0hat = (xt - torch.sqrt(1-alpha_bar_t)*ddpm_learner.model(xt,t))/torch.sqrt(alpha_bar_t) # predicted denoised images
Timestep 73 is closer to 0 so less noisy but noise is still visible.
np.where((t==73).cpu())
(array([42]),)
Now we can see the original clean image (x0), the noisy image (xt), and the model's attempt to remove the noise (x0hat)
ctxs = get_grid(3,1,3)
ax1 = dls.after_batch.decode((x0,))[0][42].show(ctx=ctxs[0], title='Original')
ax2 = dls.after_batch.decode((xt,))[0][42].show(ctx=ctxs[1], title='Noisy')
ax3 = dls.after_batch.decode((x0hat,))[0][42].show(ctx=ctxs[2], title='Predicted denoised')