import os os.environ['CUDA_VISIBLE_DEVICES']='3' import fastai from fastai import * from fastai.vision import * from fastai.callbacks import * from torchvision.models import vgg16_bn from PIL import Image, ImageDraw, ImageFont path = untar_data(URLs.PETS) path_hr = path/'images' path_lr = path/'crappy' #torch.cuda.set_device(1) def crappify(fn,i): dest = path_lr/fn.relative_to(path_hr) dest.parent.mkdir(parents=True, exist_ok=True) img = PIL.Image.open(fn) targ_sz = resize_to(img, 96, use_min=True) img = img.resize(targ_sz, resample=PIL.Image.BILINEAR).convert('RGB') w,h = img.size q = random.randint(10,70) ImageDraw.Draw(img).text((random.randint(0,w//2),random.randint(0,h//2)), str(q), fill=(255,255,255)) img.save(dest, quality=q) #il = ImageItemList.from_folder(path_hr) #parallel(crappify, il.items) bs,size = 24,160 # bs,size = 8,256 arch = models.resnet34 arch = models.resnet34 src = ImageImageList.from_folder(path_lr).random_split_by_pct(0.1, seed=42) def get_data(bs,size): data = (src.label_from_func(lambda x: path_hr/x.name) .transform(get_transforms(max_zoom=2.), size=size, tfm_y=True) .databunch(bs=bs).normalize(imagenet_stats, do_y=True)) data.c = 3 return data data_gen = get_data(bs,size) wd = 1e-3 def create_gen_learner(): return unet_learner(data_gen, arch, wd=wd, blur=True, norm_type=NormType.Weight, self_attention=True, y_range=(-3.,3.), loss_func=MSELossFlat()) learn_gen = create_gen_learner() learn_gen.fit_one_cycle(2, pct_start=0.8) learn_gen.unfreeze() learn_gen.fit_one_cycle(3, slice(1e-6,1e-3)) learn_gen.show_results(rows=4) learn_gen.save('gen-pre2') learn_gen.load('gen-pre2'); name_gen = 'image_gen' path_gen = path/name_gen path_gen.mkdir(exist_ok=True) def save_preds(dl): i=0 names = dl.dataset.items for b in dl: preds = learn_gen.pred_batch(batch=b, reconstruct=True) for o in preds: o.save(path_gen/names[i].name) i += 1 save_preds(data_gen.fix_dl) PIL.Image.open(path_gen.ls()[0]) classes = [name_gen, 'images'] src = ImageItemList.from_folder(path, include=classes).random_split_by_pct(0.1, seed=42) ll = src.label_from_folder(classes=classes) data_crit = (ll.transform(get_transforms(max_zoom=2.), size=size) .databunch(bs=bs).normalize(imagenet_stats)) data_crit.c = 3 data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3) conv_args = dict(leaky=0.2, norm_type=NormType.Spectral) def conv(ni:int, nf:int, ks:int=3, stride:int=1, **kwargs): return conv_layer(ni, nf, ks=ks, stride=stride, **conv_args, **kwargs) def critic(n_channels:int=3, nf:int=128, n_blocks:int=3, p:int=0.15): layers = [ conv(n_channels, nf, ks=4, stride=2), nn.Dropout2d(p/2), res_block(nf, dense=True,**conv_args)] nf *= 2 # after dense block for i in range(n_blocks): layers += [ nn.Dropout2d(p), conv(nf, nf*2, ks=4, stride=2, self_attention=(i==0))] nf *= 2 layers += [ conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False), Flatten()] return nn.Sequential(*layers) class AdaptiveLoss(nn.Module): def __init__(self, crit): super().__init__() self.crit = crit def forward(self, output, target): return self.crit(output, target[:,None].expand_as(output).float()) def accuracy_thresh_expand(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True)->Rank0Tensor: "Compute accuracy when `y_pred` and `y_true` are the same size." if sigmoid: y_pred = y_pred.sigmoid() return ((y_pred>thresh)==y_true[:,None].expand_as(y_pred).byte()).float().mean() def create_critic_learner(loss_func, metrics): return Learner(data_crit, critic(), metrics=metrics, loss_func=loss_func, wd=wd) learn_critic = create_critic_learner(metrics=accuracy_thresh_expand, loss_func=AdaptiveLoss(nn.BCEWithLogitsLoss())) learn_critic.fit_one_cycle(6, 1e-3) learn_critic.fit_one_cycle(6, 1e-3) learn_critic.save('critic-pre2') from fastai.vision.gan import * loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss()) loss_gen = MSELossFlat() learn_crit=None learn_gen=None gc.collect() learn_crit = create_critic_learner(metrics=None, loss_func=AdaptiveLoss(nn.BCEWithLogitsLoss())).load('critic-pre2') learn_gen = create_gen_learner().load('gen-pre2') @dataclass class GANDiscriminativeLR(LearnerCallback): "`Callback` that handles multiplying the learning rate by `mult_lr` for the critic." mult_lr:float = 5. def on_batch_begin(self, train, **kwargs): "Multiply the current lr if necessary." if not self.learn.gan_trainer.gen_mode and train: self.learn.opt.lr *= self.mult_lr def on_step_end(self, **kwargs): "Put the LR back to its value if necessary." if not self.learn.gan_trainer.gen_mode: self.learn.opt.lr /= self.mult_lr switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65) #switcher = partial(FixedGANSwitcher, n_crit=1, n_gen=1) learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.,250.), show_img=True, switcher=switcher, opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=wd) learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.)) learn.fit(1,4e-5) learn.fit(20,4e-5) learn.show_results() learn.fit(40,4e-5) learn.show_results() learn.fit(20,4e-6) learn.show_results(rows=24)