#!/usr/bin/env python # coding: utf-8 # # Attempting a minimal diffusion model # In[1]: #import os #os.environ['CUDA_VISIBLE_DEVICES']='1' #os.environ['OMP_NUM_THREADS']='1' # In[2]: import logging, torch, torchvision, torch.nn.functional as F, torchvision.transforms.functional as TF, matplotlib as mpl from matplotlib import pyplot as plt from functools import partial from torch import tensor,nn,optim from torch.utils.data import DataLoader,default_collate from torchvision.utils import make_grid from datasets import load_dataset,load_dataset_builder from miniai.datasets import * from miniai.learner import * from fastprogress import progress_bar from timm.optim.rmsprop_tf import RMSpropTF from timm.optim.adafactor import Adafactor from timm.optim.lookahead import Lookahead from fastai.callback.schedule import combined_cos from fastai.layers import SequentialEx, MergeLayer from fastai.losses import MSELossFlat from fastcore.basics import store_attr from fastai.torch_core import TensorImage from fastai.optimizer import OptimWrapper from einops import rearrange, repeat import PIL import numpy as np # Perceptual loss import lpips # In[3]: def RmsLookahead(params, alpha=0.5, k=6, *args, **kwargs): opt = RMSpropTF(params, *args, **kwargs) return Lookahead(opt, alpha, k) # In[4]: def AdamLookahead(params, alpha=0.5, k=6, *args, **kwargs): opt = optim.Adam(params, *args, **kwargs) return Lookahead(opt, alpha, k) # In[5]: opt_func = optim.Adam lr_max = 1e-3 bs = 256 sz = 28 device = 'cuda' unet_width = 1 unet_resids = [2,4,4] num_classes = 10 # In[6]: mean = 0.2859 std = 0.353 def normalize_img(x): #return x return (x-mean)/std def denormalize_img(x): #return x return (x*std)+mean # In[7]: loss_fn_alex = lpips.LPIPS(net='alex').to(device) loss_fn_mse = MSELossFlat() resize_sz = 64 if sz < 64 else sz pl_resizer = torchvision.transforms.Resize(resize_sz) def combined_loss(preds, y): eps, x0, xt = y #return loss_fn_mse(preds, eps) pred_images = denormalize_img(pl_resizer(preds+xt)) target_images = denormalize_img(pl_resizer(x0)) return loss_fn_alex.forward(pred_images, target_images).mean() + loss_fn_mse(preds, eps) loss_func = combined_loss # In[8]: mpl.rcParams['image.cmap'] = 'gray_r' logging.disable(logging.WARNING) # Load a dataset: # In[9]: x,y = 'image','label' #name = "mnist" #"fashion_mnist" name = "fashion_mnist" dsd = load_dataset(name) # In[10]: @inplace def transformi(b): b[x] = [normalize_img(TF.to_tensor(o)) for o in b[x]] # In[11]: tds = dsd.with_transform(transformi) dls = DataLoaders.from_dd(tds, bs) dt = dls.train xb,yb = next(iter(dt)) xb.shape,yb[:10] # Define a model: # In[12]: def conv2dks7(inc, outc, stride=1): return nn.Conv2d(inc, outc, kernel_size=7, padding=3, stride=stride) def conv2dks5(inc, outc, stride=1): return nn.Conv2d(inc, outc, kernel_size=7, padding=2, stride=stride) def conv2dks3(inc, outc, stride=1): return nn.Conv2d(inc, outc, kernel_size=3, padding=1, stride=stride) def conv2dks1(inc, outc): return nn.Conv2d(inc, outc, kernel_size=1) def conv2dtrans(inc, outc): return nn.ConvTranspose2d(inc, outc, 4, 2, 1) # In[13]: """ code from here: https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py """ def exists(val): return val is not None def default(val, d): if exists(val): return val return d() if callable(d) else d class PixelShuffleUpsample(nn.Module): def __init__(self, dim, dim_out = None): super().__init__() dim_out = default(dim_out, dim) conv = nn.Conv2d(dim, dim_out * 4, 1) self.net = nn.Sequential( conv, nn.SiLU(), nn.PixelShuffle(2) ) self.init_conv_(conv) def init_conv_(self, conv): o, i, h, w = conv.weight.shape conv_weight = torch.empty(o // 4, i, h, w) nn.init.kaiming_uniform_(conv_weight) conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...') conv.weight.data.copy_(conv_weight) nn.init.zeros_(conv.bias.data) def forward(self, x): return self.net(x) # In[14]: # from https://github.com/crowsonkb/k-diffusion/blob/f4e99857772fc3a126ba886aadf795a332774878/k_diffusion/layers.py _kernels = { 'linear': [1 / 8, 3 / 8, 3 / 8, 1 / 8], 'cubic': [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], 'lanczos3': [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, -0.066637322306633, 0.13550527393817902, 0.44638532400131226, 0.44638532400131226, 0.13550527393817902, -0.066637322306633, -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] } _kernels['bilinear'] = _kernels['linear'] _kernels['bicubic'] = _kernels['cubic'] class Downsample2d(nn.Module): def __init__(self, kernel='linear', pad_mode='reflect'): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor([_kernels[kernel]]) self.pad = kernel_1d.shape[1] // 2 - 1 self.register_buffer('kernel', kernel_1d.T @ kernel_1d) def forward(self, x): x = F.pad(x, (self.pad,) * 4, self.pad_mode) weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) indices = torch.arange(x.shape[1], device=x.device) weight[indices, indices] = self.kernel.to(weight) return F.conv2d(x, weight, stride=2) # In[15]: norm = partial(torch.nn.GroupNorm, num_groups=8) #norm = torch.nn.BatchNorm2d act = torch.nn.SiLU def conv_layer(conv_fn, inc, outc, stride=1): return torch.nn.Sequential(conv_fn(inc, outc), act(), norm(num_channels=outc)) def residual_layer(inc): return SequentialEx(conv_layer(conv2dks3, inc, inc), conv_layer(conv2dks3, inc, inc), MergeLayer(False)) def down_layer(inc, outc, resids=1): layers = [Downsample2d(), conv_layer(conv2dks3, inc, outc)] for i in range(resids): layers.append(residual_layer(outc)) return torch.nn.Sequential(*layers) def up_layer(inc, outc, resids=1): layers = [conv_layer(conv2dks1, inc, inc//2)] for i in range(resids): layers.append(residual_layer(inc//2)) layers.extend([PixelShuffleUpsample(inc//2, outc), act(), norm(num_channels=outc)]) return torch.nn.Sequential(*layers) def out_layer(inc, midc, outc, resids=1): layers = [conv_layer(conv2dks1, inc, midc)] for i in range(resids): layers.append(residual_layer(midc)) layers.append(conv_layer(conv2dks1, midc, outc)) return torch.nn.Sequential(*layers) # In[16]: class BasicUNet(nn.Module): "A minimal UNet implementation." def __init__(self, inc, outc, resids=[1,2,2], width:int=1): super().__init__() self.proj_in = nn.Conv2d(inc, 128*width, 1) self.down_layers = torch.nn.ModuleList([down_layer(128*width, 128*width, resids[0]), down_layer(128*width, 128*width, resids[1]), down_layer(128*width, 256*width, resids[2])]) self.up_layers = torch.nn.ModuleList([up_layer(256*width, 256*width, resids[2]), up_layer(384*width,128*width, resids[1]), up_layer(256*width, 128*width, resids[0])]) self.out_layers = out_layer(inc+128*width, 64*width, 32*width, resids[0]) self.proj_out = nn.Conv2d(32*width, outc, 1) def forward(self, x): x_orig = x.clone() x = self.proj_in(x) h = [] for i, l in enumerate(self.down_layers): x = l(x) h.append(x) #if i < 2: x = F.max_pool2d(x, 2) for i, l in enumerate(self.up_layers): x_cross = h.pop() if x.shape[-2:] != x_cross.shape[-2:]: x = F.interpolate(x, x_cross.shape[-2:], mode='nearest') x = torch.cat([x_cross,x], dim=1) if i > 0 else x x = l(x) if x.shape[-2:] != x_orig.shape[-2:]: x = F.interpolate(x, x_orig.shape[-2:], mode='nearest') x = torch.cat([x_orig,x], dim=1) x = self.out_layers(x) x = self.proj_out(x) # eps can range from about -6 to 6 (0 mean 1 std normal dist *2, basically) x = x.sigmoid()*12-6.0 return x # In[17]: def generate_random_noise(batch_size): #return normalize_img(torch.rand(batch_size, 1, sz, sz)) return torch.randn(batch_size, 1, sz, sz) def generate_random_noise_like(x): #return normalize_img(torch.rand_like(x)) return torch.randn_like(x) # Define the corruption: # In[18]: def corrupt(x, amount): "Corrupt the input `x` by mixing it with noise according to `amount`" noise = generate_random_noise_like(x) amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works xt = (x*(1-amount) + noise*amount) #xt = x- x*amount + noise*amount #xt = x-amount*(x - noise) #xt + amount*(x-noise) = x eps = amount*(x-noise) #x = xt + eps return xt, eps # In[19]: class ConditionalDDPMCallback(Callback): def __init__(self, n_steps=1000, beta_min=0.0001, beta_max=0.02, label_conditioning=False): store_attr() self.tensor_type=TensorImage self.label_conditioning = label_conditioning def before_fit(self): self.beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps).to(device) # variance schedule, linearly increased with timestep self.alpha = 1. - self.beta self.alpha_bar = torch.cumprod(self.alpha, dim=0) self.sigma = torch.sqrt(self.beta) def sample_timesteps(self, x, dtype=torch.long): return torch.randint(self.n_steps, (x.shape[0],), device=device, dtype=dtype) def generate_noise(self, x): return generate_random_noise_like(x) def noise_image(self, x, eps, t): alpha_bar_t = self.alpha_bar[t][:, None, None, None] return torch.sqrt(alpha_bar_t)*x + torch.sqrt(1-alpha_bar_t)*eps # noisify the image def before_batch_training(self): x0 = self.learn.batch[0] # original images x1 = self.learn.batch[1] if self.label_conditioning else None noise_amount = torch.rand(x0.shape[0]).to(device) # Chose random corruption amount xt, eps = corrupt(x0, noise_amount) #eps = self.generate_noise(x0) # noise same shape as x0 #t = self.sample_timesteps(x0) # select random timesteps #xt = self.noise_image(x0, eps, t) # add noise to the image self.learn.xb = (xt,) if x1 is None else (xt,x1)# input to our model is noisy image and timestep self.learn.yb = (eps,x0,xt) # train model to predict noise AND good image from subtracting that noise def sampling_algo(self, xt, t): t_batch = torch.full((xt.shape[0],), t, device=device, dtype=torch.long) z = self.generate_noise(xt) if t > 0 else torch.zeros_like(xt) alpha_t = self.alpha[t] # get noise level at current timestep alpha_bar_t = self.alpha_bar[t] sigma_t = self.sigma[t] alpha_bar_t_1 = self.alpha_bar[t-1] if t > 0 else torch.tensor(1, device=device) beta_bar_t = 1 - alpha_bar_t beta_bar_t_1 = 1 - alpha_bar_t_1 x0hat = (xt - torch.sqrt(beta_bar_t) * self.model(xt))/torch.sqrt(alpha_bar_t) x0hat = torch.clamp(x0hat, -1, 1) xt = x0hat * torch.sqrt(alpha_bar_t_1)*(1-alpha_t)/beta_bar_t + xt * torch.sqrt(alpha_t)*beta_bar_t_1/beta_bar_t + sigma_t*z return xt def sample(self): xt = self.generate_noise(self.xb[0]) # a full batch at once! for t in progress_bar(reversed(range(self.n_steps)), total=self.n_steps, leave=False): xt = self.sampling_algo(xt, t) return xt def before_batch_sampling(self): xt = self.sample() self.learn.pred = (xt,) raise CancelBatchException def before_batch(self): if not hasattr(self, 'gather_preds'): self.before_batch_training() else: self.before_batch_sampling() # In[20]: class MyTrainCB(TrainCB): def predict(self): xb = self.learn.xb self.learn.preds = self.learn.model(*xb) def get_loss(self): self.learn.loss = self.learn.loss_func(self.learn.preds, self.learn.yb) # Logging callback: # In[21]: class LogLossesCB(Callback): def __init__(self): self.losses = [] def after_batch(self): self.losses.append(self.learn.loss.item()) def after_fit(self): plt.plot(self.losses) # I chose to write a new training callback: # In[22]: class OneCycle(Callback): def __init__(self, lr_max): lr_max = lr_max div=25. div_final=1e5 pct_start=0.3 self.lr_sched_fn = combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final) self.ns = [] def after_batch(self): self.ns.append(bs) n_steps = len(self.learn.dls.train) * self.learn.n_epochs * bs i = sum(self.ns) pos = i/(n_steps) lr = self.lr_sched_fn(pos) self.learn.lr = lr def before_fit(self): lr = self.lr_sched_fn(0) self.learn.lr = lr # In[23]: model = BasicUNet(1, 1, width=unet_width, resids=unet_resids) cbs = [CudaCB(), ConditionalDDPMCallback(label_conditioning=False), MyTrainCB(), ProgressCB(), LogLossesCB(), OneCycle(lr_max)] learn = Learner(model, dls, loss_func, lr=lr_max, cbs=cbs, opt_func=opt_func) # In[24]: learn.fit(2) # Viewing the predictions on images with increasing noise levels: # In[25]: # Some noisy data xb = xb[:8].cpu() amount = torch.linspace(0, 1, xb.shape[0]) # Left to right -> more corruption noised_x, noize = corrupt(xb, amount) # In[26]: with torch.no_grad(): preds = model(noised_x.cuda()).detach().cpu() pred_images = denormalize_img(noised_x + preds) # In[27]: def show_grid(ax, tens, title=None): if title: ax.set_title(title) ax.imshow(make_grid(tens.cpu())[0]) # In[28]: fig, axs = plt.subplots(3, 1, figsize=(11, 6)) show_grid(axs[0], xb, 'Input data') show_grid(axs[1], noised_x, 'Corrupted data') show_grid(axs[2], pred_images, 'Network Predictions') # In[29]: plt.hist(preds[0].reshape(-1)) # A very basic sampling method (not ideal), just taking 5 or 10 equal-sized steps towards the models prediction: # In[30]: # Take one: just break the process into 5 or 10 steps and move 1/10'th of the way there each time: n_steps = 5 xb = generate_random_noise(8).to(device)# Start from random step_history = [xb.detach().cpu()] pred_output_history = [] # In[31]: for i in range(n_steps): with torch.no_grad(): pred = model(xb) # Predict the denoised x0 pred_image = xb + pred pred_output_history.append(pred_image.detach().cpu()) # Store model output for plotting mix_factor = 1/(n_steps - i) # How much we move towards the prediction xb = xb*(1-mix_factor) + pred_image*mix_factor # Move part of the way there if i < n_steps-1: step_history.append(denormalize_img(xb.detach().cpu())) # Store step for plotting # In[32]: fig, axs = plt.subplots(n_steps, 2, figsize=(15, n_steps), sharex=True) for i in range(n_steps): axs[i, 0].imshow(make_grid(step_history[i])[0]), axs[i, 1].imshow(make_grid(pred_output_history[i])[0]) # In[33]: fig, axs = plt.subplots(n_steps, 2, figsize=(15, n_steps), sharex=True) for i in range(n_steps): axs[i, 0].imshow(make_grid(step_history[i])[0]), axs[i, 1].imshow(make_grid(pred_output_history[i])[0]) # # Class Conditioning # Giving the model the labels as conditioning. # In[34]: class ClassConditionedUNet(nn.Module): "Wraps a BasicUNet but adds several input channels for class conditioning" def __init__(self, in_channels, out_channels, num_classes=10, class_emb_channels=4, width=1, resids=1): super().__init__() self.class_emb = nn.Embedding(num_classes, class_emb_channels) self.net = BasicUNet(in_channels+class_emb_channels, out_channels, width=width, resids=resids) # input channels = in_channels+1+class_emb_channels def forward(self, x, class_labels): n,c,w,h = x.shape class_cond = self.class_emb(class_labels) # Map to embedding dinemsion class_cond = class_cond.view(n, class_cond.shape[1], 1, 1).expand(n, class_cond.shape[1], w, h) # Reshape # Net input is now x, noise amound and class cond concatenated together net_input = torch.cat((x, class_cond), 1) return self.net(net_input) # In[35]: model = ClassConditionedUNet(1, 1, num_classes=num_classes, width=unet_width, resids=unet_resids) cbs = [CudaCB(), ConditionalDDPMCallback(label_conditioning=True), MyTrainCB(), ProgressCB(), LogLossesCB(), OneCycle(lr_max)] learn = Learner(model, dls, loss_func, lr=1e-3, cbs=cbs, opt_func=opt_func) # In[36]: learn.fit(50) # In[37]: torch.save(learn.model.state_dict(), 'fashion_mnist.pt') # In[38]: model.load_state_dict(torch.load('./fashion_mnist.pt')) model=model.to(device).eval() # In[39]: # source: https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py def right_pad_dims_to(x, t): padding_dims = x.ndim - t.ndim if padding_dims <= 0: return t return t.view(*t.shape, *((1,) * padding_dims)) def dynamic_thresholding(x_start, dynamic_thresholding_percentile=0.9, max_val=3.0): x_start = x_start/max_val s = torch.quantile( rearrange(x_start, 'b ... -> b (...)').abs(), dynamic_thresholding_percentile, dim = -1 ) s.clamp_(min=1.0) s = right_pad_dims_to(x_start, s) x_start = x_start.clamp(-s, s) / s return x_start*max_val # In[45]: def sample_based_on_optimizer(model, xt, class_labels, optim_fn, n_steps, lr_max, pct_start): div=100000. div_final=1e5 lr_sched_fn = combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final) optim = OptimWrapper(opt=optim_fn([xt], lr=lr_max)) prev_noise_var = 1e6 eps = None xt.requires_grad = True for i in range(n_steps): pos = i/(n_steps) lr = lr_sched_fn(pos) optim.set_hyper('lr', lr) in_xt = xt with torch.no_grad(): in_xt = dynamic_thresholding(in_xt, 0.9, 3.0) eps = model(in_xt, class_labels) #noise_pred = (net_output + flip_output)/2.0 # # Early stopping if (eps.float()).var() < 0.01 or i == n_steps-1: print('Stopping at:', i) return xt, eps #print(xt-in_xt) xt.grad = -eps.float() + (xt-in_xt) grad_mean = xt.grad.mean() prev_noise_var = eps.float().var() if i%10==0: #print('xt:', i, xt.mean(), xt.max(), xt.min()) #print('eps:', i, eps.mean(), eps.max(), eps.min(), eps.float().var()) print(i, xt.grad.mean(), xt.grad.var(), xt.float().var(), lr) # Useful to watch the noise variance optim.step() optim.zero_grad() if i%10==0: print('xt:', i, xt.mean(), xt.max(), xt.min()) #print('eps:', i, eps.mean(), eps.max(), eps.min(), eps.float().var()) #print(i, xt.grad.mean(), xt.grad.var(), xt.float().var(), lr) # Useful to watch the noise variance #print(grad_mean) if grad_mean < 0.0: return xt, eps # print('grad: ' + str(grad_mean)) #xt = xt + torch.randn_like(xt)*0.0025 # amount = torch.ones((xt.shape[0], )).to(device)*0.005 # xt,_ = corrupt(xt, amount) return xt, eps def generate_opt_based_samples(model, xt, class_labels): sgd_optim_fn = partial(torch.optim.SGD, momentum=0.01) xt = xt.clone() #xt, eps = sample_based_on_optimizer(model, xt=xt, class_labels=class_labels, optim_fn=sgd_optim_fn, n_steps=40, lr_max=2e-2, pct_start=0.25) xt, eps = sample_based_on_optimizer(model, xt=xt, class_labels=class_labels, optim_fn=sgd_optim_fn, n_steps=15, lr_max=3e-2, pct_start=0.5) xt, eps = sample_based_on_optimizer(model, xt=xt, class_labels=class_labels, optim_fn=RmsLookahead, n_steps=20, lr_max=5e-2, pct_start=0.25) pred_image = xt+eps print(pred_image.min(), pred_image.max(), pred_image.mean()) return pred_image, eps # In[46]: torch.random.manual_seed(1000) num_examples = 80 num_classes = 10 xt = generate_random_noise(num_examples).to(device) ## Seem to get better results with 0.95 noise_amount rather than 1.0 noise_amount = torch.ones((xt.shape[0], )).to(device)*0.95 xt,_ = corrupt(xt, noise_amount) class_labels = torch.tensor([[i]*(num_examples//num_classes) for i in range(num_classes)]).flatten().to(device) pred_image, eps = generate_opt_based_samples(model, xt=xt, class_labels=class_labels) pred_image = denormalize_img(pred_image).detach().cpu() xt = xt.detach().cpu() eps = eps.detach().cpu() pred_img = (pred_image).clip(0, 1) fig, ax = plt.subplots(1, 1, figsize=(12, 12)) ax.imshow(make_grid(pred_img, nrow=8)[0]); # Sampling as before over 20 steps, but this time with the labels as conditioning: # In[42]: n_steps = 15 xb = generate_random_noise(80).cuda() yb = torch.tensor([[i]*8 for i in range(10)]).flatten().cuda() for i in range(n_steps): noise_amount = torch.ones((xb.shape[0], )).to(device) * (1-(i/n_steps)) with torch.no_grad(): pred = model(xb, yb) pred_image = xb + pred #pred_image = dynamic_thresholding(pred_image, 0.9, 3.0) mix_factor = 1/(n_steps - i) xb = xb*(1-mix_factor) + pred_image*mix_factor #xb = dynamic_thresholding(xb, 0.85) #xb = dynamic_thresholding(xb, 0.9999) # Optional: Add a bit of extra noise back at early steps if i < 10: xb,_ = corrupt(xb, torch.ones((xb.shape[0], )).to(device)*0.05) fig, ax = plt.subplots(1, 1, figsize=(12, 12)) ax.imshow(make_grid(denormalize_img(xb).detach().cpu().clip(0, 1), nrow=8)[0]); # You can try fashion_mnist as the dataset without making any changes. This seems to work (suprisingly given the lack of fiddling with training and architecture). # In[43]: def count_parameters(module): return sum(p.numel() for p in module.parameters()) # In[44]: count_parameters(model) # In[ ]: