import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.vision.gan import *
path = untar_data(URLs.PETS)
path_hr = path/'images'
path_lr = path/'crappy'
Prepare the input data by crappifying images.
使用噪化的图片准备输入数据
from crappify import *
Uncomment the first time you run this notebook.
如果首次运行这个notebook,请把注释打开
#il = ImageList.from_folder(path_hr)
#parallel(crappifier(path_lr, path_hr), il.items)
For gradual resizing we can change the commented line here.
如果需要对图片进行渐变修改或者缩放尺寸请打开下面的注释
bs,size=32, 128
# bs,size = 24,160
#bs,size = 8,256
arch = models.resnet34
Now let's pretrain the generator.
现在我们来预训练数据生成器
arch = models.resnet34
src = ImageImageList.from_folder(path_lr).split_by_rand_pct(0.1, seed=42)
def get_data(bs,size):
data = (src.label_from_func(lambda x: path_hr/x.name)
.transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)
.databunch(bs=bs).normalize(imagenet_stats, do_y=True))
data.c = 3
return data
data_gen = get_data(bs,size)
data_gen.show_batch(4)
wd = 1e-3
y_range = (-3.,3.)
loss_gen = MSELossFlat()
def create_gen_learner():
return unet_learner(data_gen, arch, wd=wd, blur=True, norm_type=NormType.Weight,
self_attention=True, y_range=y_range, loss_func=loss_gen)
learn_gen = create_gen_learner()
learn_gen.fit_one_cycle(2, pct_start=0.8)
| epoch | train_loss | valid_loss |
|---|---|---|
| 1 | 0.061653 | 0.053493 |
| 2 | 0.051248 | 0.047272 |
learn_gen.unfreeze()
learn_gen.fit_one_cycle(3, slice(1e-6,1e-3))
| epoch | train_loss | valid_loss |
|---|---|---|
| 1 | 0.050429 | 0.046088 |
| 2 | 0.049056 | 0.043954 |
| 3 | 0.045437 | 0.043146 |
learn_gen.show_results(rows=4)
learn_gen.save('gen-pre2')
learn_gen.load('gen-pre2');
name_gen = 'image_gen'
path_gen = path/name_gen
# shutil.rmtree(path_gen)
path_gen.mkdir(exist_ok=True)
def save_preds(dl):
i=0
names = dl.dataset.items
for b in dl:
preds = learn_gen.pred_batch(batch=b, reconstruct=True)
for o in preds:
o.save(path_gen/names[i].name)
i += 1
save_preds(data_gen.fix_dl)
PIL.Image.open(path_gen.ls()[0])
learn_gen=None
gc.collect()
3755
Pretrain the critic on crappy vs not crappy.
使用噪化和非噪化数据来预训练批评者网络
def get_crit_data(classes, bs, size):
src = ImageList.from_folder(path, include=classes).split_by_rand_pct(0.1, seed=42)
ll = src.label_from_folder(classes=classes)
data = (ll.transform(get_transforms(max_zoom=2.), size=size)
.databunch(bs=bs).normalize(imagenet_stats))
data.c = 3
return data
data_crit = get_crit_data([name_gen, 'images'], bs=bs, size=size)
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)
loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())
def create_critic_learner(data, metrics):
return Learner(data, gan_critic(), metrics=metrics, loss_func=loss_critic, wd=wd)
learn_critic = create_critic_learner(data_crit, accuracy_thresh_expand)
learn_critic.fit_one_cycle(6, 1e-3)
| epoch | train_loss | valid_loss | accuracy_thresh_expand |
|---|---|---|---|
| 1 | 0.678256 | 0.687312 | 0.531083 |
| 2 | 0.434768 | 0.366180 | 0.851823 |
| 3 | 0.186435 | 0.128874 | 0.955214 |
| 4 | 0.120681 | 0.072901 | 0.980228 |
| 5 | 0.099568 | 0.107304 | 0.962564 |
| 6 | 0.071958 | 0.078094 | 0.976239 |
learn_critic.save('critic-pre2')
Now we'll combine those pretrained model in a GAN.
现在我们将这些经过预训练的模型整合到GAN里。
learn_crit=None
learn_gen=None
gc.collect()
15794
data_crit = get_crit_data(['crappy', 'images'], bs=bs, size=size)
learn_crit = create_critic_learner(data_crit, metrics=None).load('critic-pre2')
learn_gen = create_gen_learner().load('gen-pre2')
To define a GAN Learner, we just have to specify the learner objects foor the generator and the critic. The switcher is a callback that decides when to switch from discriminator to generator and vice versa. Here we do as many iterations of the discriminator as needed to get its loss back < 0.5 then one iteration of the generator.
为了定义一个GAN学习器,我们需要为生成器和评价网络指定学习组件对象。switcher(切换器)是一个回调(装置),这个装置决定何时在鉴别器和生成器之间转换。这里我们针对鉴别器训练足够多的iterations(迭代次数)直到其损失函数值小于0.5,然后再对生成器训练一次。
The loss of the critic is given by learn_crit.loss_func. We take the average of this loss function on the batch of real predictions (target 1) and the batch of fake predicitions (target 0).
评价网络的损失由learn_crit.loss_func获得。基于正预测值(目标为1)和负预测值(目标为0)的迭代训练会针对这个损失函数取均值。
The loss of the generator is weighted sum (weights in weights_gen) of learn_crit.loss_func on the batch of fake (passed throught the critic to become predictions) with a target of 1, and the learn_gen.loss_func applied to the output (batch of fake) and the target (corresponding batch of superres images).
生成器网络的损失则是learn_crit.loss_func和learn_gen.loss_func的加权和(权重在weights_gen中),(其中)learn_crit.loss_func是基于目标为1的fake(失败的)批次(通过评价网络成为预测值),learn_gen.loss_func则应用到fake批次的输出及目标上(对应一批超分辨率图像)。
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.,50.), show_img=False, switcher=switcher,
opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=wd)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))
lr = 1e-4
learn.fit(40,lr)
| epoch | train_loss | gen_loss | disc_loss |
|---|---|---|---|
| 1 | 2.071352 | 2.025429 | 4.047686 |
| 2 | 1.996251 | 1.850199 | 3.652173 |
| 3 | 2.001999 | 2.035176 | 3.612669 |
| 4 | 1.921844 | 1.931835 | 3.600355 |
| 5 | 1.987216 | 1.961323 | 3.606629 |
| 6 | 2.022372 | 2.102732 | 3.609494 |
| 7 | 1.900056 | 2.059208 | 3.581742 |
| 8 | 1.942305 | 1.965547 | 3.538015 |
| 9 | 1.954079 | 2.006257 | 3.593008 |
| 10 | 1.984677 | 1.771790 | 3.617556 |
| 11 | 2.040979 | 2.079904 | 3.575464 |
| 12 | 2.009052 | 1.739175 | 3.626755 |
| 13 | 2.014115 | 1.204614 | 3.582353 |
| 14 | 2.042148 | 1.747239 | 3.608723 |
| 15 | 2.113957 | 1.831483 | 3.684338 |
| 16 | 1.979398 | 1.923163 | 3.600483 |
| 17 | 1.996756 | 1.760739 | 3.635300 |
| 18 | 1.976695 | 1.982629 | 3.575843 |
| 19 | 2.088960 | 1.822936 | 3.617471 |
| 20 | 1.949941 | 1.996513 | 3.594223 |
| 21 | 2.079416 | 1.918284 | 3.588732 |
| 22 | 2.055047 | 1.869254 | 3.602390 |
| 23 | 1.860164 | 1.917518 | 3.557776 |
| 24 | 1.945440 | 2.033273 | 3.535242 |
| 25 | 2.026493 | 1.804196 | 3.558001 |
| 26 | 1.875208 | 1.797288 | 3.511697 |
| 27 | 1.972286 | 1.798044 | 3.570746 |
| 28 | 1.950635 | 1.951106 | 3.525849 |
| 29 | 2.013820 | 1.937439 | 3.592216 |
| 30 | 1.959477 | 1.959566 | 3.561970 |
| 31 | 2.012466 | 2.110288 | 3.539897 |
| 32 | 1.982466 | 1.905378 | 3.559940 |
| 33 | 1.957023 | 2.207354 | 3.540873 |
| 34 | 2.049188 | 1.942845 | 3.638360 |
| 35 | 1.913136 | 1.891638 | 3.581291 |
| 36 | 2.037127 | 1.808180 | 3.572567 |
| 37 | 2.006383 | 2.048738 | 3.553226 |
| 38 | 2.000312 | 1.657985 | 3.594805 |
| 39 | 1.973937 | 1.891186 | 3.533843 |
| 40 | 2.002513 | 1.853988 | 3.554688 |
learn.save('gan-1c')
learn.data=get_data(16,192)
learn.fit(10,lr/2)
| epoch | train_loss | gen_loss | disc_loss |
|---|---|---|---|
| 1 | 2.578580 | 2.415008 | 4.716179 |
| 2 | 2.620808 | 2.487282 | 4.729377 |
| 3 | 2.596190 | 2.579693 | 4.796489 |
| 4 | 2.701113 | 2.522197 | 4.821410 |
| 5 | 2.545030 | 2.401921 | 4.710739 |
| 6 | 2.638539 | 2.548171 | 4.776103 |
| 7 | 2.551988 | 2.513859 | 4.644952 |
| 8 | 2.629724 | 2.490307 | 4.701890 |
| 9 | 2.552170 | 2.487726 | 4.728183 |
| 10 | 2.597136 | 2.478334 | 4.649708 |
learn.show_results(rows=16)
learn.save('gan-1c')