import os
os.environ['CUDA_VISIBLE_DEVICES']='1'
os.environ['OMP_NUM_THREADS']='1'
#!pip install -q diffusers datasets wandb lpips timm
import wandb
wandb.login()
#@title imports
import wandb
import torch
import torchvision
from torch import nn
from torchvision import transforms as T
from torchvision.transforms import functional as TF
from fastai.data.all import *
from fastai.vision.all import *
from fastai.callback.wandb import *
from timm.optim.rmsprop_tf import RMSpropTF
from timm.optim.lookahead import Lookahead
from einops import rearrange
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
def RmsLookahead(params, alpha=0.5, k=6, *args, **kwargs):
rmsprop = RMSpropTF(params, *args, **kwargs)
return Lookahead(rmsprop, alpha, k)
#Training Config
num_epochs = 10
bs = 128 # the batch size
lr_max = 1e-4 # the max learning rate
#opt_func = partial(OptimWrapper, opt=torch.optim.AdamW, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3, lr=lr_max)
opt_func = partial(OptimWrapper, opt=RmsLookahead, weight_decay=1e-3, lr=lr_max)
#Model Config
sz = 64
in_channels = 3
depths= [2, 4, 4]
channels= [128, 256, 512]
self_attn_depths = [False, False, True]
dropout_rate = 0.05
num_samples = 64
name = 'FastDiffusion_KDiff_CelebA_OptSamp' # the name of the run
wandb_project = 'FastDiffusion_KDiff_CelebA_OptSamp' # the wandb project name (specify this to enable wandb)
wandb_save_model = False # save model to wandb
dataset_name = 'CelebA' # wandb name for dataset used
comments = 'Celeb faces using kdiff-based unet and optimizer based sampling.' # comments logged in wandb
demo_imgs_dir = './demo_images'
metrics_dir = './metrics'
#Model Save/Load
checkpoints_dir = './checkpoints'
model_path = Path(checkpoints_dir +'/' + name + '.pt')
model_path.parent.mkdir(exist_ok=True)
# Perceptual loss
import lpips
#@title dataset from HF
from torchvision import transforms as T
from datasets import load_dataset
dataset = load_dataset('huggan/CelebA-faces')
tfm = T.Compose([T.Resize(sz), T.CenterCrop(sz)])
def transforms(examples):
examples["image"] = [tfm(image.convert("RGB")) for image in examples["image"]]
return examples
dataset = dataset.with_transform(transforms)['train']
# Example 64px image
dataset[0]['image']
# Class for crappified image
class PILImageNoised(PILImage): pass
class TensorImageNoised(TensorImage):
def show(self, ctx=None, **kwargs):
super().show(ctx=ctx, **kwargs)
PILImageNoised._tensor_cls = TensorImageNoised
# Transform (TODO experiment)
class Crappify(Transform):
def encodes(self, x:TensorImageNoised):
x = IntToFloatTensor()(x)
blurred = T.GaussianBlur(3)(x) # Add some random blur
noise_amount = torch.rand(x.shape[0], device=x.device)
noise = torch.rand_like(x, device=x.device)
noised = torch.lerp(blurred, noise, noise_amount.view(-1, 1, 1, 1)) * 255
return noised
# Dataloader
dblock = DataBlock(blocks=(ImageBlock(cls=PILImageNoised),ImageBlock(cls=PILImage)),
get_items=lambda pth: range(len(dataset)), # Gets the indexes
getters=[lambda idx: np.array(dataset[idx]['image'])]*2,
batch_tfms=[Crappify])
# dls = dblock.dataloaders('', bs=128)
dls = dblock.dataloaders('', bs=bs) # Half batch size to save mem when using extra nn for perceptual loss
dls.show_batch()
_kernels = {
'linear':
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
'cubic':
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
0.43359375, 0.11328125, -0.03515625, -0.01171875],
'lanczos3':
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
}
_kernels['bilinear'] = _kernels['linear']
_kernels['bicubic'] = _kernels['cubic']
class Downsample2d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect'):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([_kernels[kernel]])
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
def forward(self, x):
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv2d(x, weight, stride=2)
class Upsample2d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect'):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([_kernels[kernel]]) * 2
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
def forward(self, x):
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
def orthogonal_(module):
nn.init.orthogonal_(module.weight)
return module
class ResidualBlock(nn.Module):
def __init__(self, *main, skip=None):
super().__init__()
self.main = nn.Sequential(*main)
self.skip = skip if skip else nn.Identity()
def forward(self, input):
return self.main(input) + self.skip(input)
class ResConvBlock(nn.Module):
def __init__(self, c_in, c_mid, c_out, group_size=32, dropout_rate=0.):
super().__init__()
layers = [
nn.GroupNorm(num_groups=max(1, c_in // group_size), num_channels=c_in),
nn.GELU(),
nn.Conv2d(c_in, c_mid, 3, padding=1),
nn.Dropout2d(dropout_rate, inplace=True),
nn.GroupNorm(num_groups=max(1, c_mid // group_size), num_channels=c_mid),
nn.GELU(),
nn.Conv2d(c_mid, c_out, 3, padding=1),
nn.Dropout2d(dropout_rate, inplace=True)
]
self.main = nn.Sequential(*layers)
self.skip = nn.Identity() if c_in == c_out else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False))
def forward(self, input):
skip = self.skip(input)
input = self.main(input)
return input + skip
class SelfAttention2d(nn.Module):
def __init__(self, c_in, n_head, norm, dropout_rate=0.):
super().__init__()
assert c_in % n_head == 0
self.norm_in = norm(c_in)
self.n_head = n_head
self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1)
self.out_proj = nn.Conv2d(c_in, c_in, 1)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, input):
n, c, h, w = input.shape
qkv = self.qkv_proj(self.norm_in(input))
qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3)
q, k, v = qkv.chunk(3, dim=1)
scale = k.shape[3] ** -0.25
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
att = self.dropout(att)
y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w])
return input + self.out_proj(y)
class DBlock(nn.Sequential):
def __init__(self, n_layers, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False):
modules = [nn.Identity()]
for i in range(n_layers):
my_c_in = c_in if i == 0 else c_mid
my_c_out = c_mid if i < n_layers - 1 else c_out
modules.append(ResConvBlock( my_c_in, c_mid, my_c_out, group_size, dropout_rate))
if self_attn:
norm = lambda c_in: nn.GroupNorm(num_groups=max(1, my_c_out // group_size), num_channels=c_in)
modules.append(SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
super().__init__(*modules)
self.set_downsample(downsample)
def set_downsample(self, downsample):
self[0] = Downsample2d() if downsample else nn.Identity()
return self
class UBlock(nn.Sequential):
def __init__(self, n_layers, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False):
modules = []
for i in range(n_layers):
my_c_in = c_in if i == 0 else c_mid
my_c_out = c_mid if i < n_layers - 1 else c_out
modules.append(ResConvBlock(my_c_in, c_mid, my_c_out, group_size, dropout_rate))
if self_attn:
norm = lambda c_in: nn.GroupNorm(num_groups=max(1, my_c_out // group_size), num_channels=c_in)
modules.append(SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
modules.append(nn.Identity())
super().__init__(*modules)
self.set_upsample(upsample)
def forward(self, input, skip=None):
if skip is not None:
input = torch.cat([input, skip], dim=1)
return super().forward(input)
def set_upsample(self, upsample):
self[-1] = Upsample2d() if upsample else nn.Identity()
return self
class UNet(nn.Module):
def __init__(self, depths, channels, self_attn_depths, c_in=3, dropout_rate=0.0):
super().__init__()
self.proj_in = nn.Conv2d(c_in, channels[0], 1)
self.proj_out = nn.Conv2d(channels[0], c_in, 1)
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
d_blocks, u_blocks = [], []
for i in range(len(depths)):
my_c_in = channels[max(0, i - 1)]
d_blocks.append(DBlock(depths[i], my_c_in, channels[i], channels[i], downsample=i>0, self_attn=self_attn_depths[i], dropout_rate=dropout_rate))
for i in range(len(depths)):
my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i]
my_c_out = channels[max(0, i - 1)]
u_blocks.append(UBlock(depths[i], my_c_in, channels[i], my_c_out, upsample=i>0, self_attn=self_attn_depths[i], dropout_rate=dropout_rate))
self.d_blocks = nn.ModuleList(d_blocks)
self.u_blocks = nn.ModuleList(reversed(u_blocks))
def forward(self, input):
input = self.proj_in(input)
skips = []
for block in self.d_blocks:
input = block(input)
skips.append(input)
for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))):
input = block(input, skip if i > 0 else None)
input = self.proj_out(input)
return input
model = UNet(
depths=depths,
channels=channels,
c_in = in_channels,
self_attn_depths=self_attn_depths,
dropout_rate=dropout_rate
).to(device)
#model
# Create a learner and pick LR
# learn = Learner(dls, model, loss_func=MSELossFlat())
loss_fn_alex = lpips.LPIPS(net='alex').to(device)
loss_fn_mse = MSELossFlat()
def combined_loss(preds, y):
return loss_fn_alex.forward(preds, y).mean() + loss_fn_mse(preds, y)
learn = Learner(dls, model, loss_func=combined_loss, opt_func=opt_func)
#learn.lr_find()
def sample_based_on_optimizer(model, xt, optim_fn, n_steps, lr_max, pct_start):
div=100000.
div_final=1e5
lr_sched_fn = combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final)
optim = OptimWrapper(opt=optim_fn([xt], lr=lr_max))
prev_noise_var = 1e6
eps = None
xt.requires_grad = True
for i in range(n_steps):
pos = i/(n_steps)
lr = lr_sched_fn(pos)
optim.set_hyper('lr', lr)
with torch.no_grad():
net_output = model(xt)
flip_output = TF.hflip(model(TF.hflip(xt)))
denoised_output = (net_output + flip_output)/2.0
noise_pred = xt - denoised_output
# noise_pred is basically the grad, so GD on this should find a minimum!
xt.grad = noise_pred.float()
grad_mean = xt.grad.mean()
var_diff = (noise_pred.float().var() - prev_noise_var).abs()
# # Early stopping
if i == n_steps-1 or (i > 0 and grad_mean < 0 and var_diff < 1e-5):
xt = net_output
#print(var_diff)
#print('Stopping at:', i)
break
prev_noise_var = noise_pred.float().var()
optim.step()
optim.zero_grad() # Not really needed since we're setting .grad ourselves but anyway...
if grad_mean < 0.0:
#print(grad_mean)
xt = xt + torch.randn_like(xt)*0.0025
return xt
def generate_opt_based_samples(model, xt):
sgd_optim_fn = partial(torch.optim.SGD, momentum=0.01)
xt = xt.clone()
xt = sample_based_on_optimizer(model, xt=xt, optim_fn=sgd_optim_fn, n_steps=10, lr_max=5e-2, pct_start=0.5)
xt = sample_based_on_optimizer(model, xt=xt, optim_fn=RmsLookahead, n_steps=30, lr_max=5e-2, pct_start=0.25)
pred_image = xt
return pred_image
def generate_samples_grid(model):
xt = torch.rand(num_samples, in_channels, sz, sz).to(device)
pred_image = generate_opt_based_samples(model, xt)
im = torchvision.utils.make_grid(pred_image.detach().cpu(), nrow=8).permute(1, 2, 0).clip(0, 1) * 255
im = PILImage.fromarray(np.array(im).astype(np.uint8))
return im
#@title Callback for logging samples (shown later in sampling section)
from PIL import Image as PILImage
class LogSamplesBasicCallback(Callback):
def after_epoch(self):
model = self.learn.model
im = generate_samples_grid(model)
wandb.log({'Sample generations basic':wandb.Image(im)})
def after_step(self):
if self.train_iter%100 == 0: # Also log every 100 training steps
model = self.learn.model
im = generate_samples_grid(model)
wandb.log({'Sample generations basic':wandb.Image(im)})
def init_wandb():
import wandb
log_config = {}
log_config['num_epochs'] = num_epochs
log_config['max_lr'] = lr_max
log_config['comments'] = comments
log_config['dataset'] = dataset_name
wandb.init(project=wandb_project, config=log_config, save_code=False)
init_wandb()
learn.fit_one_cycle(num_epochs, lr_max=lr_max, cbs=[WandbCallback(n_preds=8), LogSamplesBasicCallback()])
wandb.finish()
learn.show_results()
torch.save(learn.model.state_dict(), str(model_path))
learn.model.load_state_dict(torch.load(str(model_path)))
generate_samples_grid(learn.model)