#import os
#os.environ['CUDA_VISIBLE_DEVICES']='1'
#os.environ['OMP_NUM_THREADS']='1'
import logging, torch, torchvision, torch.nn.functional as F, torchvision.transforms.functional as TF, matplotlib as mpl
from matplotlib import pyplot as plt
from functools import partial
from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
from torchvision.utils import make_grid
from datasets import load_dataset,load_dataset_builder
from miniai.datasets import *
from miniai.learner import *
from fastprogress import progress_bar
from timm.optim.rmsprop_tf import RMSpropTF
from timm.optim.adafactor import Adafactor
from timm.optim.lookahead import Lookahead
from fastai.callback.schedule import combined_cos
from fastai.layers import SequentialEx, MergeLayer
from fastai.losses import MSELossFlat
from fastcore.basics import store_attr
from fastai.torch_core import TensorImage
from fastai.optimizer import OptimWrapper
from einops import rearrange, repeat
import PIL
import numpy as np
# Perceptual loss
import lpips
def RmsLookahead(params, alpha=0.5, k=6, *args, **kwargs):
opt = RMSpropTF(params, *args, **kwargs)
return Lookahead(opt, alpha, k)
def AdamLookahead(params, alpha=0.5, k=6, *args, **kwargs):
opt = optim.Adam(params, *args, **kwargs)
return Lookahead(opt, alpha, k)
opt_func = optim.Adam
lr_max = 1e-3
bs = 256
sz = 28
device = 'cuda'
unet_width = 1
unet_resids = [2,4,4]
num_classes = 10
dropout_rate = 0.05
mean = 0.2859
std = 0.353
def normalize_img(x):
#return x
return (x-mean)/std
def denormalize_img(x):
#return x
return (x*std)+mean
loss_fn_alex = lpips.LPIPS(net='alex').to(device)
loss_fn_mse = MSELossFlat()
resize_sz = 64 if sz < 64 else sz
pl_resizer = torchvision.transforms.Resize(resize_sz)
def combined_loss(preds, y):
eps, x0, xt = y
#return loss_fn_mse(preds, eps)
pred_images = denormalize_img(pl_resizer(preds+xt))
target_images = denormalize_img(pl_resizer(x0))
return loss_fn_alex.forward(pred_images, target_images).mean() + loss_fn_mse(preds, eps)
loss_func = combined_loss
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
C:\Users\Jason\anaconda3\envs\course22p2\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead. warnings.warn( C:\Users\Jason\anaconda3\envs\course22p2\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights. warnings.warn(msg)
Loading model from: C:\Users\Jason\anaconda3\envs\course22p2\lib\site-packages\lpips\weights\v0.1\alex.pth
mpl.rcParams['image.cmap'] = 'gray_r'
logging.disable(logging.WARNING)
Load a dataset:
x,y = 'image','label'
#name = "mnist" #"fashion_mnist"
name = "fashion_mnist"
dsd = load_dataset(name)
0%| | 0/2 [00:00<?, ?it/s]
@inplace
def transformi(b):
b[x] = [normalize_img(TF.to_tensor(o)) for o in b[x]]
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs)
dt = dls.train
xb,yb = next(iter(dt))
xb.shape,yb[:10]
(torch.Size([256, 1, 28, 28]), tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5]))
Define a model:
def conv2dks7(inc, outc, stride=1): return nn.Sequential(nn.Conv2d(inc, outc, kernel_size=7, padding=3, stride=stride), nn.Dropout2d(dropout_rate, inplace=True))
def conv2dks5(inc, outc, stride=1): return nn.Sequential(nn.Conv2d(inc, outc, kernel_size=7, padding=2, stride=stride), nn.Dropout2d(dropout_rate, inplace=True))
def conv2dks3(inc, outc, stride=1): return nn.Sequential(nn.Conv2d(inc, outc, kernel_size=3, padding=1, stride=stride), nn.Dropout2d(dropout_rate, inplace=True))
def conv2dks1(inc, outc): return nn.Sequential(nn.Conv2d(inc, outc, kernel_size=1), nn.Dropout2d(dropout_rate, inplace=True))
def conv2dtrans(inc, outc): return nn.Sequential(nn.ConvTranspose2d(inc, outc, 4, 2, 1), nn.Dropout2d(dropout_rate, inplace=True))
"""
code from here: https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
"""
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
class PixelShuffleUpsample(nn.Module):
def __init__(self, dim, dim_out = None):
super().__init__()
dim_out = default(dim_out, dim)
conv = nn.Conv2d(dim, dim_out * 4, 1)
self.net = nn.Sequential(
conv,
nn.SiLU(),
nn.PixelShuffle(2)
)
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
return self.net(x)
# from https://github.com/crowsonkb/k-diffusion/blob/f4e99857772fc3a126ba886aadf795a332774878/k_diffusion/layers.py
_kernels = {
'linear':
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
'cubic':
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
0.43359375, 0.11328125, -0.03515625, -0.01171875],
'lanczos3':
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
}
_kernels['bilinear'] = _kernels['linear']
_kernels['bicubic'] = _kernels['cubic']
class Downsample2d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect'):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([_kernels[kernel]])
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
def forward(self, x):
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv2d(x, weight, stride=2)
norm = partial(torch.nn.GroupNorm, num_groups=8)
#norm = torch.nn.BatchNorm2d
act = torch.nn.SiLU
def conv_layer(conv_fn, inc, outc, stride=1):
return torch.nn.Sequential(conv_fn(inc, outc), act(), norm(num_channels=outc))
def residual_layer(inc):
return SequentialEx(conv_layer(conv2dks3, inc, inc),
conv_layer(conv2dks3, inc, inc),
MergeLayer(False))
def down_layer(inc, outc, resids=1):
layers = [Downsample2d(),
conv_layer(conv2dks3, inc, outc)]
for i in range(resids):
layers.append(residual_layer(outc))
return torch.nn.Sequential(*layers)
def up_layer(inc, outc, resids=1):
layers = [conv_layer(conv2dks1, inc, inc//2)]
for i in range(resids):
layers.append(residual_layer(inc//2))
layers.extend([PixelShuffleUpsample(inc//2, outc), act(), norm(num_channels=outc)])
return torch.nn.Sequential(*layers)
def out_layer(inc, midc, outc, resids=1):
layers = [conv_layer(conv2dks1, inc, midc)]
for i in range(resids):
layers.append(residual_layer(midc))
layers.append(conv_layer(conv2dks1, midc, outc))
return torch.nn.Sequential(*layers)
class BasicUNet(nn.Module):
"A minimal UNet implementation."
def __init__(self, inc, outc, resids=[1,2,2], width:int=1):
super().__init__()
self.proj_in = nn.Conv2d(inc, 128*width, 1)
self.down_layers = torch.nn.ModuleList([down_layer(128*width, 128*width, resids[0]),
down_layer(128*width, 128*width, resids[1]),
down_layer(128*width, 256*width, resids[2])])
self.up_layers = torch.nn.ModuleList([up_layer(256*width, 256*width, resids[2]),
up_layer(384*width,128*width, resids[1]),
up_layer(256*width, 128*width, resids[0])])
self.out_layers = out_layer(inc+128*width, 64*width, 32*width, resids[0])
self.proj_out = nn.Conv2d(32*width, outc, 1)
def forward(self, x):
x_orig = x.clone()
x = self.proj_in(x)
h = []
for i, l in enumerate(self.down_layers):
x = l(x)
h.append(x)
#if i < 2: x = F.max_pool2d(x, 2)
for i, l in enumerate(self.up_layers):
x_cross = h.pop()
if x.shape[-2:] != x_cross.shape[-2:]:
x = F.interpolate(x, x_cross.shape[-2:], mode='nearest')
x = torch.cat([x_cross,x], dim=1) if i > 0 else x
x = l(x)
if x.shape[-2:] != x_orig.shape[-2:]:
x = F.interpolate(x, x_orig.shape[-2:], mode='nearest')
x = torch.cat([x_orig,x], dim=1)
x = self.out_layers(x)
x = self.proj_out(x)
# eps can range from about -6 to 6 (0 mean 1 std normal dist *2, basically)
x = x.sigmoid()*12-6.0
return x
def generate_random_noise(batch_size):
#return normalize_img(torch.rand(batch_size, 1, sz, sz))
return torch.randn(batch_size, 1, sz, sz)
def generate_random_noise_like(x):
#return normalize_img(torch.rand_like(x))
return torch.randn_like(x)
Define the corruption:
def corrupt(x, amount):
"Corrupt the input `x` by mixing it with noise according to `amount`"
noise = generate_random_noise_like(x)
amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works
xt = (x*(1-amount) + noise*amount)
#xt = x- x*amount + noise*amount
#xt = x-amount*(x - noise)
#xt + amount*(x-noise) = x
eps = amount*(x-noise)
#x = xt + eps
return xt, eps
class ConditionalDDPMCallback(Callback):
def __init__(self, n_steps=1000, beta_min=0.0001, beta_max=0.02, label_conditioning=False):
store_attr()
self.tensor_type=TensorImage
self.label_conditioning = label_conditioning
def before_fit(self):
self.beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps).to(device) # variance schedule, linearly increased with timestep
self.alpha = 1. - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
self.sigma = torch.sqrt(self.beta)
def sample_timesteps(self, x, dtype=torch.long):
return torch.randint(self.n_steps, (x.shape[0],), device=device, dtype=dtype)
def generate_noise(self, x):
return generate_random_noise_like(x)
def noise_image(self, x, eps, t):
alpha_bar_t = self.alpha_bar[t][:, None, None, None]
return torch.sqrt(alpha_bar_t)*x + torch.sqrt(1-alpha_bar_t)*eps # noisify the image
def before_batch_training(self):
x0 = self.learn.batch[0] # original images
x1 = self.learn.batch[1] if self.label_conditioning else None
noise_amount = torch.rand(x0.shape[0]).to(device) # Chose random corruption amount
xt, eps = corrupt(x0, noise_amount)
#eps = self.generate_noise(x0) # noise same shape as x0
#t = self.sample_timesteps(x0) # select random timesteps
#xt = self.noise_image(x0, eps, t) # add noise to the image
self.learn.xb = (xt,) if x1 is None else (xt,x1)# input to our model is noisy image and timestep
self.learn.yb = (eps,x0,xt) # train model to predict noise AND good image from subtracting that noise
def sampling_algo(self, xt, t):
t_batch = torch.full((xt.shape[0],), t, device=device, dtype=torch.long)
z = self.generate_noise(xt) if t > 0 else torch.zeros_like(xt)
alpha_t = self.alpha[t] # get noise level at current timestep
alpha_bar_t = self.alpha_bar[t]
sigma_t = self.sigma[t]
alpha_bar_t_1 = self.alpha_bar[t-1] if t > 0 else torch.tensor(1, device=device)
beta_bar_t = 1 - alpha_bar_t
beta_bar_t_1 = 1 - alpha_bar_t_1
x0hat = (xt - torch.sqrt(beta_bar_t) * self.model(xt))/torch.sqrt(alpha_bar_t)
x0hat = torch.clamp(x0hat, -1, 1)
xt = x0hat * torch.sqrt(alpha_bar_t_1)*(1-alpha_t)/beta_bar_t + xt * torch.sqrt(alpha_t)*beta_bar_t_1/beta_bar_t + sigma_t*z
return xt
def sample(self):
xt = self.generate_noise(self.xb[0]) # a full batch at once!
for t in progress_bar(reversed(range(self.n_steps)), total=self.n_steps, leave=False):
xt = self.sampling_algo(xt, t)
return xt
def before_batch_sampling(self):
xt = self.sample()
self.learn.pred = (xt,)
raise CancelBatchException
def before_batch(self):
if not hasattr(self, 'gather_preds'): self.before_batch_training()
else: self.before_batch_sampling()
class MyTrainCB(TrainCB):
def predict(self):
xb = self.learn.xb
self.learn.preds = self.learn.model(*xb)
def get_loss(self):
self.learn.loss = self.learn.loss_func(self.learn.preds, self.learn.yb)
Logging callback:
class LogLossesCB(Callback):
def __init__(self): self.losses = []
def after_batch(self): self.losses.append(self.learn.loss.item())
def after_fit(self): plt.plot(self.losses)
I chose to write a new training callback:
class OneCycle(Callback):
def __init__(self, lr_max):
lr_max = lr_max
div=25.
div_final=1e5
pct_start=0.3
self.lr_sched_fn = combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final)
self.ns = []
def after_batch(self):
self.ns.append(bs)
n_steps = len(self.learn.dls.train) * self.learn.n_epochs * bs
i = sum(self.ns)
pos = i/(n_steps)
lr = self.lr_sched_fn(pos)
self.learn.lr = lr
def before_fit(self):
lr = self.lr_sched_fn(0)
self.learn.lr = lr
model = BasicUNet(1, 1, width=unet_width, resids=unet_resids)
cbs = [CudaCB(), ConditionalDDPMCallback(label_conditioning=False), MyTrainCB(), ProgressCB(), LogLossesCB(), OneCycle(lr_max)]
learn = Learner(model, dls, loss_func, lr=lr_max, cbs=cbs, opt_func=opt_func)
learn.fit(2)
Viewing the predictions on images with increasing noise levels:
# Some noisy data
xb = xb[:8].cpu()
amount = torch.linspace(0, 1, xb.shape[0]) # Left to right -> more corruption
noised_x, noize = corrupt(xb, amount)
with torch.no_grad():
preds = model(noised_x.cuda()).detach().cpu()
pred_images = denormalize_img(noised_x + preds)
def show_grid(ax, tens, title=None):
if title: ax.set_title(title)
ax.imshow(make_grid(tens.cpu())[0])
fig, axs = plt.subplots(3, 1, figsize=(11, 6))
show_grid(axs[0], xb, 'Input data')
show_grid(axs[1], noised_x, 'Corrupted data')
show_grid(axs[2], pred_images, 'Network Predictions')
plt.hist(preds[0].reshape(-1))
(array([ 2., 3., 6., 72., 366., 185., 90., 38., 18., 4.]),
array([-0.24057674, -0.18854837, -0.13652 , -0.08449163, -0.03246326,
0.01956511, 0.07159348, 0.12362184, 0.17565021, 0.22767858,
0.27970695], dtype=float32),
<BarContainer object of 10 artists>)
A very basic sampling method (not ideal), just taking 5 or 10 equal-sized steps towards the models prediction:
# Take one: just break the process into 5 or 10 steps and move 1/10'th of the way there each time:
n_steps = 5
xb = generate_random_noise(8).to(device)# Start from random
step_history = [xb.detach().cpu()]
pred_output_history = []
for i in range(n_steps):
with torch.no_grad():
pred = model(xb) # Predict the denoised x0
pred_image = xb + pred
pred_output_history.append(pred_image.detach().cpu()) # Store model output for plotting
mix_factor = 1/(n_steps - i) # How much we move towards the prediction
xb = xb*(1-mix_factor) + pred_image*mix_factor # Move part of the way there
if i < n_steps-1: step_history.append(denormalize_img(xb.detach().cpu())) # Store step for plotting
fig, axs = plt.subplots(n_steps, 2, figsize=(15, n_steps), sharex=True)
for i in range(n_steps):
axs[i, 0].imshow(make_grid(step_history[i])[0]),
axs[i, 1].imshow(make_grid(pred_output_history[i])[0])
fig, axs = plt.subplots(n_steps, 2, figsize=(15, n_steps), sharex=True)
for i in range(n_steps):
axs[i, 0].imshow(make_grid(step_history[i])[0]),
axs[i, 1].imshow(make_grid(pred_output_history[i])[0])
Giving the model the labels as conditioning.
class ClassConditionedUNet(nn.Module):
"Wraps a BasicUNet but adds several input channels for class conditioning"
def __init__(self, in_channels, out_channels, num_classes=10, class_emb_channels=4, width=1, resids=1):
super().__init__()
self.class_emb = nn.Embedding(num_classes, class_emb_channels)
self.net = BasicUNet(in_channels+class_emb_channels, out_channels, width=width, resids=resids) # input channels = in_channels+1+class_emb_channels
def forward(self, x, class_labels):
n,c,w,h = x.shape
class_cond = self.class_emb(class_labels) # Map to embedding dinemsion
class_cond = class_cond.view(n, class_cond.shape[1], 1, 1).expand(n, class_cond.shape[1], w, h) # Reshape
# Net input is now x, noise amound and class cond concatenated together
net_input = torch.cat((x, class_cond), 1)
return self.net(net_input)
model = ClassConditionedUNet(1, 1, num_classes=num_classes, width=unet_width, resids=unet_resids)
cbs = [CudaCB(), ConditionalDDPMCallback(label_conditioning=True), MyTrainCB(), ProgressCB(), LogLossesCB(), OneCycle(lr_max)]
learn = Learner(model, dls, loss_func, lr=1e-3, cbs=cbs, opt_func=opt_func)
learn.fit(50)
torch.save(learn.model.state_dict(), 'fashion_mnist.pt')
model.load_state_dict(torch.load('./fashion_mnist.pt'))
model=model.to(device).eval()
# source: https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
def right_pad_dims_to(x, t):
padding_dims = x.ndim - t.ndim
if padding_dims <= 0:
return t
return t.view(*t.shape, *((1,) * padding_dims))
def dynamic_thresholding(x_start, dynamic_thresholding_percentile=0.9, max_val=3.0):
x_start = x_start/max_val
s = torch.quantile(
rearrange(x_start, 'b ... -> b (...)').abs(),
dynamic_thresholding_percentile,
dim = -1
)
s.clamp_(min=1.0)
s = right_pad_dims_to(x_start, s)
x_start = x_start.clamp(-s, s) / s
return x_start*max_val
def sample_based_on_optimizer(model, xt, class_labels, optim_fn, n_steps, lr_max, pct_start):
div=100000.
div_final=1e5
lr_sched_fn = combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final)
optim = OptimWrapper(opt=optim_fn([xt], lr=lr_max))
prev_noise_var = 1e6
eps = None
xt.requires_grad = True
for i in range(n_steps):
pos = i/(n_steps)
lr = lr_sched_fn(pos)
optim.set_hyper('lr', lr)
in_xt = xt
with torch.no_grad():
in_xt = dynamic_thresholding(in_xt, 0.9, 3.0)
eps = model(in_xt, class_labels)
#noise_pred = (net_output + flip_output)/2.0
# # Early stopping
if (eps.float()).var() < 0.01 or i == n_steps-1:
print('Stopping at:', i)
return xt, eps
#print(xt-in_xt)
xt.grad = -eps.float() + (xt-in_xt)
grad_mean = xt.grad.mean()
prev_noise_var = eps.float().var()
if i%10==0:
#print('xt:', i, xt.mean(), xt.max(), xt.min())
#print('eps:', i, eps.mean(), eps.max(), eps.min(), eps.float().var())
print(i, xt.grad.mean(), xt.grad.var(), xt.float().var(), lr) # Useful to watch the noise variance
optim.step()
optim.zero_grad()
if i%10==0:
print('xt:', i, xt.mean(), xt.max(), xt.min())
#print('eps:', i, eps.mean(), eps.max(), eps.min(), eps.float().var())
#print(i, xt.grad.mean(), xt.grad.var(), xt.float().var(), lr) # Useful to watch the noise variance
#print(grad_mean)
if grad_mean < 0.0:
return xt, eps
# print('grad: ' + str(grad_mean))
#xt = xt + torch.randn_like(xt)*0.0025
# amount = torch.ones((xt.shape[0], )).to(device)*0.005
# xt,_ = corrupt(xt, amount)
return xt, eps
def generate_opt_based_samples(model, xt, class_labels):
sgd_optim_fn = partial(torch.optim.SGD, momentum=0.01)
xt = xt.clone()
#xt, eps = sample_based_on_optimizer(model, xt=xt, class_labels=class_labels, optim_fn=sgd_optim_fn, n_steps=40, lr_max=2e-2, pct_start=0.25)
xt, eps = sample_based_on_optimizer(model, xt=xt, class_labels=class_labels, optim_fn=sgd_optim_fn, n_steps=15, lr_max=3e-2, pct_start=0.5)
xt, eps = sample_based_on_optimizer(model, xt=xt, class_labels=class_labels, optim_fn=RmsLookahead, n_steps=20, lr_max=5e-2, pct_start=0.25)
pred_image = xt+eps
print(pred_image.min(), pred_image.max(), pred_image.mean())
return pred_image, eps
torch.random.manual_seed(1000)
num_examples = 80
num_classes = 10
xt = generate_random_noise(num_examples).to(device)
## Seem to get better results with 0.95 noise_amount rather than 1.0
noise_amount = torch.ones((xt.shape[0], )).to(device)*0.95
xt,_ = corrupt(xt, noise_amount)
class_labels = torch.tensor([[i]*(num_examples//num_classes) for i in range(num_classes)]).flatten().to(device)
pred_image, eps = generate_opt_based_samples(model, xt=xt, class_labels=class_labels)
pred_image = denormalize_img(pred_image).detach().cpu()
xt = xt.detach().cpu()
eps = eps.detach().cpu()
pred_img = (pred_image).clip(0, 1)
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(make_grid(pred_img, nrow=8)[0]);
0 tensor(0.0883, device='cuda:0', grad_fn=<MeanBackward0>) tensor(1.3084, device='cuda:0', grad_fn=<VarBackward0>) tensor(0.8984, device='cuda:0', grad_fn=<VarBackward0>) 3e-07 xt: 0 tensor(-0.0039, device='cuda:0', grad_fn=<MeanBackward0>) tensor(4.6173, device='cuda:0', grad_fn=<MaxBackward1>) tensor(-3.9280, device='cuda:0', grad_fn=<MinBackward1>) 10 tensor(0.1194, device='cuda:0', grad_fn=<MeanBackward0>) tensor(0.9403, device='cuda:0', grad_fn=<VarBackward0>) tensor(0.6612, device='cuda:0', grad_fn=<VarBackward0>) 0.022500073378353153 xt: 10 tensor(-0.0253, device='cuda:0', grad_fn=<MeanBackward0>) tensor(3.9515, device='cuda:0', grad_fn=<MaxBackward1>) tensor(-3.3798, device='cuda:0', grad_fn=<MinBackward1>) Stopping at: 14 0 tensor(0.1173, device='cuda:0', grad_fn=<MeanBackward0>) tensor(0.8501, device='cuda:0', grad_fn=<VarBackward0>) tensor(0.6080, device='cuda:0', grad_fn=<VarBackward0>) 5e-07 xt: 0 tensor(-0.0291, device='cuda:0', grad_fn=<MeanBackward0>) tensor(3.8627, device='cuda:0', grad_fn=<MaxBackward1>) tensor(-3.2993, device='cuda:0', grad_fn=<MinBackward1>) 10 tensor(0.0774, device='cuda:0', grad_fn=<MeanBackward0>) tensor(0.5102, device='cuda:0', grad_fn=<VarBackward0>) tensor(0.4621, device='cuda:0', grad_fn=<VarBackward0>) 0.03750012432431383 xt: 10 tensor(-0.0644, device='cuda:0', grad_fn=<MeanBackward0>) tensor(3.4306, device='cuda:0', grad_fn=<MaxBackward1>) tensor(-2.8415, device='cuda:0', grad_fn=<MinBackward1>) Stopping at: 19 tensor(-1.0224, device='cuda:0', grad_fn=<MinBackward1>) tensor(2.3579, device='cuda:0', grad_fn=<MaxBackward1>) tensor(-0.1371, device='cuda:0', grad_fn=<MeanBackward0>)
Sampling as before over 20 steps, but this time with the labels as conditioning:
n_steps = 15
xb = generate_random_noise(80).cuda()
yb = torch.tensor([[i]*8 for i in range(10)]).flatten().cuda()
for i in range(n_steps):
noise_amount = torch.ones((xb.shape[0], )).to(device) * (1-(i/n_steps))
with torch.no_grad():
pred = model(xb, yb)
pred_image = xb + pred
#pred_image = dynamic_thresholding(pred_image, 0.9, 3.0)
mix_factor = 1/(n_steps - i)
xb = xb*(1-mix_factor) + pred_image*mix_factor
#xb = dynamic_thresholding(xb, 0.85)
#xb = dynamic_thresholding(xb, 0.9999)
# Optional: Add a bit of extra noise back at early steps
if i < 10: xb,_ = corrupt(xb, torch.ones((xb.shape[0], )).to(device)*0.05)
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(make_grid(denormalize_img(xb).detach().cpu().clip(0, 1), nrow=8)[0]);
You can try fashion_mnist as the dataset without making any changes. This seems to work (suprisingly given the lack of fiddling with training and architecture).
def count_parameters(module):
return sum(p.numel() for p in module.parameters())
count_parameters(model)
12121577