%matplotlib inline
from fastai2.vision.all import *
from fastai2.vision.gan import *
For this lesson, we'll be using the bedrooms from the LSUN dataset. The full dataset is a bit too large so we'll use a sample from kaggle.
path = untar_data(URLs.LSUN_BEDROOMS)
We then grab all the images in the folder with the data block API. We don't create a validation set here for reasons we'll explain later. It consists of random noise of size 100 by default (can be changed if you replace generate_noise by partial(generate_noise, size=...)) as inputs and the images of bedrooms as targets.
dblock = DataBlock(blocks = (TransformBlock, ImageBlock),
get_x = generate_noise,
get_items = get_image_files,
splitter = IndexSplitter([]))
def get_dls(bs, size):
dblock = DataBlock(blocks = (TransformBlock, ImageBlock),
get_x = generate_noise,
get_items = get_image_files,
splitter = IndexSplitter([]),
item_tfms=Resize(size, method=ResizeMethod.Crop),
batch_tfms = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])))
return dblock.dataloaders(path, path=path, bs=bs)
We'll begin with a small size since GANs take a lot of time to train.
dls = get_dls(128, 64)
dls.show_batch(max_n=16)
GAN stands for Generative Adversarial Nets and were invented by Ian Goodfellow. The concept is that we will train two models at the same time: a generator and a critic. The generator will try to make new images similar to the ones in our dataset, and the critic will try to classify real images from the ones the generator does. The generator returns images, the critic a single number (usually 0. for fake images and 1. for real ones).
We train them against each other in the sense that at each step (more or less), we:
real)fake)Here, we'll use the Wassertein GAN.
We create a generator and a critic that we pass to gan_learner. The noise_size is the size of the random vector from which our generator creates images.
generator = basic_generator(64, n_channels=3, n_extra_layers=1)
critic = basic_critic (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2))
learn = GANLearner.wgan(dls, generator, critic, opt_func = partial(Adam, mom=0.))
learn.recorder.train_metrics=True
learn.recorder.valid_metrics=False
learn.fit(30, 2e-4, wd=0)
| epoch | train_loss | gen_loss | crit_loss | time |
|---|---|---|---|---|
| 0 | -0.811070 | 0.608824 | -1.134686 | 02:17 |
| 1 | -0.854871 | 0.607675 | -1.080127 | 02:16 |
| 2 | -0.659264 | 0.519865 | -0.971711 | 02:16 |
| 3 | -0.570475 | 0.475425 | -0.892289 | 02:17 |
| 4 | -0.691494 | 0.434911 | -0.837299 | 02:17 |
| 5 | -0.560793 | 0.392304 | -0.769569 | 02:19 |
| 6 | -0.495189 | 0.350668 | -0.708923 | 02:17 |
| 7 | -0.453588 | 0.314808 | -0.646859 | 02:17 |
| 8 | -0.420760 | 0.267400 | -0.581581 | 02:17 |
| 9 | -0.392180 | 0.241475 | -0.531013 | 03:09 |
| 10 | -0.353970 | 0.206983 | -0.481969 | 02:18 |
| 11 | -0.326770 | 0.177502 | -0.439688 | 02:16 |
| 12 | -0.299592 | 0.162772 | -0.404918 | 02:17 |
| 13 | -0.299998 | 0.140973 | -0.382297 | 02:17 |
| 14 | -0.271852 | 0.142364 | -0.361767 | 02:17 |
| 15 | -0.262536 | 0.126330 | -0.345981 | 02:17 |
| 16 | -0.240438 | 0.123006 | -0.329001 | 02:17 |
| 17 | -0.232625 | 0.120965 | -0.315758 | 02:17 |
| 18 | -0.213570 | 0.129917 | -0.301637 | 02:17 |
| 19 | -0.216735 | 0.116807 | -0.287692 | 02:17 |
| 20 | -0.210406 | 0.117491 | -0.275790 | 02:17 |
| 21 | -0.186631 | 0.102863 | -0.261194 | 02:18 |
| 22 | -0.186007 | 0.100930 | -0.249909 | 02:17 |
| 23 | -0.183062 | 0.087650 | -0.239010 | 02:17 |
| 24 | -0.174468 | 0.084587 | -0.230857 | 02:17 |
| 25 | -0.168775 | 0.071609 | -0.220272 | 02:17 |
| 26 | -0.152947 | 0.074026 | -0.214880 | 02:17 |
| 27 | -0.163944 | 0.075298 | -0.208268 | 02:17 |
| 28 | -0.148677 | 0.067286 | -0.202055 | 02:17 |
| 29 | -0.146988 | 0.058458 | -0.197080 | 02:17 |
/home/sgugger/git/fastprogress/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.
warn("Your generator is empty.")
#learn.gan_trainer.switch(gen_mode=True)
learn.show_results(max_n=16, figsize=(8,8), ds_idx=0)