%reload_ext autoreload %autoreload 2 from fastai import * from fastai.vision import * PATH = Path('../data/lsun') IMG_PATH = PATH/'bedroom' def create_csv_file(sample=False): files = PATH.glob('bedroom/**/*.jpg') with (PATH/'files.csv').open('w') as fo: for f in files: if not sample or random.random() < 0.1: fo.write(f'{f},0\n') #create_csv_file(sample=False) df = pd.read_csv(PATH/'files.csv', header=None) fns, ys = np.array(df[0]), np.array(df[1]) train_ds = ImageDataset(fns, ys) size = 64 train_tds = DatasetTfm(train_ds, tfms = [crop_pad(size=size, row_pct=(0,1), col_pct=(0,1))], size=size) norm, denorm = normalize_funcs(mean = torch.tensor([0.5,0.5,0.5]), std = torch.tensor([0.5,0.5,0.5])) data = DataBunch.create(train_tds, valid_ds=None, path=PATH, bs=128, tfms=[norm]) data.valid_dl = None def conv_layer1(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias:bool=False, bn:bool=True, leaky:bool=False, slope:float=0.1, transpose:bool=False): if padding is None: padding = (ks-1)//2 if not transpose else 0 conv_func = nn.ConvTranspose2d if transpose else nn.Conv2d activ = nn.LeakyReLU(inplace=True, negative_slope=slope) if leaky else nn.ReLU(inplace=True) layers = [conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding), activ] if bn: layers.append(nn.BatchNorm2d(nf)) return nn.Sequential(*layers) def AvgFlatten(): return Lambda(lambda x: x.mean(0).view(1)) def discriminator(in_size, n_channels, n_features, n_extra_layers=0): layers = [conv_layer1(n_channels, n_features, 4, 2, 1, bn=False, leaky=True, slope=0.2)] cur_size, cur_ftrs = in_size//2, n_features layers.append(nn.Sequential(*[conv_layer1(cur_ftrs, cur_ftrs, 3, 1, leaky=True, slope=0.2) for _ in range(n_extra_layers)])) while cur_size > 4: layers.append(conv_layer1(cur_ftrs, cur_ftrs*2, 4, 2, 1, leaky=True, slope=0.2)) cur_ftrs *= 2 ; cur_size //= 2 layers += [conv2d(cur_ftrs, 1, 4, padding=0), AvgFlatten()] return nn.Sequential(*layers) def generator(in_size, noise_sz, n_channels, n_features, n_extra_layers=0): cur_size, cur_ftrs = 4, n_features//2 while cur_size < in_size: cur_size *= 2; cur_ftrs *= 2 layers = [conv_layer1(noise_sz, cur_ftrs, 4, 1, transpose=True)] cur_size = 4 while cur_size < in_size // 2: layers.append(conv_layer1(cur_ftrs, cur_ftrs//2, 4, 2, 1, transpose=True)) cur_ftrs //= 2; cur_size *= 2 layers += [conv_layer1(cur_ftrs, cur_ftrs, 3, 1, 1, transpose=True) for _ in range(n_extra_layers)] layers += [conv2d_trans(cur_ftrs, n_channels, 4, 2, 1, bias=False), nn.Tanh()] return nn.Sequential(*layers) generator(64, 100, 3, 64, 1) discriminator(64, 3, 64, 1) class BasicGAN(nn.Module): def __init__(self, in_size, noise_sz, n_channels, n_features, n_extra_layers=0): super().__init__() self.discriminator = discriminator(in_size, n_channels, n_features, n_extra_layers) self.generator = generator(in_size, noise_sz, n_channels, n_features, n_extra_layers) def forward(self, x, gen=False): return self.generator(x) if gen else self.discriminator(x) def first_disc_iter(gen_iter): return 100 if (gen_iter < 25 or gen_iter%500 == 0) else 5 def standard_disc_iter(gen_iter): return 100 if gen_iter%500 == 0 else 5 noise_sz = 100 def create_noise(x, b, grad=True): return x.new(b, noise_sz, 1, 1).normal_(0, 1).requires_grad_(grad) class WasserteinLoss(nn.Module): def forward(self, real, fake): return real[0] - fake[0] @dataclass class GANTrainer(LearnerCallback): loss_fn:LossFunction = WasserteinLoss() n_disc_iter:Callable = standard_disc_iter clip:float = 0.01 bs:int = 64 def _set_trainable(self, gen=False): requires_grad(self.learn.model.generator, gen) requires_grad(self.learn.model.discriminator, not gen) if gen: self.opt_gen.lr, self.opt_gen.mom = self.learn.opt.lr, self.learn.opt.mom self.opt_gen.wd, self.opt_gen.beta = self.learn.opt.wd, self.learn.opt.beta def on_train_begin(self, **kwargs): opt_fn = self.learn.opt_fn lr, wd, true_wd, bn_wd = self.learn.opt.lr, self.learn.opt.wd, self.learn.opt.true_wd, self.learn.opt.bn_wd self.opt_gen = OptimWrapper.create(opt_fn, lr, [nn.Sequential(*flatten_model(self.learn.model.generator))], wd=wd, true_wd=true_wd, bn_wd=bn_wd) self.opt_disc = OptimWrapper.create(opt_fn, lr, [nn.Sequential(*flatten_model(self.learn.model.discriminator))], wd=wd, true_wd=true_wd, bn_wd=bn_wd) self.learn.opt.opt = self.opt_disc.opt self.disc_iters, self.gen_iters = 0, 0 self._set_trainable() self.dlosses,self.glosses = [],[] def on_batch_begin(self, **kwargs): for p in self.learn.model.discriminator.parameters(): p.data.clamp_(-self.clip, self.clip) def on_backward_begin(self, last_output, last_input, **kwargs): fake = self.learn.model(create_noise(last_input, last_input.size(0), False), gen=True) fake.requires_grad_(True) loss = self.loss_fn(last_output, self.learn.model(fake)) self.dlosses.append(loss.detach().cpu()) return loss def on_batch_end(self, last_input, **kwargs): self.disc_iters += 1 if self.disc_iters == self.n_disc_iter(self.gen_iters): self.disc_iters = 0 self._set_trainable(True) loss = self.learn.model(self.learn.model(create_noise(last_input,self.bs), gen=True)).mean().view(1)[0] self.glosses.append(loss.detach().cpu()) self.learn.model.generator.zero_grad() loss.backward() self.opt_gen.step() self.gen_iters += 1 self._set_trainable() class NoopLoss(nn.Module): def forward(self, output, target): return output[0] wgan = BasicGAN(64, 100, 3, 64, 1) learn = Learner(data, wgan, loss_fn=NoopLoss(), opt_fn=optim.RMSprop, wd=0.) cb = GANTrainer(learn, bs=128, n_disc_iter=first_disc_iter) learn.callbacks.append(cb) learn.fit(1, 1e-4) x,y = next(iter(learn.data.train_dl)) tst = learn.model(create_noise(x,64,False), gen=True) imgs = denorm(tst.cpu()).numpy().clip(0,1) fig,axs = plt.subplots(5,5,figsize=(8,8)) for i,ax in enumerate(axs.flatten()): ax.imshow(imgs[i].transpose(1,2,0)) ax.axis('off') plt.tight_layout() learn.save('temp') cb = GANTrainer(learn, bs=128, n_disc_iter=standard_disc_iter) learn.callbacks.append(cb) learn.fit(1, 1e-5) tst = learn.model(create_noise(x,64,False), gen=True) imgs = denorm(tst.cpu()).numpy().clip(0,1) fig,axs = plt.subplots(5,5,figsize=(8,8)) for i,ax in enumerate(axs.flatten()): ax.imshow(imgs[i].transpose(1,2,0)) ax.axis('off') plt.tight_layout()