import torch, numpy as np, seaborn as sb, torchvision.transforms.functional as TF
import matplotlib.pyplot as plt, matplotlib as mpl
from torch import nn,tensor
from fastcore.utils import *
from k_diffusion import *
from einops import rearrange
from torchvision import datasets#, transforms, utils
from torch.utils import data
torch.manual_seed(42)
mpl.rcParams['image.cmap'] = 'gray_r'
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
from miniai.datasets import *
# model_path = Path('outputs/base_00060000.pth')
model_path = Path('outputs/mse-no-temb_00060000.pth')
sz = 28
sigma_min,sigma_max,sigma_data = 1e-2,80,0.6162
bs = 16
def tf(image): return TF.to_tensor(image)[...]*2-1
train_set = datasets.FashionMNIST('data', train=True, download=True, transform=tf)
train_set[0][0].shape
torch.Size([1, 28, 28])
train_dl = data.DataLoader(train_set, bs, shuffle=True, drop_last=True, num_workers=4)
x,y = next(iter(train_dl))
x.shape
torch.Size([16, 1, 28, 28])
show_images(x, imsize=1.5)
n = len(x)
sigmas = torch.rand(n)
sigmas
tensor([0.39, 0.60, 0.26, 0.79, 0.94, 0.13, 0.93, 0.59, 0.87, 0.57, 0.74, 0.43, 0.89, 0.57, 0.27, 0.63])
def rbroadcast(x, t): return x.reshape(x.shape + (1,)*(t.ndim-x.ndim))
noise = torch.randn(n,1,sz,sz)
noised = x+noise*rbroadcast(sigmas,noise)
show_images(noised, imsize=1.5)
model = models.ImageDenoiserModelV1(t_embed=False,
c_in=1, feats_in=256, depths=[2,4,4], channels=[64,128,256], self_attn_depths=[False, False, True],
cross_attn_depths=None, patch_size=1, dropout_rate=0.05, mapping_cond_dim=9, unet_cond_dim=0,
cross_cond_dim=0, skip_stages=0, has_variance=False).cuda()
chkpt = torch.load(model_path)['model_ema']
# model = Denoiser(model, unscaled=True).cuda()
nn.modules.utils.consume_prefix_in_state_dict_if_present(chkpt, 'inner_model.')
model.load_state_dict(chkpt);
out = model(noised.cuda(), tensor(1.).cuda()).data.cpu()
show_images(out, imsize=1.5)
show_images(noised-out, imsize=1.5)
show_images(x, imsize=1.5)
sigmas = sampling.get_sigmas_karras(20, sigma_min, sigma_max, rho=7.)
sigmas
tensor([ 80.00, 60.97, 45.97, 34.24, 25.18, 18.26, 13.04, 9.15, 6.30, 4.25, 2.80, 1.80, 1.12,
0.67, 0.39, 0.21, 0.11, 0.05, 0.02, 0.01, 0.00])
torch.manual_seed(42)
xr = torch.randn([8, 1, sz, sz]).cuda() * sigma_max
x_0 = sampling.sample_heun(model, xr.cuda(), sigmas.cuda()).data.cpu()
0%| | 0/20 [00:00<?, ?it/s]
x = (x_0.clamp(-1,1)+1)/2
x = rearrange(x, '(b1 b2) c h w -> (b1 h) (b2 w) c ', b2=4)
plt.figure(figsize=(5,5))
plt.imshow(x, cmap='gray_r')
plt.axis('off');
loc,scale = -1.2,1.2
density = partial(utils.rand_log_normal, loc=loc, scale=scale)
sb.kdeplot(density(1000), cut=0);
sb.kdeplot(density(1000), clip=(0,5));
# def tf(image):
# h, w = image.size
# image = np.array(image, dtype=np.float32)[..., None] / 255
# image = torch.as_tensor(image).movedim(2, 0) * 2 - 1
# return image