!pip install -q diffusers datasets wandb lpips timm
import wandb
wandb.login()
wandb: Currently logged in as: jantic. Use `wandb login --relogin` to force relogin
True
#@title imports
import wandb
import torch
import torchvision
from torch import nn
from torchvision import transforms as T
from fastai.data.all import *
from fastai.vision.all import *
from fastai.callback.wandb import *
from timm.optim.rmsprop_tf import RMSpropTF
from timm.optim.lookahead import Lookahead
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
Using device: cuda
# Perceptual loss
import lpips
#@title dataset from HF
from torchvision import transforms as T
from datasets import load_dataset
dataset = load_dataset('huggan/CelebA-faces')
tfm = T.Compose([T.Resize(64), T.CenterCrop(64)])
def transforms(examples):
examples["image"] = [tfm(image.convert("RGB")) for image in examples["image"]]
return examples
dataset = dataset.with_transform(transforms)['train']
Using custom data configuration huggan--CelebA-faces-8a807f0d7d4912ca Found cached dataset parquet (F:/.cache/huggingface/datasets/huggan___parquet/huggan--CelebA-faces-8a807f0d7d4912ca/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
0%| | 0/1 [00:00<?, ?it/s]
# Example 64px image
dataset[0]['image']
# Class for crappified image
class PILImageNoised(PILImage): pass
class TensorImageNoised(TensorImage):
def show(self, ctx=None, **kwargs):
super().show(ctx=ctx, **kwargs)
PILImageNoised._tensor_cls = TensorImageNoised
# Transform (TODO experiment)
class Crappify(Transform):
def encodes(self, x:TensorImageNoised):
x = IntToFloatTensor()(x)
blurred = T.GaussianBlur(3)(x) # Add some random blur
noise_amount = torch.rand(x.shape[0], device=x.device)
noise = torch.rand_like(x, device=x.device)
noised = torch.lerp(blurred, noise, noise_amount.view(-1, 1, 1, 1)) * 255
return noised
# Dataloader
dblock = DataBlock(blocks=(ImageBlock(cls=PILImageNoised),ImageBlock(cls=PILImage)),
get_items=lambda pth: range(len(dataset)), # Gets the indexes
getters=[lambda idx: np.array(dataset[idx]['image'])]*2,
batch_tfms=[Crappify])
# dls = dblock.dataloaders('', bs=128)
dls = dblock.dataloaders('', bs=64) # Half batch size to save mem when using extra nn for perceptual loss
dls.show_batch()
Due to IPython and Windows limitation, python multiprocessing isn't available now. So `number_workers` is changed to 0 to avoid getting stuck
#@title The unet model
from diffusers import UNet2DModel
class Unetwrapper(Module):
def __init__(self, in_channels=3, out_channels=3, sample_size=64):
super().__init__()
self.net = UNet2DModel(
sample_size=sample_size, # the target image resolution
in_channels=in_channels, # the number of input channels, 3 for RGB images
out_channels=out_channels, # the number of output channels
layers_per_block=2, # how many ResNet layers to use per UNet block
block_out_channels=(64, 64, 64, 128), # <<< Experiment with number of layers and how many
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"DownBlock2D", # a regular ResNet downsampling block
"AttnDownBlock2D", # a regular ResNet downsampling block
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention (uses lots of memory at higher resolutions - better to keep at lowest level or two)
),
up_block_types=(
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
def forward(self, x): return self.net(x, 100).sample # Not timestep conditioning here
model = Unetwrapper().to(device)
def RmsLookahead(params, alpha=0.5, k=6, *args, **kwargs):
rmsprop = RMSpropTF(params, *args, **kwargs)
return Lookahead(rmsprop, alpha, k)
# Create a learner and pick LR
# learn = Learner(dls, model, loss_func=MSELossFlat())
loss_fn_alex = lpips.LPIPS(net='alex').to(device)
loss_fn_mse = MSELossFlat()
def combined_loss(preds, y):
return loss_fn_alex.forward(preds, y).mean() + loss_fn_mse(preds, y)
opt_func = partial(OptimWrapper, opt=RmsLookahead)
learn = Learner(dls, model, loss_func=combined_loss, opt_func=opt_func)
learn.lr_find()
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
C:\Users\jsa16\anaconda3\envs\course22p2\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead. warnings.warn( C:\Users\jsa16\anaconda3\envs\course22p2\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights. warnings.warn(msg)
Loading model from: C:\Users\jsa16\anaconda3\envs\course22p2\lib\site-packages\lpips\weights\v0.1\alex.pth
KeyboardInterrupt
#@title Callback for logging samples (shown later in sampling section)
from PIL import Image as PILImage
class LogSamplesBasicCallback(Callback):
def after_epoch(self):
model = self.learn.model
n_steps = 40
x = torch.rand(64, 3, 64, 64).to(device)
for i in range(n_steps):
with torch.no_grad():
pred = model(x)
mix_factor = 1/(n_steps - i)
x = x*(1-mix_factor) + pred*mix_factor
im = torchvision.utils.make_grid(x.detach().cpu(), nrow=8).permute(1, 2, 0).clip(0, 1) * 255
im = PILImage.fromarray(np.array(im).astype(np.uint8))
wandb.log({'Sample generations basic':wandb.Image(im)})
def after_step(self):
if self.train_iter%100 == 0: # Also log every 100 training steps
model = self.learn.model
n_steps = 40
x = torch.rand(64, 3, 64, 64).to(device)
for i in range(n_steps):
with torch.no_grad():
pred = model(x)
mix_factor = 1/(n_steps - i)
x = x*(1-mix_factor) + pred*mix_factor
im = torchvision.utils.make_grid(x.detach().cpu(), nrow=8).permute(1, 2, 0).clip(0, 1) * 255
im = PILImage.fromarray(np.array(im).astype(np.uint8))
wandb.log({'Sample generations basic':wandb.Image(im)})
# Train a bit
cfg = dict(model.net.config)
cfg['num_epochs'] = 10
cfg['lr_max'] = 1e-4
cfg['comments'] = 'Using alex-based lpips loss, RmsPropLookahead instead of Adam'
cfg['dataset'] = 'faces'
wandb.init(project='fastdiffusion', job_type='quick train', config=cfg)
learn.fit_one_cycle(cfg['num_epochs'], lr_max=cfg['lr_max'], cbs=[WandbCallback(n_preds=8), LogSamplesBasicCallback()])
wandb.finish()
learn.show_results()
torch.save(learn.model.state_dict(), 'faces_10e_model.pt')
learn.model.load_state_dict(torch.load('./faces_10e_model.pt'))
<All keys matched successfully>
#@title sample with naive method from first test
n_steps = 40 # Set steps to 1 to see raw preds on first input
x = torch.rand(64, 3, 64, 64).to(device) # Raw noise starting point
for i in range(n_steps):
with torch.no_grad():
pred = model(x)
mix_factor = 1/(n_steps - i)
x = x*(1-mix_factor) + pred*mix_factor
fig, ax = plt.subplots(1, 1, figsize=(16, 16))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8).permute(1, 2, 0))
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage at 0x24ebe93cb80>
#@title sample with an optimiser
n_steps = 50 # We can stop early if pred noise is close to 0
x = torch.rand(64, 3, 64, 64).to(device)
x.requires_grad = True
lr_max=8e-3
div=100000.
div_final=1e5
pct_start=0.7
#lr_sched_fn = combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final)
lr_sched_fn = combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final)
optim = OptimWrapper(opt=torch.optim.SGD([x], lr=lr_max, momentum=0.01))
#optim = torch.optim.Adam([x], lr=base_lr)
#base_lr=8e-3
#optim = partial(OptimWrapper, opt=RmsLookahead)([x], lr=lr_max)
#optim = partial(OptimWrapper, opt=RMSpropTF)([x], lr=lr_max)
#optim = RAdam([x], lr=base_lr)
#optim = Lamb([x], lr=base_lr)
finetune = False
for i in range(n_steps):
pos = i+1/(n_steps)
lr = lr_sched_fn(pos)
optim.set_hyper('lr', lr)
with torch.no_grad():
net_output = model(x)
noise_pred = x - net_output
# # Early stopping
if (noise_pred.float()).var() < 0.0025 or i == n_steps-1:
x = net_output
print('Stopping at:', i)
break
# noise_pred is basically the grad, so GD on this should find a minimum!
x.grad = noise_pred.float()
#if i%10==0:
print(i, x.grad.mean(), x.grad.var(), noise_pred.float().var(), lr) # Useful to watch the noise variance
optim.step()
optim.zero_grad() # Not really needed since we're setting .grad ourselves but anyway...
fig, ax = plt.subplots(1, 1, figsize=(16, 16))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8).permute(1, 2, 0))
Note: Other optimizers have been tried but none of others seems to have the stability required here.
import torchvision.transforms.functional as TF
#@title sample with an optimiser
n_steps = 10 # We can stop early if pred noise is close to 0
torch.random.manual_seed(1000)
x = torch.rand(64, 3, 64, 64).to(device)
x.requires_grad = True
lr_max=5e-2
div=100000.
div_final=1e5
pct_start=0.5
lr_sched_fn = combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final)
optim = OptimWrapper(opt=torch.optim.SGD([x], lr=lr_max, momentum=0.01))
#optim = OptimWrapper(opt=torch.optim.Adam([x], lr=lr_max))
for i in range(n_steps):
pos = i/(n_steps)
lr = lr_sched_fn(pos)
optim.set_hyper('lr', lr)
with torch.no_grad():
net_output = (model(x) + TF.hflip(model(TF.hflip(x))))/2.0
noise_pred = x - net_output
# # Early stopping
if (noise_pred.float()).var() < 0.0025 or i == n_steps-1:
x = net_output
print('Stopping at:', i)
break
# noise_pred is basically the grad, so GD on this should find a minimum!
x.grad = noise_pred.float()
print(i, x.grad.mean(), x.grad.var(), noise_pred.float().var(), lr) # Useful to watch the noise variance
optim.step()
optim.zero_grad() # Not really needed since we're setting .grad ourselves but anyway...
fig, ax = plt.subplots(1, 1, figsize=(16, 16))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8).permute(1, 2, 0))
0 tensor(0.0665, device='cuda:0') tensor(0.1045, device='cuda:0') tensor(0.1045, device='cuda:0') 5e-07 1 tensor(0.0665, device='cuda:0') tensor(0.1045, device='cuda:0') tensor(0.1045, device='cuda:0') 0.004775027532454601 2 tensor(0.0694, device='cuda:0') tensor(0.1068, device='cuda:0') tensor(0.1068, device='cuda:0') 0.01727490284009215 3 tensor(0.0720, device='cuda:0') tensor(0.1161, device='cuda:0') tensor(0.1161, device='cuda:0') 0.032725599385994016 4 tensor(0.0694, device='cuda:0') tensor(0.1170, device='cuda:0') tensor(0.1170, device='cuda:0') 0.04522547315544385 5 tensor(0.0671, device='cuda:0') tensor(0.1082, device='cuda:0') tensor(0.1082, device='cuda:0') 0.05 6 tensor(0.0629, device='cuda:0') tensor(0.0982, device='cuda:0') tensor(0.0982, device='cuda:0') 0.04522547040384979 7 tensor(0.0587, device='cuda:0') tensor(0.0899, device='cuda:0') tensor(0.0899, device='cuda:0') 0.03272559938599401
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
8 tensor(0.0559, device='cuda:0') tensor(0.0844, device='cuda:0') tensor(0.0844, device='cuda:0') 0.017274900614005988 Stopping at: 9
<matplotlib.image.AxesImage at 0x2512cf6a410>
Notes: This doesn't seem much better than 10 steps. Note also the usage of gradient clipping in this case.
#@title sample with an optimiser
n_steps = 50 # We can stop early if pred noise is close to 0
torch.random.manual_seed(1000)
x = torch.rand(64, 3, 64, 64).to(device)
x.requires_grad = True
lr_max=1e-2
div=100000.
div_final=1e5
pct_start=0.25
lr_sched_fn = combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final)
optim = OptimWrapper(opt=torch.optim.SGD([x], lr=lr_max, momentum=0.01))
for i in range(n_steps):
pos = i/(n_steps)
lr = lr_sched_fn(pos)
optim.set_hyper('lr', lr)
with torch.no_grad():
net_output = (model(x) + TF.hflip(model(TF.hflip(x))))/2.0
noise_pred = (x - net_output)
# # Early stopping
if (noise_pred.float()).var() < 0.0025 or i == n_steps-1:
x = net_output
print('Stopping at:', i)
break
# noise_pred is basically the grad, so GD on this should find a minimum!
x.grad = noise_pred.float()
if i%10==0:
print(i, x.grad.mean(), x.grad.var(), noise_pred.float().var(), lr) # Useful to watch the noise variance
torch.nn.utils.clip_grad_norm_(x, 250.0)
optim.step()
optim.zero_grad() # Not really needed since we're setting .grad ourselves but anyway...
fig, ax = plt.subplots(1, 1, figsize=(16, 16))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8).permute(1, 2, 0))
0 tensor(0.0665, device='cuda:0') tensor(0.1045, device='cuda:0') tensor(0.1045, device='cuda:0') 1e-07 10 tensor(0.0712, device='cuda:0') tensor(0.1238, device='cuda:0') tensor(0.1238, device='cuda:0') 0.009045094631088768 20 tensor(0.0654, device='cuda:0') tensor(0.1174, device='cuda:0') tensor(0.1174, device='cuda:0') 0.009045094493509081 30 tensor(0.0596, device='cuda:0') tensor(0.1049, device='cuda:0') tensor(0.1049, device='cuda:0') 0.005522686593312814 40 tensor(0.0568, device='cuda:0') tensor(0.0988, device='cuda:0') tensor(0.0988, device='cuda:0') 0.0016544302391959132 Stopping at: 49
<matplotlib.image.AxesImage at 0x251381a1630>