import fastai from fastai import * # Quick access to most common functionality from fastai.vision import * # Quick access to computer vision functionality from fastai.callbacks import * from torchvision.models import vgg16_bn PATH = Path('/DATA/kaggle/imgnetloc/ILSVRC/Data/CLS-LOC/') PATH_TRN = PATH/'train' sz_lr=224//4 scale,bs = 4,24 sz_hr = sz_lr*scale classes = list(PATH_TRN.iterdir()) fnames_full = [] for class_folder in progress_bar(classes): for fname in class_folder.iterdir(): fnames_full.append(fname) np.random.seed(42) keep_pct = 0.02 #keep_pct = 1. keeps = np.random.rand(len(fnames_full)) < keep_pct image_fns = np.array(fnames_full, copy=False)[keeps] len(image_fns) valid_pct=0.1 src = (ImageToImageList(image_fns) .random_split_by_pct(valid_pct, seed=42) .label_from_func(lambda x: x)) def get_data(bs, sz_lr, sz_hr, num_workers=12, **kwargs): tfms = get_transforms(flip_vert=True) data = (src .transform(tfms, size=sz_lr) .transform_labels(size=sz_hr) .databunch(bs=bs, num_workers=num_workers, **kwargs)) #.normalize(imagenet_stats, do_y=True)) return data sz_lr = 288//4 scale,bs = 4,24 sz_hr = sz_lr*scale data = get_data(bs, sz_lr, sz_hr) data.train_ds[0:3] x,y = data.dl().one_batch() x.shape, y.shape class Block(nn.Module): def __init__(self, n_feats, kernel_size, wn, act=nn.ReLU(True), res_scale=1): super(Block, self).__init__() self.res_scale = res_scale body = [] expand = 6 linear = 0.8 body.append( wn(nn.Conv2d(n_feats, n_feats*expand, 1, padding=1//2))) body.append(act) body.append( wn(nn.Conv2d(n_feats*expand, int(n_feats*linear), 1, padding=1//2))) body.append( wn(nn.Conv2d(int(n_feats*linear), n_feats, kernel_size, padding=kernel_size//2))) self.body = nn.Sequential(*body) def forward(self, x): res = self.body(x) * self.res_scale res += x return res class WDSR(nn.Module): def __init__(self, scale, n_resblocks, n_feats, res_scale, n_colors=3): super().__init__() # hyper-params kernel_size = 3 act = nn.ReLU(True) # wn = lambda x: x wn = lambda x: torch.nn.utils.weight_norm(x) mean, std = imagenet_stats self.rgb_mean = torch.autograd.Variable(torch.FloatTensor(mean)).view([1, n_colors, 1, 1]) self.rgb_std = torch.autograd.Variable(torch.FloatTensor(std)).view([1, n_colors, 1, 1]) # define head module head = [] head.append( wn(nn.Conv2d(n_colors, n_feats,3,padding=3//2))) # define body module body = [] for i in range(n_resblocks): body.append( Block(n_feats, kernel_size, act=act, res_scale=res_scale, wn=wn)) # define tail module tail = [] out_feats = scale*scale*n_colors tail.append( wn(nn.Conv2d(n_feats, out_feats, 3, padding=3//2))) tail.append(nn.PixelShuffle(scale)) skip = [] skip.append( wn(nn.Conv2d(n_colors, out_feats, 5, padding=5//2)) ) skip.append(nn.PixelShuffle(scale)) pad = [] pad.append(torch.nn.ReplicationPad2d(5//2)) # make object members self.head = nn.Sequential(*head) self.body = nn.Sequential(*body) self.tail = nn.Sequential(*tail) self.skip = nn.Sequential(*skip) self.pad = nn.Sequential(*pad) def forward(self, x): mean = self.rgb_mean.to(x) std = self.rgb_std.to(x) x = (x - mean) / std #if not self.training: # x = self.pad(x) s = self.skip(x) x = self.head(x) x = self.body(x) x = self.tail(x) x += s x = x*std + mean return x scale=4 n_resblocks=8 n_feats=64 res_scale= 1. model = WDSR(scale, n_resblocks, n_feats, res_scale).cuda() sz_lr = 72 scale,bs = 4,24 sz_hr = sz_lr*scale data = get_data(bs, sz_lr, sz_hr) #loss = CropTargetForLoss(F.l1_loss) loss = F.mse_loss learn = Learner(data, nn.DataParallel(model), loss_func=loss) # learn.lr_find(num_it=500, start_lr=1e-5, end_lr=1000) # learn.recorder.plot() #learn.load('pixel') lr = 1e-3 learn.fit_one_cycle(1, lr) learn.save('pixel') learn.fit_one_cycle(1, lr/5) learn.save('pixel') sz_lr = 512 scale,bs = 4,4 sz_hr = sz_lr*scale data = get_data(bs, sz_lr, sz_hr) #loss = CropTargetForLoss(F.l1_loss) loss = F.mse_loss learn = Learner(data, nn.DataParallel(model), loss_func=loss) learn = learn.load('pixel') learn.fit_one_cycle(1, lr) learn.save('pixel') m_vgg_feat = vgg16_bn(True).features.cuda().eval().features requires_grad(m_vgg_feat, False) blocks = [i-1 for i,o in enumerate(children(m_vgg_feat)) if isinstance(o,nn.MaxPool2d)] blocks, [m_vgg_feat[i] for i in blocks] class FeatureLoss(nn.Module): def __init__(self, m_feat, layer_ids, layer_wgts): super().__init__() self.m_feat = m_feat self.loss_features = [self.m_feat[i] for i in layer_ids] self.hooks = hook_outputs(self.loss_features, detach=False) self.wgts = layer_wgts self.metrics = {} self.metric_names = ['L1'] + [f'feat_{i}' for i in range(len(layer_ids))] for name in self.metric_names: self.metrics[name] = 0. def make_feature(self, bs, o, clone=False): feat = o.view(bs, -1) if clone: feat = feat.clone() return feat def make_features(self, x, clone=False): bs = x.shape[0] self.m_feat(x) return [self.make_feature(bs, o, clone) for o in self.hooks.stored] def forward(self, input, target): out_feat = self.make_features(target, clone=True) in_feat = self.make_features(input) l1_loss = F.l1_loss(input,target)/100 self.feat_losses = [l1_loss] self.feat_losses += [F.mse_loss(f_in, f_out)*w for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)] for i,name in enumerate(self.metric_names): self.metrics[name] = self.feat_losses[i] self.metrics['L1'] = l1_loss self.loss = sum(self.feat_losses) return self.loss*100 class ReportLossMetrics(LearnerCallback): _order = -20 #Needs to run before the recorder def on_train_begin(self, **kwargs): self.metric_names = self.learn.loss_func.metric_names self.learn.recorder.add_metric_names(self.metric_names) def on_epoch_begin(self, **kwargs): self.metrics = {} for name in self.metric_names: self.metrics[name] = 0. self.nums = 0 def on_batch_end(self, last_target, train, **kwargs): if not train: bs = last_target.size(0) for name in self.metric_names: self.metrics[name] += bs * self.learn.loss_func.metrics[name] self.nums += bs def on_epoch_end(self, **kwargs): if self.nums: metrics = [self.metrics[name]/self.nums for name in self.metric_names] self.learn.recorder.add_metrics(metrics) sz_lr = 200 scale,bs = 4,4 sz_hr = sz_lr*scale data = get_data(bs, sz_lr, sz_hr) feat_loss = FeatureLoss(m_vgg_feat, blocks[:2], [0.25,0.45,0.30]) learn = Learner(data, nn.DataParallel(model), loss_func=feat_loss, callback_fns=[ReportLossMetrics]) #learn = learn.load('pixel') # learn.lr_find() # learn.recorder.plot() lr=1e-3 learn.fit_one_cycle(1, lr) learn.save('enhance_feat') learn = learn.load('enhance_feat') def make_img(x, idx=0): return Image(torch.clamp(x.cpu(),0,1)[idx]) def plot_x_y_pred(x, pred, y, figsize): rows=x.shape[0] fig, axs = plt.subplots(rows,3,figsize=figsize) for i in range(rows): make_img(x, i).show(ax=axs[i, 0]) make_img(pred, i).show(ax=axs[i, 1]) make_img(y, i).show(ax=axs[i, 2]) plt.tight_layout() def plot_some(learn, do_denorm=True, figsize=None): x, y = next(iter(learn.data.valid_dl)) y_pred = model(x) y_pred = y_pred.detach() x = x.detach() y = y.detach() if figsize is None: figsize=y_pred.shape[-2:] plot_x_y_pred(x[0:4], y_pred[0:4], y[0:4], figsize=figsize) sz_lr = 64 scale,bs = 4,24 sz_hr = sz_lr*scale data = get_data(bs, sz_lr, sz_hr) loss = F.mse_loss learn = Learner(data, nn.DataParallel(model), loss_func=loss) learn = learn.load('enhance_feat') plot_some(learn, figsize=(256,256)) learn = learn.load('pixel') plot_some(learn, figsize=(256,256))