%load_ext autoreload %autoreload 2 #export from nb_005b import * PATH = Path('data/carvana') PATH_PNG = PATH/'train_masks_png' PATH_X_FULL = PATH/'train' PATH_X_128 = PATH/'train-128' PATH_Y_FULL = PATH_PNG PATH_Y_128 = PATH/'train_masks-128' # start with the 128x128 images PATH_X = PATH_X_128 PATH_Y = PATH_Y_128 img_f = next(PATH_X.iterdir()) open_image(img_f).show() #export class ImageMask(Image): def lighting(self, func, *args, **kwargs): return self def refresh(self): self.sample_kwargs['mode'] = 'nearest' return super().refresh() def open_mask(fn): return ImageMask(pil2tensor(PIL.Image.open(fn)).long()) def get_y_fn(x_fn): return PATH_Y/f'{x_fn.name[:-4]}_mask.png' img_y_f = get_y_fn(img_f) y = open_mask(img_y_f) y.show() #export # Same as `show_image`, but renamed with _ prefix def _show_image(img, ax=None, figsize=(3,3), hide_axis=True, cmap='binary', alpha=None): if ax is None: fig,ax = plt.subplots(figsize=figsize) ax.imshow(image2np(img), cmap=cmap, alpha=alpha) if hide_axis: ax.axis('off') return ax def show_image(x, y=None, ax=None, figsize=(3,3), alpha=0.5, hide_axis=True, cmap='viridis'): ax1 = _show_image(x, ax=ax, hide_axis=hide_axis, cmap=cmap) if y is not None: _show_image(y, ax=ax1, alpha=alpha, hide_axis=hide_axis, cmap=cmap) if hide_axis: ax1.axis('off') def _show(self, ax=None, y=None, **kwargs): if y is not None: y=y.data return show_image(self.data, ax=ax, y=y, **kwargs) Image.show = _show x = open_image(img_f) x.show(y=y) x.shape #export class DatasetTfm(Dataset): def __init__(self, ds:Dataset, tfms:Collection[Callable]=None, tfm_y:bool=False, **kwargs): self.ds,self.tfms,self.tfm_y,self.x_kwargs = ds,tfms,tfm_y,kwargs self.y_kwargs = {**self.x_kwargs, 'do_resolve':False} # don't reset random vars def __len__(self): return len(self.ds) def __getitem__(self,idx): x,y = self.ds[idx] x = apply_tfms(self.tfms, x, **self.x_kwargs) if self.tfm_y: y = apply_tfms(self.tfms, y, **self.y_kwargs) return x, y def __getattr__(self,k): return getattr(self.ds, k) import nb_002b,nb_005 nb_002b.DatasetTfm = DatasetTfm nb_005.DatasetTfm = DatasetTfm #export class MatchedImageDataset(DatasetBase): def __init__(self, x:Collection[Path], y:Collection[Path]): assert len(x)==len(y) self.x,self.y = np.array(x),np.array(y) def __getitem__(self, i): return open_image(self.x[i]), open_mask(self.y[i]) def get_datasets(path): x_fns = [o for o in path.iterdir() if o.is_file()] y_fns = [get_y_fn(o) for o in x_fns] mask = [o>=1008 for o in range(len(x_fns))] arrs = arrays_split(mask, x_fns, y_fns) return [MatchedImageDataset(*o) for o in arrs] train_ds,valid_ds = get_datasets(PATH_X_128) train_ds,valid_ds x,y = next(iter(train_ds)) x.shape, y.shape, type(x), type(y) size=128 def get_tfm_datasets(size): datasets = get_datasets(PATH_X_128 if size<=128 else PATH_X_FULL) tfms = get_transforms(do_flip=True, max_rotate=4, max_lighting=0.2) return transform_datasets(train_ds, valid_ds, tfms=tfms, tfm_y=True, size=size) train_tds,*_ = get_tfm_datasets(size) _,axes = plt.subplots(1,4, figsize=(12,6)) for i, ax in enumerate(axes.flat): imgx,imgy = train_tds[i] imgx.show(ax, y=imgy) #export def normalize_batch(b, mean, std, do_y=False): x,y = b x = normalize(x,mean,std) if do_y: y = normalize(y,mean,std) return x,y def normalize_funcs(mean, std, do_y=False, device=None): if device is None: device=default_device return (partial(normalize_batch, mean=mean.to(device),std=std.to(device), do_y=do_y), partial(denormalize, mean=mean, std=std)) default_norm,default_denorm = normalize_funcs(*imagenet_stats) bs = 64 def get_data(size, bs): return DataBunch.create(*get_tfm_datasets(size), bs=bs, tfms=default_norm) data = get_data(size, bs) #export def show_xy_images(x,y,rows,figsize=(9,9)): fig, axs = plt.subplots(rows,rows,figsize=figsize) for i, ax in enumerate(axs.flatten()): show_image(x[i], y=y[i], ax=ax) plt.tight_layout() x,y = next(iter(data.train_dl)) x,y = x.cpu(),y.cpu() x = default_denorm(x) show_xy_images(x,y,4, figsize=(9,9)) x.shape, y.shape #export class Debugger(nn.Module): def forward(self,x): set_trace() return x class StdUpsample(nn.Module): def __init__(self, nin, nout): super().__init__() self.conv = conv2d_trans(nin, nout) self.bn = nn.BatchNorm2d(nout) def forward(self, x): return self.bn(F.relu(self.conv(x))) def std_upsample_head(c, *nfs): return nn.Sequential( nn.ReLU(), *(StdUpsample(nfs[i],nfs[i+1]) for i in range(4)), conv2d_trans(nfs[-1], c) ) head = std_upsample_head(2, 512,256,256,256,256) head #export def dice(input, targs): "dice coefficient metric for binary target" n = targs.shape[0] input = input.argmax(dim=1).view(n,-1) targs = targs.view(n,-1) intersect = (input*targs).sum().float() union = (input+targs).sum().float() return 2. * intersect / union def accuracy(input, targs): n = targs.shape[0] input = input.argmax(dim=1).view(n,-1) targs = targs.view(n,-1) return (input==targs).float().mean() class CrossEntropyFlat(nn.CrossEntropyLoss): "Same as `nn.CrossEntropyLoss`, but flattens input and target" def forward(self, input, target): n,c,*_ = input.shape return super().forward(input.view(n, c, -1), target.view(n, -1)) metrics=[accuracy, dice] learn = ConvLearner(data, tvm.resnet34, 2, custom_head=head, metrics=metrics, loss_fn=CrossEntropyFlat()) lr_find(learn) learn.recorder.plot() lr = 1e-1 learn.fit_one_cycle(10, slice(lr)) learn.unfreeze() learn.save('0') learn.load('0') lr = 2e-2 learn.fit_one_cycle(10, slice(lr/100,lr)) x,y,py = learn.pred_batch() py = py.argmax(dim=1).unsqueeze(1) for i, ax in enumerate(plt.subplots(4,4,figsize=(10,10))[1].flat): show_image(default_denorm(x[i].cpu()), py[i], ax=ax) learn.save('1') size=512 bs = 8 data = get_data(size, bs) learn.data = data learn.load('1') learn.freeze() lr = 2e-2 learn.fit_one_cycle(5, slice(lr)) learn.save('2') learn.load('2') lr = 2e-2 learn.unfreeze() learn.fit_one_cycle(8, slice(lr/100,lr)) learn.save('3') x,py = learn.pred_batch() for i, ax in enumerate(plt.subplots(4,4,figsize=(10,10))[1].flat): show_image(default_denorm(x[i].cpu()), py[i]>0, ax=ax)