#!/usr/bin/env python # coding: utf-8 # ### Using KDiffusion model, with pretrained UNet encoder. # In[1]: import os os.environ['CUDA_VISIBLE_DEVICES']='1' os.environ['OMP_NUM_THREADS']='1' # In[2]: #!pip install -q diffusers datasets wandb lpips timm # In[3]: import wandb wandb.login() # In[4]: #@title imports import wandb import torch import torchvision from torch import nn from torch import multiprocessing as mp from torch.utils import data from torchvision import datasets, transforms from torchvision import transforms as T from torchvision.transforms import functional as TF from fastai.data.all import * from fastai.vision.all import * from fastai.callback.wandb import * from fastai.vision.models.unet import _get_sz_change_idxs from timm.optim.rmsprop_tf import RMSpropTF from timm.optim.lookahead import Lookahead from timm.optim.lamb import Lamb from torch.nn.utils import spectral_norm import timm import accelerate from einops import rearrange from functools import partial import math from copy import deepcopy from pathlib import Path from tqdm.auto import trange, tqdm import k_diffusion as K from k_diffusion.models.image_v1 import * from datasets import load_dataset device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f'Using device: {device}') # In[5]: def RmsLookahead(params, alpha=0.5, k=6, *args, **kwargs): rmsprop = RMSpropTF(params, *args, **kwargs) return Lookahead(rmsprop, alpha, k) # In[6]: def AdamLookahead(params, alpha=0.5, k=6, *args, **kwargs): opt = torch.optim.AdamW(params, betas=(0.95, 0.999), eps=1e-6, *args, **kwargs) return Lookahead(opt, alpha, k) # In[7]: def convnext_large(pretrained:bool=False, **kwargs): return timm.create_model('convnext_large_384_in22ft1k', pretrained=pretrained) # In[8]: def resnext101_32x16d_wsl(pretrained:bool=False, **kwargs): return torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x16d_wsl') # In[9]: #Training Config num_epochs = 15 unfreeze_epoch = int(0.3*num_epochs) bs = 128 # the batch size grad_accum_steps = 1 # the number of gradient accumulation steps max_lr = 5e-4 # the max learning rate num_workers = 8 # the number of data loader workers resume = None # the checkpoint to resume from save_every = 10000 # save every this many steps training_seed = None # the random seed for training start_method = 'spawn' # the multiprocessing start method. Options: 'fork', 'forkserver', 'spawn' imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) #encoder_arch = resnext101_32x16d_wsl #encoder_cut = -2 encoder_arch = convnext_large encoder_cut = -6 mean, std = imagenet_stats opt_func = partial(torch.optim.AdamW, lr=max_lr, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3) #opt_func = partial(RmsLookahead, lr=max_lr, weight_decay=1e-3) #opt_func = partial(RMSpropTF, lr=max_lr, weight_decay=1e-3) #Logging Config sample_n = 64 # the number of images to sample for demo grids demo_every = 250 # save a demo grid every this many steps #evaluate_every = 10000 # save a demo grid every this many steps evaluate_every = 0 #disabled evaluate_n = 50000 # the number of samples to draw to evaluate name = 'KDiff_CelebA_PretrainedEncoderUnet' # the name of the run wandb_project = 'FastDiffusion_KDiff_CelebA' # the wandb project name (specify this to enable wandb) wandb_save_model = False # save model to wandb dataset_name = 'CelebA' # wandb name for dataset used comments = 'Pretrained Encoder based Unet run of K-diffusion model on CelebA.' # comments logged in wandb demo_imgs_dir = './demo_images' metrics_dir = './metrics' #Model Config sz = 64 size = [sz,sz] input_channels = 3 patch_size= 1 mapping_out= 256 compress_factor = 4 #UBlock Only #depths= [8, 8, 8, 4, 4] #channels= [1024, 512, 256, 256, 128] #self_attn_depths = [True, True, False, False, False] #UBlock Only depths= [4, 4, 2, 2] channels= [512, 512, 256, 256] self_attn_depths = [True, True, False, False] #depths= [4, 4, 2] #channels= [512, 256, 128] #self_attn_depths = [True, True, False] cross_attn_depths = None has_variance = True dropout_rate = 0.05 augment_wrapper = True augment_prob = 0.12 sigma_data = 0.5 sigma_min = 1e-2 sigma_max = 80 skip_stages = 0 augment_prob = 0.12 sigma_min = 1e-2 sigma_max = 80 #Model Save/Load checkpoints_dir = './checkpoints' model_path = Path(checkpoints_dir +'/' + name + '.pt') model_ema_path = Path(checkpoints_dir +'/' + name + '_ema.pt') model_path.parent.mkdir(exist_ok=True) # In[10]: mp.set_start_method(start_method) torch.backends.cuda.matmul.allow_tf32 = True # In[11]: ddp_kwargs = accelerate.DistributedDataParallelKwargs(find_unused_parameters=skip_stages > 0) accelerator = accelerate.Accelerator(kwargs_handlers=[ddp_kwargs], gradient_accumulation_steps=grad_accum_steps, mixed_precision='fp16') device = accelerator.device print(f'Process {accelerator.process_index} using device: {device}', flush=True) # # Model and Training Setup # In[12]: def make_sample_density(mean=-1.2, std=1.2): #lognormal return partial(K.utils.rand_log_normal, loc=mean, scale=std) # In[13]: def make_sequential_model(model, cut:int=None): flattened = list() for child in model.children(): if isinstance(child, nn.Sequential): flattened.extend(child.children()) else: flattened.append(child) if cut is None: return nn.Sequential(*flattened) else: return nn.Sequential(*flattened[:cut]) # In[14]: encoder = encoder_arch(pretrained=True) encoder = make_sequential_model(encoder, encoder_cut) # In[15]: class CompressBlock(layers.ConditionedSequential): def __init__(self, feats_in, c_in, c_out, group_size=32, dropout_rate=0., self_attn=False, cross_attn=False, c_enc=0): modules = [] modules.append(spectral_norm(nn.Conv2d(c_in, c_out, 1))) #modules.append(nn.Conv2d(c_in, c_out, 3, padding=1)) modules.append(nn.Dropout2d(dropout_rate, inplace=True)) modules.append(K.layers.AdaGN(feats_in, c_out, max(1, c_out // group_size))) modules.append(nn.GELU()) #modules.append(ResConvBlock(feats_in, c_out, c_out, c_out, group_size, dropout_rate)) modules.append(UBlock(n_layers=1, feats_in=feats_in, c_in=c_out, c_mid=c_out, c_out=c_out, dropout_rate=dropout_rate, upsample=False, self_attn=self_attn, cross_attn=cross_attn, c_enc=c_enc, group_size=group_size)) super().__init__(*modules) def forward(self, input, cond, skip=None): if skip is not None: input = torch.cat([input, skip], dim=1) return super().forward(input, cond) # In[16]: class CustUBlock(nn.Module): def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, dropout_rate=0., up_c_in=0, compress_factor=4, upsample=False, self_attn=False, cross_attn=False, c_enc=0): super().__init__() if up_c_in > 0: c_in_big = c_in + up_c_in up_c_in_comp = up_c_in//compress_factor self.c1_block = CompressBlock(feats_in, c_in_big, up_c_in_comp, dropout_rate=dropout_rate, self_attn=self_attn, cross_attn=cross_attn, c_enc=c_enc) c_in_bigger = c_in + up_c_in_comp + up_c_in self.c2_block = CompressBlock(feats_in, c_in_bigger, c_in, dropout_rate=dropout_rate, self_attn=self_attn, cross_attn=cross_attn, c_enc=c_enc) else: self.c1_block = nn.Identity() self.c2_blocks = nn.Identity() self.u_block = UBlock(n_layers=n_layers, feats_in=feats_in, c_in=c_in, c_mid=c_mid, c_out=c_out, dropout_rate=dropout_rate, upsample=upsample, self_attn=self_attn, cross_attn=cross_attn, c_enc=c_enc) def forward(self, input, cond, skip=None): if skip is not None: skip_comp = torch.cat([input, skip], dim=1) skip_comp = self.c1_block(skip_comp, cond) input = torch.cat([input, skip, skip_comp], dim=1) input = self.c2_block(input, cond) return self.u_block(input, cond) # In[17]: class OutBlock(nn.Module): def __init__(self, n_layers, feats_in, c_in, c_in_comp, c_mid, c_out, dropout_rate=0., self_attn=False, cross_attn=False, c_enc=0, group_size=32): super().__init__() self.c1_block = CompressBlock(feats_in, c_in, c_in_comp, dropout_rate=dropout_rate) self.u_block1 = UBlock(n_layers=n_layers, feats_in=feats_in, c_in=c_in_comp, c_mid=c_in_comp, c_out=c_in_comp, dropout_rate=dropout_rate, upsample=False, self_attn=self_attn, cross_attn=cross_attn, c_enc=c_enc) c_in_big = c_in_comp + c_in self.c2_block = CompressBlock(feats_in, c_in_big, c_mid, dropout_rate=dropout_rate) self.u_block2 = UBlock(n_layers=n_layers, feats_in=feats_in, c_in=c_mid, c_mid=c_mid, c_out=c_out, dropout_rate=dropout_rate, upsample=False, self_attn=self_attn, cross_attn=cross_attn, c_enc=c_enc) def forward(self, input, cond): input_comp = self.c1_block(input, cond) input_comp = self.u_block1(input_comp, cond, None) input_comp = torch.cat([input, input_comp], dim=1) input_comp = self.c2_block(input_comp, cond) return self.u_block2(input_comp, cond, None) # In[18]: class InitBlock(K.layers.ConditionedModule): def __init__(self, n_layers, feats_in, c_in, c_out, ks=2, stride=2, group_size=32, dropout_rate=0.0, c_enc=0): super().__init__() layers = [nn.Conv2d(c_in, c_out, ks, stride=(stride,stride)), K.layers.AdaGN(feats_in, c_out, max(1, c_out // group_size)), DBlock(n_layers=n_layers, feats_in=feats_in, c_in=c_out, c_mid=c_out, c_out=c_out, dropout_rate=dropout_rate, downsample=False, self_attn=False, cross_attn=False, c_enc=c_enc)] self.downsampler = K.layers.ConditionedSequential(*layers) def forward(self, input, cond): return self.downsampler(input, cond) # In[19]: class EnhancedEncoder(K.layers.ConditionedSequential): def __init__(self, encoder, imsize, depths, self_attn_depths, cross_attn_depths, feats_in, c_in=3, dropout_rate=0., c_enc=0): sizes = model_sizes(encoder, imsize) sz_chg_idxs = _get_sz_change_idxs(sizes) self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False) self.sizes = sizes orig_layers = [] new_layers = [] layers = [] c_out = sizes[sz_chg_idxs[0]][1] init_block = InitBlock(n_layers=2, feats_in=feats_in, c_in=c_in, c_out=c_out, dropout_rate=dropout_rate, ks=4, stride=4, c_enc=c_enc) #sz_chg_idxs = [0, *sz_chg_idxs] self.sz_chg_idxs = sz_chg_idxs layers.append(init_block) new_layers.append(init_block) for i in range(1, len(sz_chg_idxs)+1): start_idx = sz_chg_idxs[i-1] end_idx = sz_chg_idxs[i] if i < len(sz_chg_idxs) else len(encoder) orig_part = encoder[start_idx:end_idx] orig_layers.append(orig_part) layers.append(orig_part) c_in = sizes[min(end_idx-1, len(encoder)-1)][1] self_attn = self_attn_depths[len(self_attn_depths)-2-i] cross_attn = cross_attn_depths[len(cross_attn_depths)-2-i] n_layers = depths[len(depths)-2-i] d_block = DBlock(n_layers=n_layers, feats_in=feats_in, c_in=c_in, c_mid=c_in, c_out=c_in, dropout_rate=dropout_rate, downsample=False, self_attn=self_attn, cross_attn=cross_attn, c_enc=c_enc) new_layers.append(d_block) layers.append(d_block) super().__init__(*layers) self.orig_layers = orig_layers self.new_layers = new_layers def toggle_orig_encoder_freeze(self, freeze=True): for layer in self.orig_layers: for param in layer.parameters(): param.requires_grad = not freeze # In[20]: class CustUNet(K.layers.ConditionedModule): "Create a U-Net from a given architecture." def __init__(self, encoder, n_out, img_size, depths, channels, self_attn_depths, feats_in, mean, std, dropout_rate=0.0, cross_attn_depths=None, last_cross=True, group_size=32, cross_cond_dim=0, compress_factor=4, **kwargs): super().__init__() self.normalize = transforms.Normalize(mean=mean, std=std) imsize = img_size x = dummy_eval(encoder, imsize).detach() self.last_cross=last_cross self.encoder = EnhancedEncoder(encoder, imsize, depths=depths, self_attn_depths=self_attn_depths, cross_attn_depths=cross_attn_depths, feats_in=feats_in, dropout_rate=dropout_rate) #sizes = model_sizes(self.encoder, size=imsize) #sz_chg_idxs = list(reversed(_get_sz_change_idxs(sizes))) #self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False) sizes = self.encoder.sizes sz_chg_idxs = list(reversed(self.encoder.sz_chg_idxs)) self.sfs = list(reversed(self.encoder.sfs)) ni = sizes[-1][1] my_c_out = channels[0] #middle_conv = K.layers.ConditionedSequential(ResConvBlock(feats_in, ni, ni*2, my_c_out, dropout_rate=dropout_rate)) middle_conv = K.layers.ConditionedSequential(ResConvBlock(feats_in, ni, ni//4, my_c_out, dropout_rate=dropout_rate)) middle_layers = [K.layers.AdaGN(feats_in, ni, max(1, ni // group_size)), nn.GELU(), middle_conv] self.middle_block = K.layers.ConditionedSequential(*middle_layers) u_blocks = [] for i in range(0, len(channels)): idx = sz_chg_idxs[i-1] if i > 0 else None up_c_in = 0 if idx is None else int(sizes[idx][1]) my_c_out = channels[min(len(channels)-1, i+1)] u_block = CustUBlock(depths[i], feats_in, channels[i], channels[i], my_c_out, up_c_in=up_c_in, upsample=True, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate) u_blocks.append(u_block) self.u_blocks = nn.ModuleList(u_blocks) ni = my_c_out out_layers = [] if last_cross: ni += in_channels(encoder) i = len(channels)-1 ni_comp = (ni- ni%group_size)//compress_factor self.out_block = OutBlock(depths[i], feats_in, ni, ni_comp, channels[i], n_out, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate, group_size=group_size) def __del__(self): if hasattr(self, "sfs"): self.sfs.remove() def forward(self, input, cond): input = self.normalize(input) orig_input = input input = self.encoder(input, cond) input = self.middle_block(input, cond) for i, block in enumerate(self.u_blocks): skip = self.sfs[i-1].stored if i > 0 else None input = block(input, cond, skip) if orig_input.shape[-2:] != input.shape[-2:]: input = F.interpolate(input, orig_input.shape[-2:], mode='bicubic') if self.last_cross: input = torch.cat([input, orig_input], dim=1) input = self.out_block (input, cond) return input # In[21]: class CustImageDenoiserModelV1(nn.Module): def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, encoder, mean, std, cross_attn_depths=None, mapping_cond_dim=0, \ unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0, has_variance=False, compress_factor=4): super().__init__() self.c_in = c_in self.unet_cond_dim = unet_cond_dim self.patch_size = patch_size self.has_variance = has_variance self.timestep_embed = K.layers.FourierFeatures(1, feats_in) if mapping_cond_dim > 0: self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False) self.mapping = MappingNet(feats_in, feats_in) n_out = channels[-1]//4 self.proj_out = nn.Conv2d(n_out, c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1) nn.init.zeros_(self.proj_out.weight) nn.init.zeros_(self.proj_out.bias) if cross_cond_dim == 0: cross_attn_depths = [False] * len(self_attn_depths) self.u_net = CustUNet(encoder=encoder, n_out=n_out, img_size=(sz,sz), depths=depths, channels=channels, feats_in=feats_in, mean=mean, std=std, self_attn_depths=self_attn_depths, dropout_rate=dropout_rate, cross_attn_depths=cross_attn_depths, cross_cond_dim=cross_cond_dim, compress_factor=compress_factor) def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=None, cross_cond_padding=None, return_variance=False): c_noise = sigma.log() / 4 timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2)) mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond) mapping_out = self.mapping(timestep_embed + mapping_cond_embed) cond = {'cond': mapping_out} if unet_cond is not None: input = torch.cat([input, unet_cond], dim=1) if cross_cond is not None: cond['cross'] = cross_cond cond['cross_padding'] = cross_cond_padding if self.patch_size > 1: input = F.pixel_unshuffle(input, self.patch_size) input = self.u_net(input, cond) input = self.proj_out(input) if self.has_variance: input, logvar = input[:, :-1], input[:, -1].flatten(1).mean(1) if self.patch_size > 1: input = F.pixel_shuffle(input, self.patch_size) if self.has_variance and return_variance: return input, logvar return input def set_skip_stages(self, skip_stages): return def set_patch_size(self, patch_size): return # In[22]: def make_model(): model = CustImageDenoiserModelV1( c_in=input_channels, feats_in=mapping_out, depths=depths, channels=channels, self_attn_depths=self_attn_depths, cross_attn_depths=cross_attn_depths, patch_size=patch_size, dropout_rate=dropout_rate, mapping_cond_dim= 9 if augment_wrapper else 0, unet_cond_dim = 0, cross_cond_dim = 0, skip_stages= skip_stages, has_variance=has_variance, encoder=encoder, compress_factor=compress_factor, mean=mean, std=std ) if augment_wrapper: model = K.augmentation.KarrasAugmentWrapper(model) return model # In[23]: def make_denoiser_wrapper(): if not has_variance: return partial(K.layers.Denoiser, sigma_data=sigma_data) return partial(K.layers.DenoiserWithVariance, sigma_data=sigma_data) # In[24]: tfm = transforms.Compose([ transforms.Resize(sz, interpolation=transforms.InterpolationMode.LANCZOS), transforms.CenterCrop(sz), K.augmentation.KarrasAugmentationPipeline(augment_prob) ]) def tfms(examples): examples["image"] = [tfm(image.convert("RGB")) for image in examples["image"]] return examples # In[25]: training_set = load_dataset('huggan/CelebA-faces') tds = training_set.with_transform(tfms)['train'] dls = DataLoaders.from_dsets(tds, bs=bs) train_dl = dls.train # In[ ]: # In[26]: inner_model = make_model() if accelerator.is_main_process: print('Parameters:', K.utils.n_params(inner_model)) model = make_denoiser_wrapper()(inner_model) # In[27]: def init_wandb(): import wandb log_config = {} log_config['num_epochs'] = 'N/A' log_config['max_lr'] = max_lr log_config['comments'] = comments log_config['dataset'] = dataset_name log_config['parameters'] = K.utils.n_params(inner_model) wandb.init(project=wandb_project, config=log_config, save_code=False) # In[28]: def init_training_manual_seed(accelerator): if training_seed is not None: seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes], generator=torch.Generator().manual_seed(training_seed)) torch.manual_seed(seeds[accelerator.process_index]) # In[29]: def log_step_to_wandb(epoch, loss, step, sched, ema_decay): log_dict = { 'epoch': epoch, 'loss': loss.item(), 'lr': sched.get_last_lr()[0], 'ema_decay': ema_decay, } wandb.log(log_dict, step=step) # In[30]: def write_progress_to_tdqm(epoch, step, loss): tqdm.write(f'Epoch: {epoch}, step: {step}, loss: {loss.item():g}') # In[31]: model_ema = deepcopy(model).to(device) # In[32]: opt = opt_func(inner_model.parameters()) init_training_manual_seed(accelerator) use_wandb = accelerator.is_main_process and wandb_project if use_wandb: init_wandb() sched = K.utils.InverseLR(opt, inv_gamma=20000.0, power=1.0, warmup=0.99) #total_steps = num_epochs * len(train_dl) #sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=max_lr, total_steps=total_steps, pct_start=0.05) ema_sched = K.utils.EMAWarmup(power=0.6667, max_value=0.9999) image_key = 'image' inner_model, opt, train_dl = accelerator.prepare(inner_model, opt, train_dl) if use_wandb: wandb.watch(inner_model) sample_density = make_sample_density() epoch = 0 step = 0 # In[33]: evaluate_enabled = evaluate_every > 0 and evaluate_n > 0 extractor = None if evaluate_enabled: extractor = K.evaluation.InceptionV3FeatureExtractor(device=device) train_iter = iter(train_dl) if accelerator.is_main_process: print('Computing features for reals...') reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter)[image_key][1].to(device), extractor, evaluate_n, bs) if accelerator.is_main_process: Path(metrics_dir).mkdir(exist_ok=True) metrics_log = K.utils.CSVLogger(f'{name}_metrics.csv', ['step', 'fid', 'kid']) del train_iter # In[34]: @torch.no_grad() def demo(model_ema, step, size): with K.utils.eval_mode(model_ema): with torch.autocast(device_type='cuda', dtype=torch.float16): if accelerator.is_main_process: tqdm.write('Sampling...') filename = f'{demo_imgs_dir}/{name}_demo_{step:08}.png' path = Path(filename) path.parent.mkdir(exist_ok=True) n_per_proc = math.ceil(sample_n / accelerator.num_processes) x = torch.randn([n_per_proc, input_channels, size[0], size[1]], device=device) * sigma_max sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) x_0 = K.sampling.sample_lms(model_ema, x, sigmas, disable=not accelerator.is_main_process) x_0 = accelerator.gather(x_0)[:sample_n] # For some reason the images are inverting... #x_0 = -x_0 if accelerator.is_main_process: grid = torchvision.utils.make_grid(x_0, nrow=math.ceil(sample_n ** 0.5), padding=0) K.utils.to_pil_image(grid).save(filename) if use_wandb: wandb.log({'demo_grid': wandb.Image(filename)}, step=step) # In[35]: @torch.no_grad() def evaluate(model_ema, step, size): with K.utils.eval_mode(model_ema): with torch.autocast(device_type='cuda', dtype=torch.float16): if not evaluate_enabled: return if accelerator.is_main_process: tqdm.write('Evaluating...') sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) def sample_fn(n): x = torch.randn([n, input_channels, size[0], size[1]], device=device) * sigma_max x_0 = K.sampling.sample_lms(model_ema, x, sigmas, disable=True) return x_0 fakes_features = K.evaluation.compute_features(accelerator, sample_fn, extractor, evaluate_n, bs) if accelerator.is_main_process: fid = K.evaluation.fid(fakes_features, reals_features) kid = K.evaluation.kid(fakes_features, reals_features) print(f'FID: {fid.item():g}, KID: {kid.item():g}') if accelerator.is_main_process: metrics_log.write(step, fid.item(), kid.item()) if use_wandb: wandb.log({'FID': fid.item(), 'KID': kid.item()}, step=step) # In[36]: def save(step, epoch, opt, sched): accelerator.wait_for_everyone() filename = f'{checkpoints_dir}/{name}_{step:08}.pth' if accelerator.is_main_process: tqdm.write(f'Saving to {filename}...') obj = { 'model': accelerator.unwrap_model(model.inner_model).state_dict(), 'model_ema': accelerator.unwrap_model(model_ema.inner_model).state_dict(), 'opt': opt.state_dict(), 'sched': sched.state_dict(), 'ema_sched': ema_sched.state_dict(), 'epoch': epoch, 'step': step } accelerator.save(obj, filename) # In[37]: def toggle_encoder_freeze(model, model_ema, freeze=True): model.inner_model.inner_model.u_net.encoder.toggle_orig_encoder_freeze(freeze) model_ema.inner_model.inner_model.u_net.encoder.toggle_orig_encoder_freeze(freeze) return model, model_ema # In[38]: model, model_ema = toggle_encoder_freeze(model, model_ema, freeze=True) # ## Training Loop # In[39]: try: ema_decay = None while epoch < num_epochs: for batch in tqdm(train_dl, disable=not accelerator.is_main_process): with accelerator.accumulate(model): reals, _, aug_cond = batch[image_key] reals = reals.to(device) aug_cond = aug_cond.to(device) noise = torch.randn_like(reals) sigma = sample_density([reals.shape[0]], device=device) losses = model.loss(reals, noise, sigma, aug_cond=aug_cond) losses_all = accelerator.gather(losses) loss = losses_all.mean() accelerator.backward(losses.mean()) opt.step() sched.step() opt.zero_grad() if accelerator.sync_gradients: ema_decay = ema_sched.get_value() #print(next(model.parameters()).device, next(model_ema.parameters()).device) K.utils.ema_update(model, model_ema, ema_decay) ema_sched.step() if accelerator.is_main_process and step % 25 == 0: write_progress_to_tdqm(epoch, step, loss) if use_wandb: log_step_to_wandb(epoch, loss, step, sched, ema_decay) if step % demo_every == 0: demo(model_ema, step, size) if evaluate_enabled and step > 0 and step % evaluate_every == 0: evaluate(model_ema, step, size) if step > 0 and step % save_every == 0: save(step, epoch, opt, sched) step += 1 epoch += 1 if epoch >= unfreeze_epoch: model, model_ema = toggle_encoder_freeze(model, model_ema, freeze=False) except KeyboardInterrupt: pass # In[40]: torch.save(model.state_dict(), str(model_path)) torch.save(model_ema.state_dict(), str(model_ema_path)) # In[ ]: # # Sampling # In[42]: encoder = encoder_arch(pretrained=True) encoder = make_sequential_model(encoder, encoder_cut) inner_model = make_model().to(device) model_ema = make_denoiser_wrapper()(inner_model) # In[43]: model_ema.load_state_dict(torch.load(str(model_ema_path))) # In[44]: @torch.no_grad() def sample_lms(model_ema, size): with K.utils.eval_mode(model_ema): n_per_proc = math.ceil(sample_n / accelerator.num_processes) x = torch.randn([n_per_proc, input_channels, size[0], size[1]], device=device) * sigma_max sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) x_0 = K.sampling.sample_lms(model_ema, x, sigmas, disable=not accelerator.is_main_process) x_0 = accelerator.gather(x_0)[:sample_n] # For some reason the images are inverting... #x_0 = -x_0 grid = torchvision.utils.make_grid(x_0, nrow=math.ceil(sample_n ** 0.5), padding=0) return K.utils.to_pil_image(grid) # In[45]: grid = sample_lms(model_ema, size) fig, ax = plt.subplots(1, 1, figsize=(16, 16)) ax.imshow(grid) # In[ ]: # In[ ]: