import logging, torch, torchvision, torch.nn.functional as F, torchvision.transforms.functional as TF, matplotlib as mpl
from matplotlib import pyplot as plt
from functools import partial
from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
from torchvision.utils import make_grid
from datasets import load_dataset,load_dataset_builder
from miniai.datasets import *
from miniai.learner import *
from fastprogress import progress_bar
from timm.optim.rmsprop_tf import RMSpropTF
from timm.optim.adafactor import Adafactor
from timm.optim.lookahead import Lookahead
from fastai.callback.schedule import combined_cos
def RmsLookahead(params, alpha=0.5, k=6, *args, **kwargs):
opt = RMSpropTF(params, *args, **kwargs)
return Lookahead(opt, alpha, k)
def AdamLookahead(params, alpha=0.5, k=6, *args, **kwargs):
opt = optim.Adam(params, *args, **kwargs)
return Lookahead(opt, alpha, k)
mpl.rcParams['image.cmap'] = 'gray_r'
logging.disable(logging.WARNING)
Load a dataset:
x,y = 'image','label'
#name = "mnist" #"fashion_mnist"
name = "fashion_mnist"
dsd = load_dataset(name)
0%| | 0/2 [00:00<?, ?it/s]
@inplace
def transformi(b): b[x] = [TF.to_tensor(o) for o in b[x]]
opt_func = optim.Adam
lr_max = 1e-3
bs = 256
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs)
dt = dls.train
xb,yb = next(iter(dt))
xb.shape,yb[:10]
(torch.Size([256, 1, 28, 28]), tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5]))
Define a model:
def conv2dks7(inc, outc): return nn.Conv2d(inc, outc, kernel_size=7, padding=3)
def conv2dks3(inc, outc): return nn.Conv2d(inc, outc, kernel_size=3, padding=1)
def conv2dks1(inc, outc): return nn.Conv2d(inc, outc, kernel_size=1)
norm = partial(torch.nn.GroupNorm, num_groups=8)
#norm = torch.nn.BatchNorm2d
act = torch.nn.SiLU
def init_layer(inc, outc):
return torch.nn.Sequential(conv2dks7(inc, outc), act(), norm(num_channels=outc))
def down_layer(inc, outc):
return torch.nn.Sequential(conv2dks7(inc, outc), act(), norm(num_channels=outc))
def up_layer(inc, outc, activation=True):
layers = [conv2dks1(inc, inc//2), act(), norm(num_channels=inc//2), conv2dks3(inc//2, outc)]
if activation: layers.extend([act(), norm(num_channels=outc)])
return torch.nn.Sequential(*layers)
class BasicUNet(nn.Module):
"A minimal UNet implementation."
def __init__(self, inc, outc):
super().__init__()
self.down_layers = torch.nn.ModuleList([down_layer(inc,32), down_layer(32, 64), down_layer(64, 64)])
self.up_layers = torch.nn.ModuleList([up_layer(128, 64), up_layer(128,32), up_layer(64, outc, False)])
def forward(self, x):
h = []
for i, l in enumerate(self.down_layers):
x = F.silu(l(x))
h.append(x)
if i < 2: x = F.max_pool2d(x, 2)
for i, l in enumerate(self.up_layers):
if i > 0: x = F.interpolate(x, scale_factor=2)
x = torch.cat([h.pop(),x], dim=1)
x = l(x)
#x = (x.sigmoid()*2)-0.5
x= x.sigmoid()
return x
Define the corruption:
def corrupt(x, amount):
"Corrupt the input `x` by mixing it with noise according to `amount`"
noise = torch.rand_like(x)
amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works
return x*(1-amount) + noise*amount
Logging callback:
class LogLossesCB(Callback):
def __init__(self): self.losses = []
def after_batch(self): self.losses.append(self.learn.loss.item())
def after_fit(self): plt.plot(self.losses)
I chose to write a new training callback:
class OneCycle(Callback):
def __init__(self, lr_max):
lr_max = lr_max
div=25.
div_final=1e5
pct_start=0.3
self.lr_sched_fn = combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final)
self.ns = []
def after_batch(self):
self.ns.append(bs)
n_steps = len(self.learn.dls.train) * self.learn.n_epochs * bs
i = sum(self.ns)
pos = i/(n_steps)
lr = self.lr_sched_fn(pos)
self.learn.lr = lr
def before_fit(self):
lr = self.lr_sched_fn(0)
self.learn.lr = lr
class MyTrainCB(TrainCB):
def predict(self):
bs = self.learn.batch[0].shape[0]
noise_amount = torch.rand(bs).to(self.learn.batch[0].device) # Chose random corruption amount
noisy_images = corrupt(self.learn.batch[0], noise_amount) # Noisy images as net inputs
self.learn.preds = self.learn.model(noisy_images)
def get_loss(self):
self.learn.loss = self.learn.loss_func(self.learn.preds, self.learn.batch[0]) # Clean images as targets
model = BasicUNet(1, 1)
cbs = [MyTrainCB(), CudaCB(), ProgressCB(), LogLossesCB(), OneCycle(lr_max)]
learn = Learner(model, dls, nn.MSELoss(), lr=lr_max, cbs=cbs, opt_func=opt_func)
learn.fit(2)
Viewing the predictions on images with increasing noise levels:
# Some noisy data
xb = xb[:8].cpu()
amount = torch.linspace(0, 1, xb.shape[0]) # Left to right -> more corruption
noised_x = corrupt(xb, amount)
with torch.no_grad(): preds = model(noised_x.cuda()).detach().cpu()
def show_grid(ax, tens, title=None):
if title: ax.set_title(title)
ax.imshow(make_grid(tens.cpu())[0])
fig, axs = plt.subplots(3, 1, figsize=(11, 6))
show_grid(axs[0], xb, 'Input data')
show_grid(axs[1], noised_x, 'Corrupted data')
show_grid(axs[2], preds, 'Network Predictions')
plt.hist(preds[0].reshape(-1))
(array([383., 24., 12., 5., 9., 10., 15., 57., 207., 62.]),
array([2.1125010e-05, 9.6003935e-02, 1.9198675e-01, 2.8796956e-01,
3.8395238e-01, 4.7993517e-01, 5.7591802e-01, 6.7190081e-01,
7.6788360e-01, 8.6386645e-01, 9.5984924e-01], dtype=float32),
<BarContainer object of 10 artists>)
A very basic sampling method (not ideal), just taking 5 or 10 equal-sized steps towards the models prediction:
# Take one: just break the process into 5 or 10 steps and move 1/10'th of the way there each time:
device = 'cuda'
n_steps = 5
xb = torch.rand(8, 1, 28, 28).to(device) # Start from random
step_history = [xb.detach().cpu()]
pred_output_history = []
for i in range(n_steps):
with torch.no_grad(): pred = model(xb) # Predict the denoised x0
pred_output_history.append(pred.detach().cpu()) # Store model output for plotting
mix_factor = 1/(n_steps - i) # How much we move towards the prediction
xb = xb*(1-mix_factor) + pred*mix_factor # Move part of the way there
if i < n_steps-1: step_history.append(xb.detach().cpu()) # Store step for plotting
fig, axs = plt.subplots(n_steps, 2, figsize=(15, n_steps), sharex=True)
for i in range(n_steps):
axs[i, 0].imshow(make_grid(step_history[i])[0]),
axs[i, 1].imshow(make_grid(pred_output_history[i])[0])
fig, axs = plt.subplots(n_steps, 2, figsize=(15, n_steps), sharex=True)
for i in range(n_steps):
axs[i, 0].imshow(make_grid(step_history[i])[0]),
axs[i, 1].imshow(make_grid(pred_output_history[i])[0])
Giving the model the labels as conditioning.
class ClassConditionedUNet(nn.Module):
"Wraps a BasicUNet but adds several input channels for class conditioning"
def __init__(self, in_channels, out_channels, num_classes=10, class_emb_channels=4):
super().__init__()
self.class_emb = nn.Embedding(num_classes, class_emb_channels)
self.net = BasicUNet(in_channels+class_emb_channels, out_channels) # input channels = in_channels+1+class_emb_channels
def forward(self, x, class_labels):
n,c,w,h = x.shape
class_cond = self.class_emb(class_labels) # Map to embedding dinemsion
class_cond = class_cond.view(n, class_cond.shape[1], 1, 1).expand(n, class_cond.shape[1], w, h) # Reshape
# Net input is now x, noise amound and class cond concatenated together
net_input = torch.cat((x, class_cond), 1)
return self.net(net_input)
class MyTrainCB(TrainCB):
def predict(self):
bs = self.learn.batch[0].shape[0]
noise_amount = torch.rand(bs).to(self.learn.batch[0].device)
noisy_images = corrupt(self.learn.batch[0], noise_amount)
self.learn.preds = self.learn.model(noisy_images, self.learn.batch[1]) # << Labels as conditioning
def get_loss(self): self.learn.loss = self.learn.loss_func(self.learn.preds, self.learn.batch[0])
model = ClassConditionedUNet(1, 1)
cbs = [MyTrainCB(), CudaCB(), ProgressCB(), LogLossesCB(), OneCycle(lr_max)]
learn = Learner(model, dls, nn.MSELoss(), lr=1e-3, cbs=cbs, opt_func=opt_func)
learn.fit(10)
Sampling as before over 20 steps, but this time with the labels as conditioning:
n_steps = 20
xb = torch.rand(80, 1, 28, 28).cuda()
yb = torch.tensor([[i]*8 for i in range(10)]).flatten().cuda()
for i in range(n_steps):
noise_amount = torch.ones((xb.shape[0], )).to(device) * (1-(i/n_steps))
with torch.no_grad():
pred = model(xb, yb)
mix_factor = 1/(n_steps - i)
xb = xb*(1-mix_factor) + pred*mix_factor
# Optional: Add a bit of extra noise back at early steps
if i < 10: xb = corrupt(xb, torch.ones((xb.shape[0], )).to(device)*0.05)
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(make_grid(xb.detach().cpu().clip(0, 1), nrow=8)[0]);
You can try fashion_mnist as the dataset without making any changes. This seems to work (suprisingly given the lack of fiddling with training and architecture).