!pip install -q diffusers datasets wandb lpips timm
import wandb
wandb.login()
#@title imports
import wandb
import torch
import torchvision
from torch import nn
from torch import multiprocessing as mp
from torch.utils import data
from torchvision import datasets, transforms
from torchvision import transforms as T
from torchvision.transforms import functional as TF
from fastai.data.all import *
from fastai.vision.all import *
from fastai.callback.wandb import *
from timm.optim.rmsprop_tf import RMSpropTF
from timm.optim.lookahead import Lookahead
import accelerate
from einops import rearrange
from functools import partial
import math
from copy import deepcopy
from pathlib import Path
from tqdm.auto import trange, tqdm
import k_diffusion as K
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
#Training Config
bs = 256 # the batch size
grad_accum_steps = 1 # the number of gradient accumulation steps
lr_max = 2e-4 # the max learning rate
num_workers = 8 # the number of data loader workers
resume = None # the checkpoint to resume from
save_every = 10000 # save every this many steps
training_seed = None # the random seed for training
start_method = 'spawn' # the multiprocessing start method. Options: 'fork', 'forkserver', 'spawn'
opt_func = partial(torch.optim.AdamW, lr=lr_max, betas=(0.95, 0.999),
eps=1e-6, weight_decay=1e-3)
#Logging Config
sample_n = 64 # the number of images to sample for demo grids
demo_every = 500 # save a demo grid every this many steps
evaluate_every = 10000 # save a demo grid every this many steps
evaluate_n = 2000 # the number of samples to draw to evaluate
name = 'KDiff_FashionMnist_Baseline' # the name of the run
wandb_project = 'FastDiffusion_KDiff_Fmnist' # the wandb project name (specify this to enable wandb)
wandb_save_model = False # save model to wandb
dataset_name = 'FashionMNIST' # wandb name for dataset used
comments = 'Initial baseline run of K-diffusion model on FashionMnist.' # comments logged in wandb
demo_imgs_dir = './demo_images'
metrics_dir = './metrics'
#Model Config
sz = 28
size = [sz,sz]
input_channels = 1
patch_size= 1
mapping_out= 256
depths= [2, 4, 4]
channels= [128, 128, 256]
self_attn_depths = [False, False, True]
cross_attn_depths = None
has_variance = True
dropout_rate = 0.05
augment_wrapper = True
augment_prob = 0.12
sigma_data = 0.6162
sigma_min = 1e-2
sigma_max = 80
skip_stages = 0
augment_prob = 0.12
sigma_min = 1e-2
sigma_max = 80
#Model Save/Load
checkpoints_dir = './checkpoints'
model_path = Path(checkpoints_dir +'/kdiff_baseline_fmnist.pt')
model_ema_path = Path(checkpoints_dir +'/kdiff_baseline_fmnist_ema.pt')
model_path.parent.mkdir(exist_ok=True)
mp.set_start_method(start_method)
torch.backends.cuda.matmul.allow_tf32 = True
ddp_kwargs = accelerate.DistributedDataParallelKwargs(find_unused_parameters=skip_stages > 0)
accelerator = accelerate.Accelerator(kwargs_handlers=[ddp_kwargs], gradient_accumulation_steps=grad_accum_steps)
device = accelerator.device
print(f'Process {accelerator.process_index} using device: {device}', flush=True)
def make_sample_density(mean=-1.2, std=1.2):
#lognormal
return partial(K.utils.rand_log_normal, loc=mean, scale=std)
def make_model():
model = K.models.ImageDenoiserModelV1(
c_in=input_channels,
feats_in=mapping_out,
depths=depths,
channels=channels,
self_attn_depths=self_attn_depths,
cross_attn_depths=cross_attn_depths,
patch_size=patch_size,
dropout_rate=dropout_rate,
mapping_cond_dim= 9 if augment_wrapper else 0,
unet_cond_dim = 0,
cross_cond_dim = 0,
skip_stages= skip_stages,
has_variance=has_variance,
)
if augment_wrapper:
model = K.augmentation.KarrasAugmentWrapper(model)
return model
def make_denoiser_wrapper():
if not has_variance:
return partial(K.layers.Denoiser, sigma_data=sigma_data)
return partial(K.layers.DenoiserWithVariance, sigma_data=sigma_data)
tf = transforms.Compose([
transforms.Resize(sz, interpolation=transforms.InterpolationMode.LANCZOS),
transforms.CenterCrop(sz),
K.augmentation.KarrasAugmentationPipeline(augment_prob),
])
train_set = datasets.FashionMNIST('data', train=True, download=True, transform=tf)
if accelerator.is_main_process:
try:
print('Number of items in dataset:', len(train_set))
except TypeError:
pass
train_dl = data.DataLoader(train_set, bs, shuffle=True, drop_last=True, num_workers=num_workers, persistent_workers=True)
inner_model = make_model()
if accelerator.is_main_process:
print('Parameters:', K.utils.n_params(inner_model))
model = make_denoiser_wrapper()(inner_model)
def init_wandb():
import wandb
log_config = {}
log_config['num_epochs'] = 'N/A'
log_config['lr_max'] = lr_max
log_config['comments'] = comments
log_config['dataset'] = dataset_name
log_config['parameters'] = K.utils.n_params(inner_model)
wandb.init(project=wandb_project, config=log_config, save_code=False)
def init_training_manual_seed(accelerator):
if training_seed is not None:
seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes], generator=torch.Generator().manual_seed(training_seed))
torch.manual_seed(seeds[accelerator.process_index])
def log_step_to_wandb(epoch, loss, step, sched, ema_decay):
log_dict = {
'epoch': epoch,
'loss': loss.item(),
'lr': sched.get_last_lr()[0],
'ema_decay': ema_decay,
}
wandb.log(log_dict, step=step)
def write_progress_to_tdqm(epoch, step, loss):
tqdm.write(f'Epoch: {epoch}, step: {step}, loss: {loss.item():g}')
opt = opt_func(inner_model.parameters())
init_training_manual_seed(accelerator)
use_wandb = accelerator.is_main_process and wandb_project
if use_wandb: init_wandb()
sched = K.utils.InverseLR(opt, inv_gamma=20000.0, power=1.0, warmup=0.99)
ema_sched = K.utils.EMAWarmup(power=0.6667, max_value=0.9999)
image_key = 0
inner_model, opt, train_dl = accelerator.prepare(inner_model, opt, train_dl)
if use_wandb:
wandb.watch(inner_model)
sample_density = make_sample_density()
model_ema = deepcopy(model)
epoch = 0
step = 0
evaluate_enabled = evaluate_every > 0 and evaluate_n > 0
extractor = None
if evaluate_enabled:
extractor = K.evaluation.InceptionV3FeatureExtractor(device=device)
train_iter = iter(train_dl)
if accelerator.is_main_process:
print('Computing features for reals...')
reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter)[image_key][1], extractor, evaluate_n, bs)
if accelerator.is_main_process:
Path(metrics_dir).mkdir(exist_ok=True)
metrics_log = K.utils.CSVLogger(f'{name}_metrics.csv', ['step', 'fid', 'kid'])
del train_iter
@torch.no_grad()
def demo(model_ema, step, size):
with K.utils.eval_mode(model_ema):
if accelerator.is_main_process:
tqdm.write('Sampling...')
filename = f'{demo_imgs_dir}/{name}_demo_{step:08}.png'
path = Path(filename)
path.parent.mkdir(exist_ok=True)
n_per_proc = math.ceil(sample_n / accelerator.num_processes)
x = torch.randn([n_per_proc, input_channels, size[0], size[1]], device=device) * sigma_max
sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device)
x_0 = K.sampling.sample_lms(model_ema, x, sigmas, disable=not accelerator.is_main_process)
x_0 = accelerator.gather(x_0)[:sample_n]
# For some reason the images are inverting...
x_0 = -x_0
if accelerator.is_main_process:
grid = torchvision.utils.make_grid(x_0, nrow=math.ceil(sample_n ** 0.5), padding=0)
K.utils.to_pil_image(grid).save(filename)
if use_wandb:
wandb.log({'demo_grid': wandb.Image(filename)}, step=step)
@torch.no_grad()
def evaluate(model_ema, step, size):
with K.utils.eval_mode(model_ema):
if not evaluate_enabled:
return
if accelerator.is_main_process:
tqdm.write('Evaluating...')
sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device)
def sample_fn(n):
x = torch.randn([n, input_channels, size[0], size[1]], device=device) * sigma_max
x_0 = K.sampling.sample_lms(model_ema, x, sigmas, disable=True)
return x_0
fakes_features = K.evaluation.compute_features(accelerator, sample_fn, extractor, evaluate_n, bs)
if accelerator.is_main_process:
fid = K.evaluation.fid(fakes_features, reals_features)
kid = K.evaluation.kid(fakes_features, reals_features)
print(f'FID: {fid.item():g}, KID: {kid.item():g}')
if accelerator.is_main_process:
metrics_log.write(step, fid.item(), kid.item())
if use_wandb:
wandb.log({'FID': fid.item(), 'KID': kid.item()}, step=step)
def save(step, epoch, opt, sched):
accelerator.wait_for_everyone()
filename = f'{checkpoints_dir}/{name}_{step:08}.pth'
if accelerator.is_main_process:
tqdm.write(f'Saving to {filename}...')
obj = {
'model': accelerator.unwrap_model(model.inner_model).state_dict(),
'model_ema': accelerator.unwrap_model(model_ema.inner_model).state_dict(),
'opt': opt.state_dict(),
'sched': sched.state_dict(),
'ema_sched': ema_sched.state_dict(),
'epoch': epoch,
'step': step
}
accelerator.save(obj, filename)
try:
while True:
for batch in tqdm(train_dl, disable=not accelerator.is_main_process):
with accelerator.accumulate(model):
reals, _, aug_cond = batch[image_key]
noise = torch.randn_like(reals)
sigma = sample_density([reals.shape[0]], device=device)
losses = model.loss(reals, noise, sigma, aug_cond=aug_cond)
losses_all = accelerator.gather(losses)
loss = losses_all.mean()
accelerator.backward(losses.mean())
opt.step()
sched.step()
opt.zero_grad()
if accelerator.sync_gradients:
ema_decay = ema_sched.get_value()
K.utils.ema_update(model, model_ema, ema_decay)
ema_sched.step()
if accelerator.is_main_process and step % 25 == 0:
write_progress_to_tdqm(epoch, step, loss)
if use_wandb:
log_step_to_wandb(epoch, loss, step, sched, ema_decay)
if step % demo_every == 0:
demo(model_ema, step, size)
if evaluate_enabled and step > 0 and step % evaluate_every == 0:
evaluate(model_ema, step, size)
if step > 0 and step % save_every == 0:
save(step, epoch, opt, sched)
step += 1
epoch += 1
except KeyboardInterrupt:
pass
torch.save(model.state_dict(), str(model_path))
torch.save(model_ema.state_dict(), str(model_ema_path))
inner_model = make_model().to(device)
model_ema = make_denoiser_wrapper()(inner_model)
model_ema.load_state_dict(torch.load(str(model_ema_path)))
@torch.no_grad()
def sample_lms(model_ema, size):
with K.utils.eval_mode(model_ema):
n_per_proc = math.ceil(sample_n / accelerator.num_processes)
x = torch.randn([n_per_proc, input_channels, size[0], size[1]], device=device) * sigma_max
sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device)
x_0 = K.sampling.sample_lms(model_ema, x, sigmas, disable=not accelerator.is_main_process)
x_0 = accelerator.gather(x_0)[:sample_n]
# For some reason the images are inverting...
x_0 = -x_0
grid = torchvision.utils.make_grid(x_0, nrow=math.ceil(sample_n ** 0.5), padding=0)
return K.utils.to_pil_image(grid)
grid = sample_lms(model_ema, size)
fig, ax = plt.subplots(1, 1, figsize=(16, 16))
ax.imshow(grid)