import fastai
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
from torchvision.models import vgg16_bn
path = untar_data(URLs.PETS)
path_hr = path/'images'
path_lr = path/'crappy'
torch.cuda.set_device(2)
Prepare the input data by crappifying images.
def crappify(fn,i):
dest = path_lr/fn.relative_to(path_hr)
dest.parent.mkdir(parents=True, exist_ok=True)
img = PIL.Image.open(fn)
targ_sz = resize_to(img, 96, use_min=True)
img = img.resize(targ_sz, resample=PIL.Image.BILINEAR).convert('RGB')
img.save(dest, quality=random.randint(10,70))
Uncomment the first time you run this notebook.
#il = ImageItemList.from_folder(path_hr)
#parallel(crappify, il.items)
For gradual resizing we can change the commented line here.
#bs,size = 32,128
#bs,size = 8,256
bs,size = 16,160
arch = models.resnet34
classes = ['crappy', 'images']
src = ImageItemList.from_folder(path, include=classes).random_split_by_pct(0.1, seed=42)
ll = src.label_from_folder(classes=classes)
data_crit = (ll.transform(get_transforms(max_zoom=2.), size=size)
.databunch(bs=bs).normalize(imagenet_stats))
data_crit.c = 3
data_crit.show_batch(rows=4, ds_type=DatasetType.Valid)
conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)
def conv(ni:int, nf:int, ks:int=3, stride:int=1, **kwargs):
return conv_layer(ni, nf, ks=ks, stride=stride, **conv_args, **kwargs)
def critic(n_channels:int=3, nf:int=128, n_blocks:int=3, p:int=0.25):
layers = [
conv(n_channels, nf, ks=4, stride=2),
nn.Dropout2d(p/2),
# Removing conv_args because spectral norm makes this slow
res_block(nf, dense=True)] # , **conv_args)]
nf *= 2 # after dense block
for i in range(n_blocks):
layers += [
nn.Dropout2d(p),
conv(nf, nf*2, ks=4, stride=2, self_attention=(i==0))]
nf *= 2
layers += [
conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
Flatten()]
return nn.Sequential(*layers)
Module to apply the loss function to every element of the last features before taking the mean.
class AdaptiveLoss(nn.Module):
def __init__(self, crit):
super().__init__()
self.crit = crit
def forward(self, output, target):
return self.crit(output, target[:,None].expand_as(output).float())
Specific accuracy metric.
def accuracy_thresh_expand(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True)->Rank0Tensor:
"Compute accuracy when `y_pred` and `y_true` are the same size."
if sigmoid: y_pred = y_pred.sigmoid()
return ((y_pred>thresh)==y_true[:,None].expand_as(y_pred).byte()).float().mean()
Pretrain the critic on crappy vs not crappy.
learn_critic = Learner(data_crit, critic(), metrics=accuracy_thresh_expand, loss_func=AdaptiveLoss(nn.BCEWithLogitsLoss()))
learn_critic.fit_one_cycle(6, 1e-3)
learn_critic.save('critic-pre2')
Now let's pretrain the generator.
arch = models.resnet34
src = ImageImageList.from_folder(path_lr).random_split_by_pct(0.1, seed=42)
def get_data(bs,size):
data = (src.label_from_func(lambda x: path_hr/x.name)
.transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)
.databunch(bs=bs).normalize(imagenet_stats))
data.c = 3
return data
data_gen = get_data(bs,size)
wd = 1e-3
learn_gen = unet_learner(data_gen, arch, wd=wd, blur=True, norm_type=NormType.Spectral, self_attention=True, sigmoid=True,
loss_func=MSELossFlat())
learn_gen.fit_one_cycle(2, pct_start=0.8)
learn_gen.unfreeze()
learn_gen.fit_one_cycle(2, slice(1e-6,1e-3))
data_gen = get_data(bs,size)
learn_gen.data = data_gen
learn_gen.fit_one_cycle(2, slice(1e-6,1e-3))
learn_gen.show_results(rows=8)
learn_gen.save('gen-pre2')
Now we'll combine those pretrained model in a GAN.
from fastai.vision.gan import *
Those are the losses from before.
loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())
loss_gen = MSELossFlat()
learn_crit = Learner(data_crit, critic(), loss_func=loss_critic).load('critic-pre2')
learn_gen = unet_learner(data_gen, arch, wd=wd, blur=True, norm_type=NormType.Spectral, self_attention=True, sigmoid=True,
loss_func=MSELossFlat()).load('gen-pre2')
To define a GAN Learner, we just have to specify the learner objects foor the generator and the critic. The switcher is a callback that decides when to switch from discriminator to generator and vice versa. Here we do as many iterations of the discriminator as needed to get its loss back < 0.5 then one iteration of the generator.
The loss of the critic is given by learn_crit.loss_func. We take the average of this loss function on the batch of real predictions (target 1) and the batch of fake predicitions (target 0).
The loss of the generator is weighted sum (weights in weights_gen) of learn_crit.loss_func on the batch of fake (passed throught the critic to become predictions) with a target of 1, and the learn_gen.loss_func applied to the output (batch of fake) and the target (corresponding batch of superres images).
@dataclass
class GANDiscriminativeLR(LearnerCallback):
"`Callback` that handles multiplying the learning rate by `mult_lr` for the critic."
mult_lr:float = 1.
def on_batch_begin(self, train, **kwargs):
"Multiply the current lr if necessary."
if not self.learn.gan_trainer.gen_mode and train:
self.learn.opt.lr *= self.mult_lr
def on_step_end(self, **kwargs):
"Put the LR back to its value if necessary."
if not self.learn.gan_trainer.gen_mode:
self.learn.opt.lr /= self.mult_lr
#switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.5)
switcher = partial(FixedGANSwitcher, n_crit=1, n_gen=1)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.,250.), show_img=False, switcher=switcher,
opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=0)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))
learn.fit(20,1e-4)
learn.fit(20,1e-4)
#Without dense block, adaptive schedule
learn.show_results()