#!/usr/bin/env python # coding: utf-8 # In[ ]: import os os.environ['CUDA_VISIBLE_DEVICES']='1' os.environ['OMP_NUM_THREADS']='1' # In[ ]: #!pip install -q diffusers datasets wandb lpips timm # In[ ]: import wandb wandb.login() # # Set Up Dataloaders # In[ ]: #@title imports import wandb import torch import torchvision from torch import nn from torchvision import transforms as T from torchvision.transforms import functional as TF from fastai.data.all import * from fastai.vision.all import * from fastai.callback.wandb import * from timm.optim.rmsprop_tf import RMSpropTF from timm.optim.lookahead import Lookahead from einops import rearrange device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f'Using device: {device}') # In[ ]: def RmsLookahead(params, alpha=0.5, k=6, *args, **kwargs): rmsprop = RMSpropTF(params, *args, **kwargs) return Lookahead(rmsprop, alpha, k) # In[ ]: #Training Config num_epochs = 10 bs = 128 # the batch size lr_max = 1e-4 # the max learning rate #opt_func = partial(OptimWrapper, opt=torch.optim.AdamW, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3, lr=lr_max) opt_func = partial(OptimWrapper, opt=RmsLookahead, weight_decay=1e-3, lr=lr_max) #Model Config sz = 64 in_channels = 3 depths= [2, 4, 4] channels= [128, 256, 512] self_attn_depths = [False, False, True] dropout_rate = 0.05 num_samples = 64 name = 'FastDiffusion_KDiff_CelebA_OptSamp' # the name of the run wandb_project = 'FastDiffusion_KDiff_CelebA_OptSamp' # the wandb project name (specify this to enable wandb) wandb_save_model = False # save model to wandb dataset_name = 'CelebA' # wandb name for dataset used comments = 'Celeb faces using kdiff-based unet and optimizer based sampling.' # comments logged in wandb demo_imgs_dir = './demo_images' metrics_dir = './metrics' #Model Save/Load checkpoints_dir = './checkpoints' model_path = Path(checkpoints_dir +'/' + name + '.pt') model_path.parent.mkdir(exist_ok=True) # In[ ]: # Perceptual loss import lpips # In[ ]: #@title dataset from HF from torchvision import transforms as T from datasets import load_dataset dataset = load_dataset('huggan/CelebA-faces') tfm = T.Compose([T.Resize(sz), T.CenterCrop(sz)]) def transforms(examples): examples["image"] = [tfm(image.convert("RGB")) for image in examples["image"]] return examples dataset = dataset.with_transform(transforms)['train'] # In[ ]: # Example 64px image dataset[0]['image'] # In[ ]: # Class for crappified image class PILImageNoised(PILImage): pass class TensorImageNoised(TensorImage): def show(self, ctx=None, **kwargs): super().show(ctx=ctx, **kwargs) PILImageNoised._tensor_cls = TensorImageNoised # Transform (TODO experiment) class Crappify(Transform): def encodes(self, x:TensorImageNoised): x = IntToFloatTensor()(x) blurred = T.GaussianBlur(3)(x) # Add some random blur noise_amount = torch.rand(x.shape[0], device=x.device) noise = torch.rand_like(x, device=x.device) noised = torch.lerp(blurred, noise, noise_amount.view(-1, 1, 1, 1)) * 255 return noised # Dataloader dblock = DataBlock(blocks=(ImageBlock(cls=PILImageNoised),ImageBlock(cls=PILImage)), get_items=lambda pth: range(len(dataset)), # Gets the indexes getters=[lambda idx: np.array(dataset[idx]['image'])]*2, batch_tfms=[Crappify]) # dls = dblock.dataloaders('', bs=128) dls = dblock.dataloaders('', bs=bs) # Half batch size to save mem when using extra nn for perceptual loss dls.show_batch() # # Model and training # In[ ]: _kernels = { 'linear': [1 / 8, 3 / 8, 3 / 8, 1 / 8], 'cubic': [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], 'lanczos3': [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, -0.066637322306633, 0.13550527393817902, 0.44638532400131226, 0.44638532400131226, 0.13550527393817902, -0.066637322306633, -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] } _kernels['bilinear'] = _kernels['linear'] _kernels['bicubic'] = _kernels['cubic'] class Downsample2d(nn.Module): def __init__(self, kernel='linear', pad_mode='reflect'): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor([_kernels[kernel]]) self.pad = kernel_1d.shape[1] // 2 - 1 self.register_buffer('kernel', kernel_1d.T @ kernel_1d) def forward(self, x): x = F.pad(x, (self.pad,) * 4, self.pad_mode) weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) indices = torch.arange(x.shape[1], device=x.device) weight[indices, indices] = self.kernel.to(weight) return F.conv2d(x, weight, stride=2) class Upsample2d(nn.Module): def __init__(self, kernel='linear', pad_mode='reflect'): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor([_kernels[kernel]]) * 2 self.pad = kernel_1d.shape[1] // 2 - 1 self.register_buffer('kernel', kernel_1d.T @ kernel_1d) def forward(self, x): x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) indices = torch.arange(x.shape[1], device=x.device) weight[indices, indices] = self.kernel.to(weight) return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1) # In[ ]: def orthogonal_(module): nn.init.orthogonal_(module.weight) return module class ResidualBlock(nn.Module): def __init__(self, *main, skip=None): super().__init__() self.main = nn.Sequential(*main) self.skip = skip if skip else nn.Identity() def forward(self, input): return self.main(input) + self.skip(input) class ResConvBlock(nn.Module): def __init__(self, c_in, c_mid, c_out, group_size=32, dropout_rate=0.): super().__init__() layers = [ nn.GroupNorm(num_groups=max(1, c_in // group_size), num_channels=c_in), nn.GELU(), nn.Conv2d(c_in, c_mid, 3, padding=1), nn.Dropout2d(dropout_rate, inplace=True), nn.GroupNorm(num_groups=max(1, c_mid // group_size), num_channels=c_mid), nn.GELU(), nn.Conv2d(c_mid, c_out, 3, padding=1), nn.Dropout2d(dropout_rate, inplace=True) ] self.main = nn.Sequential(*layers) self.skip = nn.Identity() if c_in == c_out else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False)) def forward(self, input): skip = self.skip(input) input = self.main(input) return input + skip class SelfAttention2d(nn.Module): def __init__(self, c_in, n_head, norm, dropout_rate=0.): super().__init__() assert c_in % n_head == 0 self.norm_in = norm(c_in) self.n_head = n_head self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1) self.out_proj = nn.Conv2d(c_in, c_in, 1) self.dropout = nn.Dropout(dropout_rate) def forward(self, input): n, c, h, w = input.shape qkv = self.qkv_proj(self.norm_in(input)) qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3) q, k, v = qkv.chunk(3, dim=1) scale = k.shape[3] ** -0.25 att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) att = self.dropout(att) y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w]) return input + self.out_proj(y) class DBlock(nn.Sequential): def __init__(self, n_layers, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False): modules = [nn.Identity()] for i in range(n_layers): my_c_in = c_in if i == 0 else c_mid my_c_out = c_mid if i < n_layers - 1 else c_out modules.append(ResConvBlock( my_c_in, c_mid, my_c_out, group_size, dropout_rate)) if self_attn: norm = lambda c_in: nn.GroupNorm(num_groups=max(1, my_c_out // group_size), num_channels=c_in) modules.append(SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate)) super().__init__(*modules) self.set_downsample(downsample) def set_downsample(self, downsample): self[0] = Downsample2d() if downsample else nn.Identity() return self class UBlock(nn.Sequential): def __init__(self, n_layers, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False): modules = [] for i in range(n_layers): my_c_in = c_in if i == 0 else c_mid my_c_out = c_mid if i < n_layers - 1 else c_out modules.append(ResConvBlock(my_c_in, c_mid, my_c_out, group_size, dropout_rate)) if self_attn: norm = lambda c_in: nn.GroupNorm(num_groups=max(1, my_c_out // group_size), num_channels=c_in) modules.append(SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate)) modules.append(nn.Identity()) super().__init__(*modules) self.set_upsample(upsample) def forward(self, input, skip=None): if skip is not None: input = torch.cat([input, skip], dim=1) return super().forward(input) def set_upsample(self, upsample): self[-1] = Upsample2d() if upsample else nn.Identity() return self # In[ ]: class UNet(nn.Module): def __init__(self, depths, channels, self_attn_depths, c_in=3, dropout_rate=0.0): super().__init__() self.proj_in = nn.Conv2d(c_in, channels[0], 1) self.proj_out = nn.Conv2d(channels[0], c_in, 1) nn.init.zeros_(self.proj_out.weight) nn.init.zeros_(self.proj_out.bias) d_blocks, u_blocks = [], [] for i in range(len(depths)): my_c_in = channels[max(0, i - 1)] d_blocks.append(DBlock(depths[i], my_c_in, channels[i], channels[i], downsample=i>0, self_attn=self_attn_depths[i], dropout_rate=dropout_rate)) for i in range(len(depths)): my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i] my_c_out = channels[max(0, i - 1)] u_blocks.append(UBlock(depths[i], my_c_in, channels[i], my_c_out, upsample=i>0, self_attn=self_attn_depths[i], dropout_rate=dropout_rate)) self.d_blocks = nn.ModuleList(d_blocks) self.u_blocks = nn.ModuleList(reversed(u_blocks)) def forward(self, input): input = self.proj_in(input) skips = [] for block in self.d_blocks: input = block(input) skips.append(input) for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))): input = block(input, skip if i > 0 else None) input = self.proj_out(input) return input model = UNet( depths=depths, channels=channels, c_in = in_channels, self_attn_depths=self_attn_depths, dropout_rate=dropout_rate ).to(device) # In[ ]: #model # In[ ]: # Create a learner and pick LR # learn = Learner(dls, model, loss_func=MSELossFlat()) loss_fn_alex = lpips.LPIPS(net='alex').to(device) loss_fn_mse = MSELossFlat() def combined_loss(preds, y): return loss_fn_alex.forward(preds, y).mean() + loss_fn_mse(preds, y) learn = Learner(dls, model, loss_func=combined_loss, opt_func=opt_func) #learn.lr_find() # In[ ]: def sample_based_on_optimizer(model, xt, optim_fn, n_steps, lr_max, pct_start): div=100000. div_final=1e5 lr_sched_fn = combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final) optim = OptimWrapper(opt=optim_fn([xt], lr=lr_max)) prev_noise_var = 1e6 eps = None xt.requires_grad = True for i in range(n_steps): pos = i/(n_steps) lr = lr_sched_fn(pos) optim.set_hyper('lr', lr) with torch.no_grad(): net_output = model(xt) flip_output = TF.hflip(model(TF.hflip(xt))) denoised_output = (net_output + flip_output)/2.0 noise_pred = xt - denoised_output # noise_pred is basically the grad, so GD on this should find a minimum! xt.grad = noise_pred.float() grad_mean = xt.grad.mean() var_diff = (noise_pred.float().var() - prev_noise_var).abs() # # Early stopping if i == n_steps-1 or (i > 0 and grad_mean < 0 and var_diff < 1e-5): xt = net_output #print(var_diff) #print('Stopping at:', i) break prev_noise_var = noise_pred.float().var() optim.step() optim.zero_grad() # Not really needed since we're setting .grad ourselves but anyway... if grad_mean < 0.0: #print(grad_mean) xt = xt + torch.randn_like(xt)*0.0025 return xt def generate_opt_based_samples(model, xt): sgd_optim_fn = partial(torch.optim.SGD, momentum=0.01) xt = xt.clone() xt = sample_based_on_optimizer(model, xt=xt, optim_fn=sgd_optim_fn, n_steps=10, lr_max=5e-2, pct_start=0.5) xt = sample_based_on_optimizer(model, xt=xt, optim_fn=RmsLookahead, n_steps=30, lr_max=5e-2, pct_start=0.25) pred_image = xt return pred_image # In[ ]: def generate_samples_grid(model): xt = torch.rand(num_samples, in_channels, sz, sz).to(device) pred_image = generate_opt_based_samples(model, xt) im = torchvision.utils.make_grid(pred_image.detach().cpu(), nrow=8).permute(1, 2, 0).clip(0, 1) * 255 im = PILImage.fromarray(np.array(im).astype(np.uint8)) return im # In[ ]: #@title Callback for logging samples (shown later in sampling section) from PIL import Image as PILImage class LogSamplesBasicCallback(Callback): def after_epoch(self): model = self.learn.model im = generate_samples_grid(model) wandb.log({'Sample generations basic':wandb.Image(im)}) def after_step(self): if self.train_iter%100 == 0: # Also log every 100 training steps model = self.learn.model im = generate_samples_grid(model) wandb.log({'Sample generations basic':wandb.Image(im)}) # In[ ]: def init_wandb(): import wandb log_config = {} log_config['num_epochs'] = num_epochs log_config['max_lr'] = lr_max log_config['comments'] = comments log_config['dataset'] = dataset_name wandb.init(project=wandb_project, config=log_config, save_code=False) # In[ ]: init_wandb() learn.fit_one_cycle(num_epochs, lr_max=lr_max, cbs=[WandbCallback(n_preds=8), LogSamplesBasicCallback()]) wandb.finish() # In[ ]: learn.show_results() # In[ ]: torch.save(learn.model.state_dict(), str(model_path)) # # Sampling # In[ ]: learn.model.load_state_dict(torch.load(str(model_path))) # In[ ]: generate_samples_grid(learn.model) # In[ ]: