%reload_ext autoreload
%autoreload 2
from fastai import *
from fastai.vision import *
PATH = Path('../data/lsun')
IMG_PATH = PATH/'bedroom'
def create_csv_file(sample=False):
files = PATH.glob('bedroom/**/*.jpg')
with (PATH/'files.csv').open('w') as fo:
for f in files:
if not sample or random.random() < 0.1: fo.write(f'{f},0\n')
#create_csv_file(sample=False)
df = pd.read_csv(PATH/'files.csv', header=None)
fns, ys = np.array(df[0]), np.array(df[1])
train_ds = ImageDataset(fns, ys)
size = 64
train_tds = DatasetTfm(train_ds, tfms = [crop_pad(size=size, row_pct=(0,1), col_pct=(0,1))], size=size)
norm, denorm = normalize_funcs(mean = torch.tensor([0.5,0.5,0.5]), std = torch.tensor([0.5,0.5,0.5]))
data = DataBunch.create(train_tds, valid_ds=None, path=PATH, bs=128, tfms=[norm])
data.valid_dl = None
def conv_layer1(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias:bool=False, bn:bool=True,
leaky:bool=False, slope:float=0.1, transpose:bool=False):
if padding is None: padding = (ks-1)//2 if not transpose else 0
conv_func = nn.ConvTranspose2d if transpose else nn.Conv2d
activ = nn.LeakyReLU(inplace=True, negative_slope=slope) if leaky else nn.ReLU(inplace=True)
layers = [conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding), activ]
if bn: layers.append(nn.BatchNorm2d(nf))
return nn.Sequential(*layers)
def AvgFlatten():
return Lambda(lambda x: x.mean(0).view(1))
def discriminator(in_size, n_channels, n_features, n_extra_layers=0):
layers = [conv_layer1(n_channels, n_features, 4, 2, 1, bn=False, leaky=True, slope=0.2)]
cur_size, cur_ftrs = in_size//2, n_features
layers.append(nn.Sequential(*[conv_layer1(cur_ftrs, cur_ftrs, 3, 1, leaky=True, slope=0.2) for _ in range(n_extra_layers)]))
while cur_size > 4:
layers.append(conv_layer1(cur_ftrs, cur_ftrs*2, 4, 2, 1, leaky=True, slope=0.2))
cur_ftrs *= 2 ; cur_size //= 2
layers += [conv2d(cur_ftrs, 1, 4, padding=0), AvgFlatten()]
return nn.Sequential(*layers)
def generator(in_size, noise_sz, n_channels, n_features, n_extra_layers=0):
cur_size, cur_ftrs = 4, n_features//2
while cur_size < in_size: cur_size *= 2; cur_ftrs *= 2
layers = [conv_layer1(noise_sz, cur_ftrs, 4, 1, transpose=True)]
cur_size = 4
while cur_size < in_size // 2:
layers.append(conv_layer1(cur_ftrs, cur_ftrs//2, 4, 2, 1, transpose=True))
cur_ftrs //= 2; cur_size *= 2
layers += [conv_layer1(cur_ftrs, cur_ftrs, 3, 1, 1, transpose=True) for _ in range(n_extra_layers)]
layers += [conv2d_trans(cur_ftrs, n_channels, 4, 2, 1, bias=False), nn.Tanh()]
return nn.Sequential(*layers)
generator(64, 100, 3, 64, 1)
discriminator(64, 3, 64, 1)
class BasicGAN(nn.Module):
def __init__(self, in_size, noise_sz, n_channels, n_features, n_extra_layers=0):
super().__init__()
self.discriminator = discriminator(in_size, n_channels, n_features, n_extra_layers)
self.generator = generator(in_size, noise_sz, n_channels, n_features, n_extra_layers)
def forward(self, x, gen=False):
return self.generator(x) if gen else self.discriminator(x)
def first_disc_iter(gen_iter):
return 100 if (gen_iter < 25 or gen_iter%500 == 0) else 5
def standard_disc_iter(gen_iter):
return 100 if gen_iter%500 == 0 else 5
noise_sz = 100
def create_noise(x, b, grad=True): return x.new(b, noise_sz, 1, 1).normal_(0, 1).requires_grad_(grad)
class WasserteinLoss(nn.Module):
def forward(self, real, fake): return real[0] - fake[0]
@dataclass
class GANTrainer(LearnerCallback):
loss_fn:LossFunction = WasserteinLoss()
n_disc_iter:Callable = standard_disc_iter
clip:float = 0.01
bs:int = 64
def _set_trainable(self, gen=False):
requires_grad(self.learn.model.generator, gen)
requires_grad(self.learn.model.discriminator, not gen)
if gen:
self.opt_gen.lr, self.opt_gen.mom = self.learn.opt.lr, self.learn.opt.mom
self.opt_gen.wd, self.opt_gen.beta = self.learn.opt.wd, self.learn.opt.beta
def on_train_begin(self, **kwargs):
opt_fn = self.learn.opt_fn
lr, wd, true_wd, bn_wd = self.learn.opt.lr, self.learn.opt.wd, self.learn.opt.true_wd, self.learn.opt.bn_wd
self.opt_gen = OptimWrapper.create(opt_fn, lr,
[nn.Sequential(*flatten_model(self.learn.model.generator))],
wd=wd, true_wd=true_wd, bn_wd=bn_wd)
self.opt_disc = OptimWrapper.create(opt_fn, lr,
[nn.Sequential(*flatten_model(self.learn.model.discriminator))],
wd=wd, true_wd=true_wd, bn_wd=bn_wd)
self.learn.opt.opt = self.opt_disc.opt
self.disc_iters, self.gen_iters = 0, 0
self._set_trainable()
self.dlosses,self.glosses = [],[]
def on_batch_begin(self, **kwargs):
for p in self.learn.model.discriminator.parameters():
p.data.clamp_(-self.clip, self.clip)
def on_backward_begin(self, last_output, last_input, **kwargs):
fake = self.learn.model(create_noise(last_input, last_input.size(0), False), gen=True)
fake.requires_grad_(True)
loss = self.loss_fn(last_output, self.learn.model(fake))
self.dlosses.append(loss.detach().cpu())
return loss
def on_batch_end(self, last_input, **kwargs):
self.disc_iters += 1
if self.disc_iters == self.n_disc_iter(self.gen_iters):
self.disc_iters = 0
self._set_trainable(True)
loss = self.learn.model(self.learn.model(create_noise(last_input,self.bs), gen=True)).mean().view(1)[0]
self.glosses.append(loss.detach().cpu())
self.learn.model.generator.zero_grad()
loss.backward()
self.opt_gen.step()
self.gen_iters += 1
self._set_trainable()
class NoopLoss(nn.Module):
def forward(self, output, target): return output[0]
wgan = BasicGAN(64, 100, 3, 64, 1)
learn = Learner(data, wgan, loss_fn=NoopLoss(), opt_fn=optim.RMSprop, wd=0.)
cb = GANTrainer(learn, bs=128, n_disc_iter=first_disc_iter)
learn.callbacks.append(cb)
learn.fit(1, 1e-4)
x,y = next(iter(learn.data.train_dl))
tst = learn.model(create_noise(x,64,False), gen=True)
imgs = denorm(tst.cpu()).numpy().clip(0,1)
fig,axs = plt.subplots(5,5,figsize=(8,8))
for i,ax in enumerate(axs.flatten()):
ax.imshow(imgs[i].transpose(1,2,0))
ax.axis('off')
plt.tight_layout()
learn.save('temp')
cb = GANTrainer(learn, bs=128, n_disc_iter=standard_disc_iter)
learn.callbacks.append(cb)
learn.fit(1, 1e-5)
tst = learn.model(create_noise(x,64,False), gen=True)
imgs = denorm(tst.cpu()).numpy().clip(0,1)
fig,axs = plt.subplots(5,5,figsize=(8,8))
for i,ax in enumerate(axs.flatten()):
ax.imshow(imgs[i].transpose(1,2,0))
ax.axis('off')
plt.tight_layout()