%matplotlib inline
from fastai.gen_doc.nbdoc import *
from fastai.vision import *
from fastai.vision.gan import *
GAN stands for Generative Adversarial Nets and were invented by Ian Goodfellow. The concept is that we will train two models at the same time: a generator and a critic. The generator will try to make new images similar to the ones in our dataset, and the critic's job will try to classify real images from the fake ones the generator does. The generator returns images, the discriminator a feature map (it can be a single number depending on the input size). Usually the discriminator will be trained to return 0. everywhere for fake images and 1. everywhere for real ones.
This module contains all the necessary function to create a GAN.
We train them against each other in the sense that at each step (more or less), we:
real)fake)show_doc(GANLearner)
class GANLearner[source][test]
GANLearner(data:DataBunch,generator:Module,critic:Module,gen_loss_func:LossFunction,crit_loss_func:LossFunction,switcher:Callback=*None,gen_first:bool=False,switch_eval:bool=True,show_img:bool=True,clip:float=None, ***learn_kwargs**) ::Learner
No tests found for GANLearner. To contribute a test please refer to this guide and this discussion.
A Learner suitable for GANs.
This is the general constructor to create a GAN, you might want to use one of the factory methods that are easier to use. Create a GAN from data, a generator and a critic. The data should have the inputs the generator will expect and the images wanted as targets.
gen_loss_func is the loss function that will be applied to the generator. It takes three argument fake_pred, target, output and should return a rank 0 tensor. output is the result of the generator applied to the input (the xs of the batch), target is the ys of the batch and fake_pred is the result of the discriminator being given output. outputand target can be used to add a specific loss to the GAN loss (pixel loss, feature loss) and for a good training of the gan, the loss should encourage fake_pred to be as close to 1 as possible (the generator is trained to fool the critic).
crit_loss_func is the loss function that will be applied to the critic. It takes two arguments real_pred and fake_pred. real_pred is the result of the critic on the target images (the ys of the batch) and fake_pred is the result of the critic applied on a batch of fake, generated byt the generator from the xs of the batch.
switcher is a Callback that should tell the GAN when to switch from critic to generator and vice versa. By default it does 5 iterations of the critic for 1 iteration of the generator. The model begins the training with the generator if gen_first=True. If switch_eval=True, the model that isn't trained is switched on eval mode (left in training mode otherwise, which means some statistics like the running mean in batchnorm layers are updated, or the dropouts are applied).
clip should be set to a certain value if one wants to clip the weights (see the Wassertein GAN for instance).
If show_img=True, one image generated by the GAN is shown at the end of each epoch.
show_doc(GANLearner.from_learners)
from_learners[source][test]
from_learners(learn_gen:Learner,learn_crit:Learner,switcher:Callback=*None,weights_gen:Point=None, ***learn_kwargs**)
No tests found for from_learners. To contribute a test please refer to this guide and this discussion.
Create a GAN from learn_gen and learn_crit.
Directly creates a GANLearner from two Learner: one for the generator and one for the critic. The switcher and all kwargs will be passed to the initialization of GANLearner along with the following loss functions:
loss_func_crit is the mean of learn_crit.loss_func applied to real_pred and a target of ones with learn_crit.loss_func applied to fake_pred and a target of zerosloss_func_gen is the mean of learn_crit.loss_func applied to fake_pred and a target of ones (to full the discriminator) with learn_gen.loss_func applied to output and target. The weights of each of those contributions can be passed in weights_gen (default is 1. and 1.)show_doc(GANLearner.wgan)
wgan[source][test]
wgan(data:DataBunch,generator:Module,critic:Module,switcher:Callback=*None,clip:float=0.01, ***learn_kwargs**)
No tests found for wgan. To contribute a test please refer to this guide and this discussion.
Create a WGAN from data, generator and critic.
The Wasserstein GAN is detailed in [this article]. switcher and the kwargs will be passed to the GANLearner init, clipis the weight clipping.
show_doc(FixedGANSwitcher, title_level=3)
class FixedGANSwitcher[source][test]
FixedGANSwitcher(learn:Learner,n_crit:Union[int,Callable]=*1,n_gen:Union[int,Callable]=1*) ::LearnerCallback
No tests found for FixedGANSwitcher. To contribute a test please refer to this guide and this discussion.
Switcher to do n_crit iterations of the critic then n_gen iterations of the generator.
show_doc(FixedGANSwitcher.on_train_begin)
on_train_begin[source][test]
on_train_begin(****kwargs**)
No tests found for on_train_begin. To contribute a test please refer to this guide and this discussion.
Initiate the iteration counts.
show_doc(FixedGANSwitcher.on_batch_end)
on_batch_end[source][test]
on_batch_end(iteration, ****kwargs**)
No tests found for on_batch_end. To contribute a test please refer to this guide and this discussion.
Switch the model if necessary.
show_doc(AdaptiveGANSwitcher, title_level=3)
class AdaptiveGANSwitcher[source][test]
AdaptiveGANSwitcher(learn:Learner,gen_thresh:float=*None,critic_thresh:float=None*) ::LearnerCallback
No tests found for AdaptiveGANSwitcher. To contribute a test please refer to this guide and this discussion.
Switcher that goes back to generator/critic when the loss goes below gen_thresh/crit_thresh.
show_doc(AdaptiveGANSwitcher.on_batch_end)
on_batch_end[source][test]
on_batch_end(last_loss, ****kwargs**)
No tests found for on_batch_end. To contribute a test please refer to this guide and this discussion.
Switch the model if necessary.
If you want to train your critic at a different learning rate than the generator, this will let you do it automatically (even if you have a learning rate schedule).
show_doc(GANDiscriminativeLR, title_level=3)
class GANDiscriminativeLR[source][test]
GANDiscriminativeLR(learn:Learner,mult_lr:float=*5.0*) ::LearnerCallback
No tests found for GANDiscriminativeLR. To contribute a test please refer to this guide and this discussion.
Callback that handles multiplying the learning rate by mult_lr for the critic.
show_doc(GANDiscriminativeLR.on_batch_begin)
on_batch_begin[source][test]
on_batch_begin(train, ****kwargs**)
No tests found for on_batch_begin. To contribute a test please refer to this guide and this discussion.
Multiply the current lr if necessary.
show_doc(GANDiscriminativeLR.on_step_end)
on_step_end[source][test]
on_step_end(****kwargs**)
No tests found for on_step_end. To contribute a test please refer to this guide and this discussion.
Put the LR back to its value if necessary.
show_doc(basic_critic)
This model contains a first 4 by 4 convolutional layer of stride 2 from n_channels to n_features followed by n_extra_layers 3 by 3 convolutional layer of stride 1. Then we put as many 4 by 4 convolutional layer of stride 2 with a number of features multiplied by 2 at each stage so that the in_size becomes 1. kwargs can be used to customize the convolutional layers and are passed to conv_layer.
show_doc(basic_generator)
This model contains a first 4 by 4 transposed convolutional layer of stride 1 from noise_size to the last numbers of features of the corresponding critic. Then we put as many 4 by 4 transposed convolutional layer of stride 2 with a number of features divided by 2 at each stage so that the image ends up being of height and widht in_size//2. At the end, we addn_extra_layers 3 by 3 convolutional layer of stride 1. The last layer is a transpose convolution of size 4 by 4 and stride 2 followed by tanh. kwargs can be used to customize the convolutional layers and are passed to conv_layer.
show_doc(gan_critic)
gan_critic[source][test]
gan_critic(n_channels:int=*3,nf:int=128,n_blocks:int=3,p:int=0.15*)
No tests found for gan_critic. To contribute a test please refer to this guide and this discussion.
Critic to train a GAN.
show_doc(GANTrainer)
class GANTrainer[source][test]
GANTrainer(learn:Learner,switch_eval:bool=*False,clip:float=None,beta:float=0.98,gen_first:bool=False,show_img:bool=True*) ::LearnerCallback
Handles GAN Training.
LearnerCallback that will be responsible to handle the two different optimizers (one for the generator and one for the critic), and do all the work behind the scenes so that the generator (or the critic) are in training mode with parameters requirement gradients each time we switch.
switch_eval=True means that the GANTrainer will put the model that isn't training into eval mode (if it's False its running statistics like in batchnorm layers will be updated and dropout will be applied). clip is the clipping applied to the weights (if not None). beta is the coefficient for the moving averages as the GANTrainertracks separately the generator loss and the critic loss. gen_first=True means the training begins with the generator (with the critic if it's False). If show_img=True we show a generated image at the end of each epoch.
show_doc(GANTrainer.switch)
If gen_mode is left as None, just put the model in the other mode (critic if it was in generator mode and vice versa).
show_doc(GANTrainer.on_train_begin)
on_train_begin[source][test]
on_train_begin(****kwargs**)
No tests found for on_train_begin. To contribute a test please refer to this guide and this discussion.
Create the optimizers for the generator and critic if necessary, initialize smootheners.
show_doc(GANTrainer.on_epoch_begin)
on_epoch_begin[source][test]
on_epoch_begin(epoch, ****kwargs**)
No tests found for on_epoch_begin. To contribute a test please refer to this guide and this discussion.
Put the critic or the generator back to eval if necessary.
show_doc(GANTrainer.on_batch_begin)
on_batch_begin[source][test]
on_batch_begin(last_input,last_target, ****kwargs**)
No tests found for on_batch_begin. To contribute a test please refer to this guide and this discussion.
Clamp the weights with self.clip if it's not None, return the correct input.
show_doc(GANTrainer.on_backward_begin)
on_backward_begin[source][test]
on_backward_begin(last_loss,last_output, ****kwargs**)
No tests found for on_backward_begin. To contribute a test please refer to this guide and this discussion.
Record last_loss in the proper list.
show_doc(GANTrainer.on_epoch_end)
on_epoch_end[source][test]
on_epoch_end(pbar,epoch,last_metrics, ****kwargs**)
No tests found for on_epoch_end. To contribute a test please refer to this guide and this discussion.
Put the various losses in the recorder and show a sample image.
show_doc(GANTrainer.on_train_end)
on_train_end[source][test]
on_train_end(****kwargs**)
No tests found for on_train_end. To contribute a test please refer to this guide and this discussion.
Switch in generator mode for showing results.
show_doc(GANModule, title_level=3)
If gen_mode is left as None, just put the model in the other mode (critic if it was in generator mode and vice versa).
show_doc(GANModule.switch)
show_doc(GANLoss, title_level=3)
class GANLoss[source][test]
GANLoss(loss_funcG:Callable,loss_funcC:Callable,gan_model:GANModule) ::PrePostInitMeta::GANModule
No tests found for GANLoss. To contribute a test please refer to this guide and this discussion.
Wrapper around loss_funcC (for the critic) and loss_funcG (for the generator).
show_doc(AdaptiveLoss, title_level=3)
class AdaptiveLoss[source][test]
AdaptiveLoss(crit) ::PrePostInitMeta::Module
No tests found for AdaptiveLoss. To contribute a test please refer to this guide and this discussion.
Expand the target to match the output size before applying crit.
show_doc(accuracy_thresh_expand)
accuracy_thresh_expand[source][test]
accuracy_thresh_expand(y_pred:Tensor,y_true:Tensor,thresh:float=*0.5,sigmoid:bool=True*) →Rank0Tensor
No tests found for accuracy_thresh_expand. To contribute a test please refer to this guide and this discussion.
Compute accuracy after expanding y_true to the size of y_pred.
show_doc(NoisyItem, title_level=3)
show_doc(GANItemList, title_level=3)
show_doc(GANItemList.show_xys)
show_xys[source][test]
show_xys(xs,ys,imgsize:int=*4,figsize:Optional[Tuple[int,int]]=None, ***kwargs**)
No tests found for show_xys. To contribute a test please refer to this guide and this discussion.
Shows ys (target images) on a figure of figsize.
show_doc(GANItemList.show_xyzs)
show_xyzs[source][test]
show_xyzs(xs,ys,zs,imgsize:int=*4,figsize:Optional[Tuple[int,int]]=None, ***kwargs**)
No tests found for show_xyzs. To contribute a test please refer to this guide and this discussion.
Shows zs (generated images) on a figure of figsize.
show_doc(GANLoss.critic)
show_doc(GANModule.forward)
forward[source][test]
forward(***args**)
No tests found for forward. To contribute a test please refer to this guide and this discussion.
show_doc(GANLoss.generator)
show_doc(NoisyItem.apply_tfms)
apply_tfms[source][test]
apply_tfms(tfms, ****kwargs**)
No tests found for apply_tfms. To contribute a test please refer to this guide and this discussion.
Subclass this method if you want to apply data augmentation with tfms to this ItemBase.
show_doc(AdaptiveLoss.forward)
forward[source][test]
forward(output,target)
No tests found for forward. To contribute a test please refer to this guide and this discussion.
show_doc(GANItemList.get)
get[source][test]
get(i)
No tests found for get. To contribute a test please refer to this guide and this discussion.
Subclass if you want to customize how to create item i from self.items.
show_doc(GANItemList.reconstruct)
reconstruct[source][test]
reconstruct(t)
No tests found for reconstruct. To contribute a test please refer to this guide and this discussion.
Reconstruct one of the underlying item for its data t.
show_doc(AdaptiveLoss.forward)
forward[source][test]
forward(output,target)
No tests found for forward. To contribute a test please refer to this guide and this discussion.