By Tanishq Abraham and Thomas Capelle
In this notebook, we will implement Denoising Diffusion Probabilistic Models, a seminal paper in the diffusion model literature.
A one-sentence summary: Train a denoising model conditioned on the amount of noise present in the image, and generate samples by iteratively denoising from pure noise to a final sample conditioned to the label of the image.
The final model is capable to generate an image form a label!
Let's get started with the implementation!
Here are all our imports. The unet file is taken from lucidrains' DDPM implementation just to focus on implementing the training process rather than architectural details.
from fastai.vision.all import *
from fastai.vision.gan import *
from unet import Unet
from copy import deepcopy
from data import *
Let's load our data. We'll work with the famous MNIST dataset.
bs = 512 # batch size
size = 32 # image size
epochs = 100
path = untar_data(URLs.CIFAR)
We use the highly flexible DataBlock API in fastai to create our DataLoaders.
Note that we start with pure noise, generated with the obviously named generate_noise function.
Let's use a labelled dataset and train a conditional model on the label
dblock = DataBlock(blocks = (ImageBlock, CategoryBlock()),
get_items = get_image_files,
get_y = lambda p: p.parent.name,
splitter = IndexSplitter(range(bs)),
item_tfms=Resize(size),
batch_tfms = Normalize.from_stats(0.5, 0.5))
dls = dblock.dataloaders(path, bs=bs)
xb, yb = dls.one_batch()
dls.show_batch()
xb.max(), xb.min(), xb.mean(), xb.std()
(TensorImage(1., device='cuda:0'), TensorImage(-1., device='cuda:0'), TensorImage(-0.0532, device='cuda:0'), TensorImage(0.4992, device='cuda:0'))
Same as DDPM callback but:
class ConditionalDDPMCallback(Callback):
def __init__(self, n_steps, beta_min, beta_max, cfg_scale=0):
store_attr()
self.tensor_type=TensorImage
def before_fit(self):
self.beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps).to(self.dls.device) # variance schedule, linearly increased with timestep
self.alpha = 1. - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
self.sigma = torch.sqrt(self.beta)
def sample_timesteps(self, x, dtype=torch.long):
return torch.randint(self.n_steps, (x.shape[0],), device=x.device, dtype=dtype)
def generate_noise(self, x):
return self.tensor_type(torch.randn_like(x))
def noise_image(self, x, eps, t):
alpha_bar_t = self.alpha_bar[t][:, None, None, None]
return torch.sqrt(alpha_bar_t)*x + torch.sqrt(1-alpha_bar_t)*eps # noisify the image
def before_batch_training(self):
x0 = self.xb[0] # original images and labels
y0 = self.yb[0] if np.random.random() > 0.1 else None
# y0 = None
eps = self.generate_noise(x0) # noise same shape as x0
t = self.sample_timesteps(x0) # select random timesteps
xt = self.noise_image(x0, eps, t) # add noise to the image
# print(x0.shape, y0.shape, t.shape, xt.shape, eps.shape)
self.learn.xb = (xt, t, y0) # input to our model is noisy image, timestep and label
self.learn.yb = (eps,) # ground truth is the noise
def sampling_algo(self, xt, t, label=None):
t_batch = torch.full((xt.shape[0],), t, device=xt.device, dtype=torch.long)
z = self.generate_noise(xt) if t > 0 else torch.zeros_like(xt)
alpha_t = self.alpha[t] # get noise level at current timestep
alpha_bar_t = self.alpha_bar[t]
sigma_t = self.sigma[t]
alpha_bar_t_1 = self.alpha_bar[t-1] if t > 0 else torch.tensor(1, device=xt.device)
beta_bar_t = 1 - alpha_bar_t
beta_bar_t_1 = 1 - alpha_bar_t_1
predicted_noise = self.model(xt, t_batch, label=label)
if self.cfg_scale>0:
uncond_predicted_noise = self.model(xt, t_batch, label=None)
predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, self.cfg_scale)
x0hat = (xt - torch.sqrt(beta_bar_t) * predicted_noise)/torch.sqrt(alpha_bar_t)
x0hat = torch.clamp(x0hat, -1, 1)
xt = x0hat * torch.sqrt(alpha_bar_t_1)*(1-alpha_t)/beta_bar_t + xt * torch.sqrt(alpha_t)*beta_bar_t_1/beta_bar_t + sigma_t*z
return xt
# def sampling_algo_old(self, xt, t, label=None):
# t_batch = torch.full((xt.shape[0],), t, device=xt.device, dtype=torch.long)
# z = self.generate_noise(xt) if t > 0 else torch.zeros_like(xt)
# alpha_t = self.alpha[t] # get noise level at current timestep
# alpha_bar_t = self.alpha_bar[t]
# sigma_t = self.sigma[t]
# xt = 1/torch.sqrt(alpha_t) * (xt - (1-alpha_t)/torch.sqrt(1-alpha_bar_t) * self.model(xt, t_batch, label=label)) + sigma_t*z
# 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
# # predict x_(t-1) in accordance to Algorithm 2 in paper
# return xt
def sample(self):
xt = self.generate_noise(self.xb[0]) # a full batch at once!
label = torch.arange(10, dtype=torch.long, device=xt.device).repeat(xt.shape[0]//10 + 1).flatten()[0:xt.shape[0]]
for t in progress_bar(reversed(range(self.n_steps)), total=self.n_steps, leave=False):
xt = self.sampling_algo(xt, t, label)
return xt
def before_batch_sampling(self):
xt = self.sample()
self.learn.pred = (xt,)
raise CancelBatchException
def after_validate(self):
if (self.epoch+1) % 4 == 0:
with torch.no_grad():
xt = self.sample()
wandb.log({"preds": [wandb.Image(torch.tensor(im)) for im in xt[0:36]]})
def before_batch(self):
if not hasattr(self, 'gather_preds'): self.before_batch_training()
else: self.before_batch_sampling()
class EMA(Callback):
"Exponential Moving average CB"
def __init__(self, beta=0.995, pct_start=0.3):
store_attr()
def before_fit(self):
self.ema_model = deepcopy(self.model).eval().requires_grad_(False)
self.step_start_ema = int(self.pct_start*self.n_epoch) #start EMA at 30% of epochs
def update_model_average(self):
for current_params, ma_params in zip(self.model.parameters(), self.ema_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)
def update_average(self, old, new):
return old * self.beta + (1 - self.beta) * new
def step_ema(self):
if self.epoch < self.step_start_ema:
self.reset_parameters()
self.step += 1
return
self.update_model_average()
self.step += 1
def reset_parameters(self):
self.ema_model.load_state_dict(self.model.state_dict())
def after_batch(self):
if hasattr(self, 'pred'): return
self.step_ema()
def after_training(self):
self.model = self.ema_model
We have to add the conditioning to the Unet, to do so, we just subclass it and inject the encoded label on the forward pass.
@delegates(Unet)
class ConditionalUnet(Unet):
def __init__(self, dim, num_classes=None, **kwargs):
super().__init__(dim=dim, **kwargs)
if num_classes is not None:
self.label_emb = nn.Embedding(num_classes, dim * 4)
def forward(self, x, time, label=None):
x = self.init_conv(x)
t = self.time_mlp(time)
if label is not None:
t += self.label_emb(label)
return super().forward_blocks(x, t)
Let's now initialize our model:
model = ConditionalUnet(dim=32, channels=3, num_classes=10).cuda()
Now we can create a fastai Learner with our DataLoaders, Callback (with the appropriate number of timesteps and noise schedule) and the simple MSE loss that we use to train DDPM.
import wandb
from fastai.callback.wandb import WandbCallback
ddpm_learner = Learner(dls, model,
cbs=[ConditionalDDPMCallback(n_steps=1000, beta_min=0.0001, beta_max=0.02, cfg_scale=3),
EMA()],
loss_func=nn.L1Loss())
Let's use fastai's amazing LR finder to select a good LR for training:
# ddpm_learner.lr_find()
And now let's train with one-cycle LR schedule:
wandb.init(project="ddpm_fastai", group="cifar10", tags=["fp", "ema"])
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving. wandb: Currently logged in as: capecape. Use `wandb login --relogin` to force relogin
/home/tcapelle/wandb/fastdiffusion/nbs/tcapelle/wandb/run-20221007_132515-9xdzrfdb
ddpm_learner.fit_one_cycle(epochs, 3e-4, cbs =[SaveModelCallback(monitor="train_loss", fname="cifar10"),
WandbCallback(log_preds=False, log_model=True)])
| epoch | train_loss | valid_loss | time |
|---|---|---|---|
| 0 | 0.750259 | 0.709923 | 00:42 |
| 1 | 0.632055 | 0.573980 | 00:40 |
| 2 | 0.490789 | 0.438364 | 00:40 |
| 3 | 0.356443 | 0.307267 | 03:25 |
| 4 | 0.270513 | 0.249062 | 00:40 |
| 5 | 0.221594 | 0.210910 | 00:40 |
| 6 | 0.190886 | 0.186583 | 00:40 |
| 7 | 0.168048 | 0.167804 | 03:25 |
| 8 | 0.152539 | 0.157809 | 00:40 |
| 9 | 0.144119 | 0.152306 | 00:40 |
| 10 | 0.137331 | 0.146878 | 00:40 |
| 11 | 0.133565 | 0.139846 | 03:25 |
| 12 | 0.131783 | 0.137660 | 00:40 |
| 13 | 0.128458 | 0.145122 | 00:40 |
| 14 | 0.125062 | 0.137213 | 00:40 |
| 15 | 0.122997 | 0.133304 | 03:25 |
| 16 | 0.121221 | 0.124890 | 00:40 |
| 17 | 0.118911 | 0.134336 | 00:40 |
| 18 | 0.118563 | 0.137339 | 00:40 |
| 19 | 0.117204 | 0.123301 | 03:25 |
| 20 | 0.115335 | 0.125933 | 00:40 |
| 21 | 0.112230 | 0.115233 | 00:40 |
| 22 | 0.111855 | 0.123500 | 00:40 |
| 23 | 0.110091 | 0.133139 | 03:25 |
| 24 | 0.109199 | 0.125465 | 00:40 |
| 25 | 0.109592 | 0.100159 | 00:40 |
| 26 | 0.108045 | 0.126357 | 00:40 |
| 27 | 0.105945 | 0.115808 | 03:25 |
| 28 | 0.105502 | 0.105228 | 00:40 |
| 29 | 0.105925 | 0.119752 | 00:40 |
| 30 | 0.106559 | 0.114725 | 00:40 |
| 31 | 0.104310 | 0.114205 | 03:25 |
| 32 | 0.104533 | 0.118345 | 00:40 |
| 33 | 0.102721 | 0.120694 | 00:40 |
| 34 | 0.102772 | 0.115649 | 00:40 |
| 35 | 0.102187 | 0.115252 | 03:25 |
| 36 | 0.102728 | 0.111810 | 00:40 |
| 37 | 0.100627 | 0.097210 | 00:40 |
| 38 | 0.100899 | 0.109537 | 00:40 |
| 39 | 0.100359 | 0.121335 | 03:25 |
| 40 | 0.101219 | 0.107469 | 00:40 |
| 41 | 0.099695 | 0.118630 | 00:41 |
| 42 | 0.099988 | 0.112241 | 00:40 |
| 43 | 0.099661 | 0.107124 | 03:25 |
| 44 | 0.100341 | 0.106861 | 00:40 |
| 45 | 0.098759 | 0.111692 | 00:40 |
| 46 | 0.098534 | 0.104519 | 00:40 |
| 47 | 0.098645 | 0.107390 | 03:25 |
| 48 | 0.099272 | 0.121547 | 00:40 |
| 49 | 0.097441 | 0.102439 | 00:40 |
| 50 | 0.098245 | 0.116344 | 00:40 |
| 51 | 0.098801 | 0.112504 | 03:25 |
| 52 | 0.097479 | 0.104128 | 00:40 |
| 53 | 0.098048 | 0.105118 | 00:40 |
| 54 | 0.097966 | 0.101353 | 00:40 |
| 55 | 0.097155 | 0.098459 | 03:25 |
| 56 | 0.097588 | 0.096302 | 00:40 |
| 57 | 0.096654 | 0.093959 | 00:40 |
| 58 | 0.096842 | 0.118672 | 00:40 |
| 59 | 0.096856 | 0.109159 | 03:25 |
| 60 | 0.096999 | 0.107374 | 00:40 |
| 61 | 0.097314 | 0.105345 | 00:40 |
| 62 | 0.096345 | 0.101585 | 00:40 |
| 63 | 0.097237 | 0.107552 | 03:26 |
| 64 | 0.096387 | 0.110463 | 00:40 |
| 65 | 0.096522 | 0.107351 | 00:40 |
| 66 | 0.096289 | 0.103139 | 00:40 |
| 67 | 0.095761 | 0.108441 | 03:26 |
| 68 | 0.096019 | 0.113256 | 00:40 |
| 69 | 0.095945 | 0.107207 | 00:40 |
| 70 | 0.095901 | 0.101774 | 00:40 |
| 71 | 0.096231 | 0.101389 | 03:25 |
| 72 | 0.096459 | 0.095912 | 00:41 |
| 73 | 0.095942 | 0.115093 | 00:41 |
| 74 | 0.095968 | 0.099037 | 00:40 |
| 75 | 0.095558 | 0.099857 | 03:25 |
| 76 | 0.095226 | 0.101043 | 00:40 |
| 77 | 0.095414 | 0.106881 | 00:40 |
| 78 | 0.095608 | 0.104538 | 00:40 |
| 79 | 0.095830 | 0.102535 | 03:26 |
| 80 | 0.095843 | 0.106510 | 00:40 |
| 81 | 0.094907 | 0.106933 | 00:40 |
| 82 | 0.095509 | 0.102437 | 00:40 |
| 83 | 0.096072 | 0.108544 | 03:25 |
| 84 | 0.095380 | 0.105316 | 00:41 |
| 85 | 0.094950 | 0.100542 | 00:40 |
| 86 | 0.095622 | 0.103235 | 00:40 |
| 87 | 0.094806 | 0.106054 | 03:25 |
| 88 | 0.094979 | 0.099068 | 00:40 |
| 89 | 0.095287 | 0.099797 | 00:40 |
| 90 | 0.095047 | 0.094034 | 00:40 |
| 91 | 0.095369 | 0.109976 | 03:26 |
| 92 | 0.094981 | 0.099022 | 00:40 |
| 93 | 0.095371 | 0.110711 | 00:40 |
| 94 | 0.095437 | 0.106877 | 00:40 |
| 95 | 0.094820 | 0.099612 | 03:25 |
| 96 | 0.094549 | 0.099947 | 00:41 |
| 97 | 0.095064 | 0.107513 | 00:40 |
| 98 | 0.094772 | 0.104572 | 00:40 |
| 99 | 0.094053 | 0.104388 | 03:25 |
Better model found at epoch 0 with train_loss value: 0.7502593398094177. Better model found at epoch 1 with train_loss value: 0.6320550441741943. Better model found at epoch 2 with train_loss value: 0.4907890260219574.
Better model found at epoch 3 with train_loss value: 0.3564431965351105. Better model found at epoch 4 with train_loss value: 0.27051302790641785. Better model found at epoch 5 with train_loss value: 0.22159354388713837. Better model found at epoch 6 with train_loss value: 0.19088611006736755.
Better model found at epoch 7 with train_loss value: 0.16804789006710052. Better model found at epoch 8 with train_loss value: 0.1525385081768036. Better model found at epoch 9 with train_loss value: 0.14411911368370056. Better model found at epoch 10 with train_loss value: 0.13733148574829102.
Better model found at epoch 11 with train_loss value: 0.13356511294841766. Better model found at epoch 12 with train_loss value: 0.13178279995918274. Better model found at epoch 13 with train_loss value: 0.1284581422805786. Better model found at epoch 14 with train_loss value: 0.12506239116191864.
Better model found at epoch 15 with train_loss value: 0.12299712002277374. Better model found at epoch 16 with train_loss value: 0.12122058123350143. Better model found at epoch 17 with train_loss value: 0.11891119927167892. Better model found at epoch 18 with train_loss value: 0.1185627207159996.
Better model found at epoch 19 with train_loss value: 0.11720366775989532. Better model found at epoch 20 with train_loss value: 0.11533458530902863. Better model found at epoch 21 with train_loss value: 0.11222967505455017. Better model found at epoch 22 with train_loss value: 0.11185462772846222.
Better model found at epoch 23 with train_loss value: 0.11009092628955841. Better model found at epoch 24 with train_loss value: 0.10919912904500961. Better model found at epoch 26 with train_loss value: 0.10804483294487.
Better model found at epoch 27 with train_loss value: 0.10594470798969269. Better model found at epoch 28 with train_loss value: 0.10550212115049362.
Better model found at epoch 31 with train_loss value: 0.10430967062711716. Better model found at epoch 33 with train_loss value: 0.1027207300066948.
Better model found at epoch 35 with train_loss value: 0.10218702256679535. Better model found at epoch 37 with train_loss value: 0.10062697529792786.
Better model found at epoch 39 with train_loss value: 0.10035884380340576. Better model found at epoch 41 with train_loss value: 0.09969516843557358.
Better model found at epoch 43 with train_loss value: 0.09966123104095459. Better model found at epoch 45 with train_loss value: 0.09875939041376114. Better model found at epoch 46 with train_loss value: 0.09853431582450867.
Better model found at epoch 49 with train_loss value: 0.09744056314229965.
Better model found at epoch 55 with train_loss value: 0.09715472906827927. Better model found at epoch 57 with train_loss value: 0.09665428847074509.
Better model found at epoch 62 with train_loss value: 0.09634526073932648.
Better model found at epoch 66 with train_loss value: 0.0962885320186615.
Better model found at epoch 67 with train_loss value: 0.09576054662466049.
Better model found at epoch 75 with train_loss value: 0.09555771201848984. Better model found at epoch 76 with train_loss value: 0.09522632509469986.
Better model found at epoch 81 with train_loss value: 0.09490688145160675.
IOPub message rate exceeded. The Jupyter server will temporarily stop sending output to the client in order to avoid crashing it. To change this limit, set the config variable `--ServerApp.iopub_msg_rate_limit`. Current values: ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec) ServerApp.rate_limit_window=3.0 (secs)
Better model found at epoch 96 with train_loss value: 0.09454851597547531.
Better model found at epoch 99 with train_loss value: 0.09405282139778137.
ddpm_learner.recorder.plot_loss()
Since we implemented sampling in the Callback, we simply can call fastai's built-in get_preds function to get our predictions.
preds = ddpm_learner.get_preds()
we are passing a labels vector that looks like [0,1,2,3,4,5,6,7,8,9,0,1,2,3,4.....]
wandb.Image(torch.tensor(0.5*preds[0][0]+0.5)).image
p = preds[0]
p.shape
torch.Size([512, 3, 32, 32])
p.mean(dim=(0,2,3))
TensorImage([-0.0303, -0.0196, -0.0900])
nrows=5
ncols = int(math.ceil(25/10))
axs = subplots(nrows, 10)[1].flat
for i, (pred, ax) in enumerate(zip(preds[0], axs)):
((pred+1)/2).show(ax=ax, title=dls.vocab[i] if i<10 else None)
def log_table(rows=10):
table = wandb.Table(columns=list(dls.vocab))
for i, row in enumerate(preds[0].split(len(dls.vocab))):
if i<rows:
table.add_data(*[wandb.Image(img.permute(1,2,0).squeeze().cpu().numpy()) for img in row])
wandb.log({"pred_table":table})
log_table()
wandb.finish()
VBox(children=(Label(value='37.925 MB of 37.925 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…
| epoch | ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███ |
| eps_0 | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| lr_0 | ▁▂▂▃▄▅▆▇███████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁ |
| mom_0 | ██▇▆▅▄▃▂▁▁▁▁▁▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▇▇▇▇█████ |
| raw_loss | █▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| sqr_mom_0 | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| train_loss | █▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| train_samples_per_sec | ████▁███████████████████████████████████ |
| valid_loss | █▆▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| wd_0 | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| epoch | 100.0 |
| eps_0 | 1e-05 |
| lr_0 | 0.0 |
| mom_0 | 0.95 |
| raw_loss | 0.09004 |
| sqr_mom_0 | 0.99 |
| train_loss | 0.09405 |
| train_samples_per_sec | 1659.48473 |
| valid_loss | 0.10457 |
| wd_0 | 0.01 |
./wandb/run-20221007_132515-9xdzrfdb/logs
Another useful thing to check is the prediction of the completely denoised image at some timestep. Our sampling takes our prediction of noise in the image but takes only a fraction of it to remove from the noisy image during the iterative process. But we can also try to see the full denoising prediction by fully subtracting out the prediction. Of course, at higher noise levels this will be inaccurate, but at lower noise levels it should be quite accurate.
eps = TensorImage(torch.randn(xb.shape, device=xb.device))
x0 = xb # original images
batch_size = x0.shape[0]
with torch.no_grad():
t = torch.randint(0, ddpm_learner.conditional_ddpm.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
alpha_bar_t = ddpm_learner.conditional_ddpm.alpha_bar[t].reshape(-1, 1, 1, 1)
xt = torch.sqrt(alpha_bar_t)*x0 + torch.sqrt(1-alpha_bar_t)*eps # noisy images
x0hat = (xt - torch.sqrt(1-alpha_bar_t)*ddpm_learner.model(xt,t))/torch.sqrt(alpha_bar_t) # predicted denoised images
Timestep 73 is closer to 0 so less noisy but noise is still visible.
np.where((t==73).cpu())[0]
array([], dtype=int64)
Now we can see the original clean image (x0), the noisy image (xt), and the model's attempt to remove the noise (x0hat)
idx = 383
ctxs = get_grid(3,1,3)
ax1 = dls.after_batch.decode((x0,))[0][idx].show(ctx=ctxs[0], title='Original')
ax2 = dls.after_batch.decode((xt,))[0][idx].show(ctx=ctxs[1], title='Noisy')
ax3 = dls.after_batch.decode((x0hat,))[0][idx].show(ctx=ctxs[2], title='Predicted denoised')