import logging, torch, torchvision, torch.nn.functional as F, torchvision.transforms.functional as TF, matplotlib as mpl
from matplotlib import pyplot as plt
from functools import partial
from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
from torchvision.utils import make_grid
from datasets import load_dataset,load_dataset_builder
from miniai.datasets import *
from miniai.learner import *
from fastprogress import progress_bar
from timm.optim.rmsprop_tf import RMSpropTF
from timm.optim.adafactor import Adafactor
from timm.optim.lookahead import Lookahead
from fastai.callback.schedule import combined_cos
def RmsLookahead(params, alpha=0.5, k=6, *args, **kwargs):
opt = RMSpropTF(params, *args, **kwargs)
return Lookahead(opt, alpha, k)
def AdamLookahead(params, alpha=0.5, k=6, *args, **kwargs):
opt = optim.Adam(params, *args, **kwargs)
return Lookahead(opt, alpha, k)
mpl.rcParams['image.cmap'] = 'gray_r'
logging.disable(logging.WARNING)
Load a dataset:
x,y = 'image','label'
#name = "mnist" #"fashion_mnist"
name = "fashion_mnist"
dsd = load_dataset(name)
0%| | 0/2 [00:00<?, ?it/s]
@inplace
def transformi(b): b[x] = [TF.to_tensor(o) for o in b[x]]
opt_func = optim.Adam
lr_max = 1e-3
bs = 256
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs)
dt = dls.train
xb,yb = next(iter(dt))
xb.shape,yb[:10]
(torch.Size([256, 1, 28, 28]), tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5]))
Define a model:
def conv2dks7(inc, outc): return nn.Conv2d(inc, outc, kernel_size=7, padding=3)
def conv2dks3(inc, outc): return nn.Conv2d(inc, outc, kernel_size=3, padding=1)
def conv2dks1(inc, outc): return nn.Conv2d(inc, outc, kernel_size=1)
#norm = partial(torch.nn.GroupNorm, num_groups=8)
norm = torch.nn.BatchNorm2d
act = torch.nn.SiLU
def init_layer(inc, outc):
return torch.nn.Sequential(conv2dks7(inc, outc), act(), norm(outc))
def down_layer(inc, outc):
return torch.nn.Sequential(conv2dks7(inc, outc), act(), norm(outc))
def up_layer(inc, outc, activation=True):
layers = [conv2dks1(inc, inc//2), act(), norm(inc//2), conv2dks3(inc//2, outc)]
if activation: layers.extend([act(), norm(outc)])
return torch.nn.Sequential(*layers)
class BasicUNet(nn.Module):
"A minimal UNet implementation."
def __init__(self, inc, outc):
super().__init__()
self.down_layers = torch.nn.ModuleList([down_layer(inc,64), down_layer(64, 128), down_layer(128, 128)])
self.up_layers = torch.nn.ModuleList([up_layer(256, 128), up_layer(256,64), up_layer(128, outc, False)])
def forward(self, x):
h = []
for i, l in enumerate(self.down_layers):
x = F.silu(l(x))
h.append(x)
if i < 2: x = F.max_pool2d(x, 2)
for i, l in enumerate(self.up_layers):
if i > 0: x = F.interpolate(x, scale_factor=2)
x = torch.cat([h.pop(),x], dim=1)
x = l(x)
x = (x.sigmoid()*2)-0.5
return x
Define the corruption:
def corrupt(x, amount):
"Corrupt the input `x` by mixing it with noise according to `amount`"
noise = torch.rand_like(x)
amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works
return x*(1-amount) + noise*amount
Logging callback:
class LogLossesCB(Callback):
def __init__(self): self.losses = []
def after_batch(self): self.losses.append(self.learn.loss.item())
def after_fit(self): plt.plot(self.losses)
I chose to write a new training callback:
class OneCycle(Callback):
def __init__(self, lr_max):
lr_max = lr_max
div=25.
div_final=1e5
pct_start=0.3
self.lr_sched_fn = combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final)
self.ns = []
def after_batch(self):
self.ns.append(bs)
n_steps = len(self.learn.dls.train) * self.learn.n_epochs * bs
i = sum(self.ns)
pos = i/(n_steps)
lr = self.lr_sched_fn(pos)
self.learn.lr = lr
def before_fit(self):
lr = self.lr_sched_fn(0)
self.learn.lr = lr
class MyTrainCB(TrainCB):
def predict(self):
bs = self.learn.batch[0].shape[0]
noise_amount = torch.rand(bs).to(self.learn.batch[0].device) # Chose random corruption amount
noisy_images = corrupt(self.learn.batch[0], noise_amount) # Noisy images as net inputs
self.learn.preds = self.learn.model(noisy_images)
def get_loss(self):
self.learn.loss = self.learn.loss_func(self.learn.preds, self.learn.batch[0]) # Clean images as targets
model = BasicUNet(1, 1)
cbs = [MyTrainCB(), CudaCB(), ProgressCB(), LogLossesCB(), OneCycle(lr_max)]
learn = Learner(model, dls, nn.MSELoss(), lr=lr_max, cbs=cbs, opt_func=opt_func)
learn.fit(2)
Viewing the predictions on images with increasing noise levels:
# Some noisy data
xb = xb[:8].cpu()
amount = torch.linspace(0, 1, xb.shape[0]) # Left to right -> more corruption
noised_x = corrupt(xb, amount)
with torch.no_grad(): preds = model(noised_x.cuda()).detach().cpu()
def show_grid(ax, tens, title=None):
if title: ax.set_title(title)
ax.imshow(make_grid(tens.cpu())[0])
fig, axs = plt.subplots(3, 1, figsize=(11, 6))
show_grid(axs[0], xb, 'Input data')
show_grid(axs[1], noised_x, 'Corrupted data')
show_grid(axs[2], preds, 'Network Predictions')
plt.hist(preds[0].reshape(-1))
(array([176., 220., 8., 17., 6., 19., 13., 83., 205., 37.]),
array([-0.15290517, -0.03342845, 0.08604827, 0.20552498, 0.32500172,
0.44447842, 0.5639551 , 0.68343186, 0.8029086 , 0.9223853 ,
1.041862 ], dtype=float32),
<BarContainer object of 10 artists>)
A very basic sampling method (not ideal), just taking 5 or 10 equal-sized steps towards the models prediction:
# Take one: just break the process into 5 or 10 steps and move 1/10'th of the way there each time:
device = 'cuda'
n_steps = 5
xb = torch.rand(8, 1, 28, 28).to(device) # Start from random
step_history = [xb.detach().cpu()]
pred_output_history = []
for i in range(n_steps):
with torch.no_grad(): pred = model(xb) # Predict the denoised x0
pred_output_history.append(pred.detach().cpu()) # Store model output for plotting
mix_factor = 1/(n_steps - i) # How much we move towards the prediction
xb = xb*(1-mix_factor) + pred*mix_factor # Move part of the way there
if i < n_steps-1: step_history.append(xb.detach().cpu()) # Store step for plotting
fig, axs = plt.subplots(n_steps, 2, figsize=(15, n_steps), sharex=True)
for i in range(n_steps):
axs[i, 0].imshow(make_grid(step_history[i])[0]),
axs[i, 1].imshow(make_grid(pred_output_history[i])[0])
fig, axs = plt.subplots(n_steps, 2, figsize=(15, n_steps), sharex=True)
for i in range(n_steps):
axs[i, 0].imshow(make_grid(step_history[i])[0]),
axs[i, 1].imshow(make_grid(pred_output_history[i])[0])
Giving the model the labels as conditioning.
class ClassConditionedUNet(nn.Module):
"Wraps a BasicUNet but adds several input channels for class conditioning"
def __init__(self, in_channels, out_channels, num_classes=10, class_emb_channels=4):
super().__init__()
self.class_emb = nn.Embedding(num_classes, class_emb_channels)
self.net = BasicUNet(in_channels+class_emb_channels, out_channels) # input channels = in_channels+1+class_emb_channels
def forward(self, x, class_labels):
n,c,w,h = x.shape
class_cond = self.class_emb(class_labels) # Map to embedding dinemsion
class_cond = class_cond.view(n, class_cond.shape[1], 1, 1).expand(n, class_cond.shape[1], w, h) # Reshape
# Net input is now x, noise amound and class cond concatenated together
net_input = torch.cat((x, class_cond), 1)
return self.net(net_input)
class MyTrainCB(TrainCB):
def predict(self):
bs = self.learn.batch[0].shape[0]
noise_amount = torch.rand(bs).to(self.learn.batch[0].device)
noisy_images = corrupt(self.learn.batch[0], noise_amount)
self.learn.preds = self.learn.model(noisy_images, self.learn.batch[1]) # << Labels as conditioning
def get_loss(self): self.learn.loss = self.learn.loss_func(self.learn.preds, self.learn.batch[0])
model = ClassConditionedUNet(1, 1)
cbs = [MyTrainCB(), CudaCB(), ProgressCB(), LogLossesCB(), OneCycle(lr_max)]
learn = Learner(model, dls, nn.MSELoss(), lr=1e-3, cbs=cbs, opt_func=opt_func)
learn.fit(10)
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Input In [29], in <cell line: 1>() ----> 1 learn.fit(10) File D:\Deep Learning\fastdiffusion\nbs\jason\miniai\learner.py:147, in Learner.fit(self, n_epochs) 145 self.epochs = range(n_epochs) 146 self.opt = self.opt_func(self.model.parameters(), self.lr) --> 147 self._fit() File D:\Deep Learning\fastdiffusion\nbs\jason\miniai\learner.py:114, in with_cbs.__call__.<locals>._f(o, *args, **kwargs) 112 try: 113 o.callback(f'before_{self.nm}') --> 114 f(o, *args, **kwargs) 115 o.callback(f'after_{self.nm}') 116 except globals()[f'Cancel{self.nm.title()}Exception']: pass File D:\Deep Learning\fastdiffusion\nbs\jason\miniai\learner.py:152, in Learner._fit(self) 149 @with_cbs('fit') 150 def _fit(self): 151 for self.epoch in self.epochs: --> 152 self.one_epoch(True) 153 self.one_epoch(False) File D:\Deep Learning\fastdiffusion\nbs\jason\miniai\learner.py:137, in Learner.one_epoch(self, train) 135 self.model.train(train) 136 self.dl = self.dls.train if train else self.dls.valid --> 137 self._one_epoch() File D:\Deep Learning\fastdiffusion\nbs\jason\miniai\learner.py:114, in with_cbs.__call__.<locals>._f(o, *args, **kwargs) 112 try: 113 o.callback(f'before_{self.nm}') --> 114 f(o, *args, **kwargs) 115 o.callback(f'after_{self.nm}') 116 except globals()[f'Cancel{self.nm.title()}Exception']: pass File D:\Deep Learning\fastdiffusion\nbs\jason\miniai\learner.py:141, in Learner._one_epoch(self) 139 @with_cbs('epoch') 140 def _one_epoch(self): --> 141 for self.iter,self.batch in enumerate(self.dl): self.one_batch() File D:\Deep Learning\fastdiffusion\nbs\jason\miniai\learner.py:114, in with_cbs.__call__.<locals>._f(o, *args, **kwargs) 112 try: 113 o.callback(f'before_{self.nm}') --> 114 f(o, *args, **kwargs) 115 o.callback(f'after_{self.nm}') 116 except globals()[f'Cancel{self.nm.title()}Exception']: pass File D:\Deep Learning\fastdiffusion\nbs\jason\miniai\learner.py:131, in Learner.one_batch(self) 129 if self.model.training: 130 self.backward() --> 131 self.step() 132 self.zero_grad() File D:\Deep Learning\fastdiffusion\nbs\jason\miniai\learner.py:160, in Learner.callback(self, method_nm) 159 def callback(self, method_nm): --> 160 for cb in sorted(self.cbs, key=attrgetter('order')): getattr(cb, method_nm,identity)() File D:\Deep Learning\fastdiffusion\nbs\jason\miniai\learner.py:170, in TrainCB.step(self) --> 170 def step(self): self.learn.opt.step() File ~\anaconda3\envs\course22p2\lib\site-packages\torch\optim\optimizer.py:113, in Optimizer._hook_for_profile.<locals>.profile_hook_step.<locals>.wrapper(*args, **kwargs) 111 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__) 112 with torch.autograd.profiler.record_function(profile_name): --> 113 return func(*args, **kwargs) File ~\anaconda3\envs\course22p2\lib\site-packages\torch\autograd\grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs) 24 @functools.wraps(func) 25 def decorate_context(*args, **kwargs): 26 with self.clone(): ---> 27 return func(*args, **kwargs) File ~\anaconda3\envs\course22p2\lib\site-packages\torch\optim\adam.py:157, in Adam.step(self, closure) 153 max_exp_avg_sqs.append(state['max_exp_avg_sq']) 155 state_steps.append(state['step']) --> 157 adam(params_with_grad, 158 grads, 159 exp_avgs, 160 exp_avg_sqs, 161 max_exp_avg_sqs, 162 state_steps, 163 amsgrad=group['amsgrad'], 164 beta1=beta1, 165 beta2=beta2, 166 lr=group['lr'], 167 weight_decay=group['weight_decay'], 168 eps=group['eps'], 169 maximize=group['maximize'], 170 foreach=group['foreach'], 171 capturable=group['capturable']) 173 return loss File ~\anaconda3\envs\course22p2\lib\site-packages\torch\optim\adam.py:213, in adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach, capturable, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize) 210 else: 211 func = _single_tensor_adam --> 213 func(params, 214 grads, 215 exp_avgs, 216 exp_avg_sqs, 217 max_exp_avg_sqs, 218 state_steps, 219 amsgrad=amsgrad, 220 beta1=beta1, 221 beta2=beta2, 222 lr=lr, 223 weight_decay=weight_decay, 224 eps=eps, 225 maximize=maximize, 226 capturable=capturable) File ~\anaconda3\envs\course22p2\lib\site-packages\torch\optim\adam.py:258, in _single_tensor_adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize, capturable) 255 # update step 256 step_t += 1 --> 258 if weight_decay != 0: 259 grad = grad.add(param, alpha=weight_decay) 261 # Decay the first and second moment running average coefficient KeyboardInterrupt:
Sampling as before over 20 steps, but this time with the labels as conditioning:
n_steps = 20
xb = torch.rand(80, 1, 28, 28).cuda()
yb = torch.tensor([[i]*8 for i in range(10)]).flatten().cuda()
for i in range(n_steps):
noise_amount = torch.ones((xb.shape[0], )).to(device) * (1-(i/n_steps))
with torch.no_grad():
pred = model(xb, yb)
mix_factor = 1/(n_steps - i)
xb = xb*(1-mix_factor) + pred*mix_factor
# Optional: Add a bit of extra noise back at early steps
if i < 10: xb = corrupt(xb, torch.ones((xb.shape[0], )).to(device)*0.05)
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(make_grid(xb.detach().cpu().clip(0, 1), nrow=8)[0]);
You can try fashion_mnist as the dataset without making any changes. This seems to work (suprisingly given the lack of fiddling with training and architecture).