%reload_ext autoreload %autoreload 2 from fastai import * from fastai.vision import * path = untar_data(URLs.LSUN_BEDROOMS) class NoisyItem(ItemBase): def __init__(self, noise_sz): self.obj,self.data = noise_sz,torch.randn(noise_sz, 1, 1) def __str__(self): return '' def apply_tfms(self, tfms, **kwargs): return self class GANItemList(ImageItemList): _label_cls = ImageItemList def __init__(self, items, noise_sz:int=100, **kwargs): super().__init__(items, **kwargs) self.noise_sz = noise_sz self.copy_new.append('noise_sz') def get(self, i): return NoisyItem(self.noise_sz) def reconstruct(self, t): return NoisyItem(t.size(0)) def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs): super().show_xys(ys, xs, imgsize=imgsize, figsize=figsize, **kwargs) def get_data(bs, size): train_ds = (GANItemList.from_folder(path).label_from_func(noop) .transform(tfms=[crop_pad(size=size, row_pct=(0,1), col_pct=(0,1))], size=size, tfm_y=True)) return (ImageDataBunch.create(train_ds, valid_ds=None, path=path, bs=bs) .normalize(do_x=False, stats = [torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])], do_y=True)) data = get_data(128, 64) data.show_batch(rows=5) from fastai.vision.gan import * generator = basic_generator(in_size=64, n_channels=3, n_extra_layers=1) critic = basic_critic(in_size=64, n_channels=3, n_extra_layers=1) learn = GANLearner.wgan(data, generator, critic, opt_func=optim.RMSprop, wd=0.) learn.fit(1,1e-4)