import fastai
from fastai import * # Quick access to most common functionality
from fastai.vision import * # Quick access to computer vision functionality
from fastai.callbacks import *
from torchvision.models import vgg16_bn
PATH = Path('/DATA/kaggle/imgnetloc/ILSVRC/Data/CLS-LOC/')
PATH_TRN = PATH/'train'
sz_lr=224//4
scale,bs = 4,24
sz_hr = sz_lr*scale
classes = list(PATH_TRN.iterdir())
fnames_full = []
for class_folder in progress_bar(classes):
for fname in class_folder.iterdir():
fnames_full.append(fname)
np.random.seed(42)
keep_pct = 0.02
#keep_pct = 1.
keeps = np.random.rand(len(fnames_full)) < keep_pct
image_fns = np.array(fnames_full, copy=False)[keeps]
len(image_fns)
valid_pct=0.1
src = (ImageToImageList(image_fns)
.random_split_by_pct(valid_pct, seed=42)
.label_from_func(lambda x: x))
def get_data(bs, sz_lr, sz_hr, num_workers=12, **kwargs):
tfms = get_transforms(flip_vert=True)
data = (src
.transform(tfms, size=sz_lr)
.transform_labels(size=sz_hr)
.databunch(bs=bs, num_workers=num_workers, **kwargs))
#.normalize(imagenet_stats, do_y=True))
return data
sz_lr = 288//4
scale,bs = 4,24
sz_hr = sz_lr*scale
data = get_data(bs, sz_lr, sz_hr)
data.train_ds[0:3]
x,y = data.dl().one_batch()
x.shape, y.shape
class Block(nn.Module):
def __init__(self, n_feats, kernel_size, wn, act=nn.ReLU(True), res_scale=1):
super(Block, self).__init__()
self.res_scale = res_scale
body = []
expand = 6
linear = 0.8
body.append(
wn(nn.Conv2d(n_feats, n_feats*expand, 1, padding=1//2)))
body.append(act)
body.append(
wn(nn.Conv2d(n_feats*expand, int(n_feats*linear), 1, padding=1//2)))
body.append(
wn(nn.Conv2d(int(n_feats*linear), n_feats, kernel_size, padding=kernel_size//2)))
self.body = nn.Sequential(*body)
def forward(self, x):
res = self.body(x) * self.res_scale
res += x
return res
class WDSR(nn.Module):
def __init__(self, scale, n_resblocks, n_feats, res_scale, n_colors=3):
super().__init__()
# hyper-params
kernel_size = 3
act = nn.ReLU(True)
# wn = lambda x: x
wn = lambda x: torch.nn.utils.weight_norm(x)
mean, std = imagenet_stats
self.rgb_mean = torch.autograd.Variable(torch.FloatTensor(mean)).view([1, n_colors, 1, 1])
self.rgb_std = torch.autograd.Variable(torch.FloatTensor(std)).view([1, n_colors, 1, 1])
# define head module
head = []
head.append(
wn(nn.Conv2d(n_colors, n_feats,3,padding=3//2)))
# define body module
body = []
for i in range(n_resblocks):
body.append(
Block(n_feats, kernel_size, act=act, res_scale=res_scale, wn=wn))
# define tail module
tail = []
out_feats = scale*scale*n_colors
tail.append(
wn(nn.Conv2d(n_feats, out_feats, 3, padding=3//2)))
tail.append(nn.PixelShuffle(scale))
skip = []
skip.append(
wn(nn.Conv2d(n_colors, out_feats, 5, padding=5//2))
)
skip.append(nn.PixelShuffle(scale))
pad = []
pad.append(torch.nn.ReplicationPad2d(5//2))
# make object members
self.head = nn.Sequential(*head)
self.body = nn.Sequential(*body)
self.tail = nn.Sequential(*tail)
self.skip = nn.Sequential(*skip)
self.pad = nn.Sequential(*pad)
def forward(self, x):
mean = self.rgb_mean.to(x)
std = self.rgb_std.to(x)
x = (x - mean) / std
#if not self.training:
# x = self.pad(x)
s = self.skip(x)
x = self.head(x)
x = self.body(x)
x = self.tail(x)
x += s
x = x*std + mean
return x
scale=4
n_resblocks=8
n_feats=64
res_scale= 1.
model = WDSR(scale, n_resblocks, n_feats, res_scale).cuda()
sz_lr = 72
scale,bs = 4,24
sz_hr = sz_lr*scale
data = get_data(bs, sz_lr, sz_hr)
#loss = CropTargetForLoss(F.l1_loss)
loss = F.mse_loss
learn = Learner(data, nn.DataParallel(model), loss_func=loss)
# learn.lr_find(num_it=500, start_lr=1e-5, end_lr=1000)
# learn.recorder.plot()
#learn.load('pixel')
lr = 1e-3
learn.fit_one_cycle(1, lr)
learn.save('pixel')
learn.fit_one_cycle(1, lr/5)
learn.save('pixel')
sz_lr = 512
scale,bs = 4,4
sz_hr = sz_lr*scale
data = get_data(bs, sz_lr, sz_hr)
#loss = CropTargetForLoss(F.l1_loss)
loss = F.mse_loss
learn = Learner(data, nn.DataParallel(model), loss_func=loss)
learn = learn.load('pixel')
learn.fit_one_cycle(1, lr)
learn.save('pixel')
m_vgg_feat = vgg16_bn(True).features.cuda().eval().features
requires_grad(m_vgg_feat, False)
blocks = [i-1 for i,o in enumerate(children(m_vgg_feat))
if isinstance(o,nn.MaxPool2d)]
blocks, [m_vgg_feat[i] for i in blocks]
class FeatureLoss(nn.Module):
def __init__(self, m_feat, layer_ids, layer_wgts):
super().__init__()
self.m_feat = m_feat
self.loss_features = [self.m_feat[i] for i in layer_ids]
self.hooks = hook_outputs(self.loss_features, detach=False)
self.wgts = layer_wgts
self.metrics = {}
self.metric_names = ['L1'] + [f'feat_{i}' for i in range(len(layer_ids))]
for name in self.metric_names: self.metrics[name] = 0.
def make_feature(self, bs, o, clone=False):
feat = o.view(bs, -1)
if clone: feat = feat.clone()
return feat
def make_features(self, x, clone=False):
bs = x.shape[0]
self.m_feat(x)
return [self.make_feature(bs, o, clone) for o in self.hooks.stored]
def forward(self, input, target):
out_feat = self.make_features(target, clone=True)
in_feat = self.make_features(input)
l1_loss = F.l1_loss(input,target)/100
self.feat_losses = [l1_loss]
self.feat_losses += [F.mse_loss(f_in, f_out)*w
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
for i,name in enumerate(self.metric_names): self.metrics[name] = self.feat_losses[i]
self.metrics['L1'] = l1_loss
self.loss = sum(self.feat_losses)
return self.loss*100
class ReportLossMetrics(LearnerCallback):
_order = -20 #Needs to run before the recorder
def on_train_begin(self, **kwargs):
self.metric_names = self.learn.loss_func.metric_names
self.learn.recorder.add_metric_names(self.metric_names)
def on_epoch_begin(self, **kwargs):
self.metrics = {}
for name in self.metric_names:
self.metrics[name] = 0.
self.nums = 0
def on_batch_end(self, last_target, train, **kwargs):
if not train:
bs = last_target.size(0)
for name in self.metric_names:
self.metrics[name] += bs * self.learn.loss_func.metrics[name]
self.nums += bs
def on_epoch_end(self, **kwargs):
if self.nums:
metrics = [self.metrics[name]/self.nums for name in self.metric_names]
self.learn.recorder.add_metrics(metrics)
sz_lr = 200
scale,bs = 4,4
sz_hr = sz_lr*scale
data = get_data(bs, sz_lr, sz_hr)
feat_loss = FeatureLoss(m_vgg_feat, blocks[:2], [0.25,0.45,0.30])
learn = Learner(data, nn.DataParallel(model), loss_func=feat_loss, callback_fns=[ReportLossMetrics])
#learn = learn.load('pixel')
# learn.lr_find()
# learn.recorder.plot()
lr=1e-3
learn.fit_one_cycle(1, lr)
learn.save('enhance_feat')
learn = learn.load('enhance_feat')
def make_img(x, idx=0):
return Image(torch.clamp(x.cpu(),0,1)[idx])
def plot_x_y_pred(x, pred, y, figsize):
rows=x.shape[0]
fig, axs = plt.subplots(rows,3,figsize=figsize)
for i in range(rows):
make_img(x, i).show(ax=axs[i, 0])
make_img(pred, i).show(ax=axs[i, 1])
make_img(y, i).show(ax=axs[i, 2])
plt.tight_layout()
def plot_some(learn, do_denorm=True, figsize=None):
x, y = next(iter(learn.data.valid_dl))
y_pred = model(x)
y_pred = y_pred.detach()
x = x.detach()
y = y.detach()
if figsize is None: figsize=y_pred.shape[-2:]
plot_x_y_pred(x[0:4], y_pred[0:4], y[0:4], figsize=figsize)
sz_lr = 64
scale,bs = 4,24
sz_hr = sz_lr*scale
data = get_data(bs, sz_lr, sz_hr)
loss = F.mse_loss
learn = Learner(data, nn.DataParallel(model), loss_func=loss)
learn = learn.load('enhance_feat')
plot_some(learn, figsize=(256,256))
learn = learn.load('pixel')
plot_some(learn, figsize=(256,256))