import fastai
from fastai import * # Quick access to most common functionality
from fastai.vision import * # Quick access to computer vision functionality
from fastai.callbacks import *
from torchvision.models import vgg16_bn
PATH = Path('/DATA/kaggle/imgnetloc/ILSVRC/Data/CLS-LOC/')
PATH_TRN = PATH/'train'
sz_lr=72
scale,bs = 4,24
sz_hr = sz_lr*scale
classes = list(PATH_TRN.iterdir())
fnames_full = []
for class_folder in progress_bar(classes):
for fname in class_folder.iterdir():
fnames_full.append(fname)
np.random.seed(42)
keep_pct = 0.02
#keep_pct = 1.
keeps = np.random.rand(len(fnames_full)) < keep_pct
image_fns = np.array(fnames_full, copy=False)[keeps]
len(image_fns)
valid_pct=0.1
src = (ImageToImageList(image_fns)
.random_split_by_pct(valid_pct, seed=42)
.label_from_func(lambda x: x))
def get_data(bs, sz_lr, sz_hr, num_workers=12, **kwargs):
# tfms = get_transforms(do_flip=True, flip_vert=True,
# max_lighting=0.0, max_rotate=0.0,max_zoom=0.0,max_warp=0.0)
tfms = [[dihedral_affine(p=0.75), crop_pad(row_pct=0.5, col_pct=0.5)],
[crop_pad(row_pct=0.5, col_pct=0.5)]]
data = (src
.transform(tfms, size=sz_lr)
.transform_labels(size=sz_hr)
.databunch(bs=bs, num_workers=num_workers, **kwargs)
.normalize(imagenet_stats, do_y=True))
return data
sz_lr = 72
scale,bs = 4,24
sz_hr = sz_lr*scale
data = get_data(bs, sz_lr, sz_hr)
data.train_ds[0:3]
x,y = data.dl().one_batch()
x.shape, y.shape
def make_img(x, idx=0):
return Image(torch.clamp(data.denorm(x.cpu()),0,1)[idx])
idx=5
x_img = make_img(x, idx)
y_img = make_img(y, idx)
x_img.show(), y_img.show()
wn = lambda x: torch.nn.utils.weight_norm(x)
def conv(ni, nf, kernel_size=3, actn=True):
layers = [wn(nn.Conv2d(ni, nf, kernel_size, padding=kernel_size//2))]
if actn: layers.append(nn.ReLU(True))
return nn.Sequential(*layers)
class ResSequential(nn.Module):
def __init__(self, layers, res_scale=1.0):
super().__init__()
self.res_scale = res_scale
self.m = nn.Sequential(*layers)
def forward(self, x):
x = x + self.m(x) * self.res_scale
return x
def res_block(nf):
return ResSequential(
[conv(nf, nf), conv(nf, nf, actn=False)],
0.1)
def upsample(ni, nf, scale):
layers = []
for i in range(int(math.log(scale,2))):
layers += [conv(ni, nf*4), nn.PixelShuffle(2)]
return nn.Sequential(*layers)
class SrResnet(nn.Module):
def __init__(self, nf, scale, n_res=8):
super().__init__()
features = [conv(3, 64)]
for i in range(n_res): features.append(res_block(64))
features += [conv(64,64), upsample(64, 64, scale),
# nn.BatchNorm2d(64),
conv(64, 3, actn=False)]
self.features = nn.Sequential(*features)
def forward(self, x): return self.features(x)
def icnr(x, scale, init=nn.init.kaiming_normal_):
new_shape = [int(x.shape[0] / (scale ** 2))] + list(x.shape[1:])
subkernel = torch.zeros(new_shape)
subkernel = init(subkernel)
subkernel = subkernel.transpose(0, 1)
subkernel = subkernel.contiguous().view(subkernel.shape[0],
subkernel.shape[1], -1)
kernel = subkernel.repeat(1, 1, scale ** 2)
transposed_shape = [x.shape[1]] + [x.shape[0]] + list(x.shape[2:])
kernel = kernel.contiguous().view(transposed_shape)
kernel = kernel.transpose(0, 1)
return kernel
model = SrResnet(64, scale)
#model = torch.load('old.pth')
# wd=1e-7
# learn = Learner(data, nn.DataParallel(model,[0,2]), loss_func=F.mse_loss, opt_func=torch.optim.Adam, wd=wd, true_wd=False)
sz_lr = 288
scale,bs = 4,12
sz_hr = sz_lr*scale
data = get_data(bs, sz_lr, sz_hr)
learn = Learner(data, nn.DataParallel(model), loss_func=F.mse_loss)
learn.lr_find()
learn.recorder.plot()
learn = learn.load('pixel_v2')
lr = 1e-3
learn.fit_one_cycle(1, lr)
lr = 1e-3
learn.fit_one_cycle(1, lr)
lr = 2e-4
learn.fit_one_cycle(1, lr)
learn.save('pixel_v2')
sz_lr = 72
scale,bs = 4,4
sz_hr = sz_lr*scale
data = get_data(bs, sz_lr, sz_hr)
learn = Learner(data, nn.DataParallel(model), loss_func=F.mse_loss)
learn = learn.load('pixel_v2')
def plot_x_y_pred(x, pred, y, figsize):
rows=x.shape[0]
fig, axs = plt.subplots(rows,3,figsize=figsize)
for i in range(rows):
make_img(x, i).show(ax=axs[i, 0])
make_img(pred, i).show(ax=axs[i, 1])
make_img(y, i).show(ax=axs[i, 2])
plt.tight_layout()
x, y = next(iter(learn.data.valid_dl))
y_pred = model(x)
x[0:3].shape
make_img(y_pred.detach(), 2).show(), make_img(y.detach(), 2).show()
def plot_some(learn, do_denorm=True):
x, y = next(iter(learn.data.valid_dl))
y_pred = model(x)
y_pred = y_pred.detach()
x = x.detach()
y = y.detach()
plot_x_y_pred(x[0:3], y_pred[0:3], y[0:3], figsize=y_pred.shape[-2:])
plot_some(learn)
m_vgg_feat = vgg16_bn(True).cuda().eval().features
requires_grad(m_vgg_feat, False)
blocks = [i-1 for i,o in enumerate(children(m_vgg_feat))
if isinstance(o,nn.MaxPool2d)]
blocks, [m_vgg_feat[i] for i in blocks]
class FeatureLoss(nn.Module):
def __init__(self, m_feat, layer_ids, layer_wgts):
super().__init__()
self.m_feat = m_feat
self.loss_features = [self.m_feat[i] for i in layer_ids]
self.hooks = hook_outputs(self.loss_features, detach=False)
self.wgts = layer_wgts
self.metrics = {}
self.metric_names = ['L1'] + [f'feat_{i}' for i in range(len(layer_ids))]
for name in self.metric_names: self.metrics[name] = 0.
def make_feature(self, bs, o, clone=False):
feat = o.view(bs, -1)
if clone: feat = feat.clone()
return feat
def make_features(self, x, clone=False):
bs = x.shape[0]
self.m_feat(x)
return [self.make_feature(bs, o, clone) for o in self.hooks.stored]
def forward(self, input, target):
out_feat = self.make_features(target, clone=True)
in_feat = self.make_features(input)
l1_loss = F.l1_loss(input,target)/100
self.feat_losses = [l1_loss]
self.feat_losses += [F.mse_loss(f_in, f_out)*w
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
for i,name in enumerate(self.metric_names): self.metrics[name] = self.feat_losses[i]
self.metrics['L1'] = l1_loss
self.loss = sum(self.feat_losses)
return self.loss
class ReportLossMetrics(LearnerCallback):
_order = -20 #Needs to run before the recorder
def on_train_begin(self, **kwargs):
self.metric_names = self.learn.loss_func.metric_names
self.learn.recorder.add_metric_names(self.metric_names)
def on_epoch_begin(self, **kwargs):
self.metrics = {}
for name in self.metric_names:
self.metrics[name] = 0.
self.nums = 0
def on_batch_end(self, last_target, train, **kwargs):
if not train:
bs = last_target.size(0)
for name in self.metric_names:
self.metrics[name] += bs * self.learn.loss_func.metrics[name]
self.nums += bs
def on_epoch_end(self, **kwargs):
if self.nums:
metrics = [self.metrics[name]/self.nums for name in self.metric_names]
self.learn.recorder.add_metrics(metrics)
sz_lr = 32
scale,bs = 4,24
sz_hr = sz_lr*scale
data = get_data(bs, sz_lr, sz_hr)
feat_loss = FeatureLoss(m_vgg_feat, blocks[0:4], [0.25,0.25,0.25,0.25])
model = SrResnet(64, scale)
learn = Learner(data, nn.DataParallel(model), loss_func=feat_loss, callback_fns=[ReportLossMetrics])
learn.load('pixel_v2')
model = learn.model.module
nres = 8
conv_shuffle = model.features[nres+2][0][0]
kernel = icnr(conv_shuffle.weight, scale=scale)
conv_shuffle.weight.data.copy_(kernel);
conv_shuffle = model.features[nres+2][2][0]
kernel = icnr(conv_shuffle.weight, scale=scale)
conv_shuffle.weight.data.copy_(kernel);
learn.freeze_to(999)
for i in range(10,12): requires_grad(model.features[i], True)
learn.lr_find()
learn.recorder.plot()
sz_lr = 128
scale,bs = 4,4
sz_hr = sz_lr*scale
data = get_data(bs, sz_lr, sz_hr)
model = SrResnet(64, scale)
learn = Learner(data, nn.DataParallel(model), loss_func=feat_loss, callback_fns=[ReportLossMetrics])
learn = learn.load('enhance_feat_v2')
lr=1e-3
learn.unfreeze()
learn.fit_one_cycle(1, lr)
learn.save('enhance_feat_v2')
learn = learn.load('enhance_feat_v2')
learn.unfreeze()
learn.lr_find()
learn.recorder.plot()
lr=1e-5
learn.fit_one_cycle(1, lr)
learn.save('enhance_feat2')
learn = learn.load('enhance_feat2')
sz_lr = 72
scale,bs = 4,4
sz_hr = sz_lr*scale
data = get_data(bs, sz_lr, sz_hr)
learn = Learner(data, nn.DataParallel(model), loss_func=F.mse_loss)
learn = learn.load('enhance_feat_v2')
plot_some(learn)
sz_lr = 72
scale,bs = 4,4
sz_hr = sz_lr*scale
data = get_data(bs, sz_lr, sz_hr)
learn = Learner(data, nn.DataParallel(model), loss_func=F.mse_loss)
learn = learn.load('pixel_v2')
plot_some(learn)