%reload_ext autoreload %autoreload 2 #export from nb_005b import * PATH = Path('../../data/carvana') PATH_PNG = PATH/'train_masks_png' PATH_X_FULL = PATH/'train' PATH_X_128 = PATH/'train-128' PATH_Y_FULL = PATH_PNG PATH_Y_128 = PATH/'train_masks-128' # start with the 128x128 images PATH_X = PATH_X_128 PATH_Y = PATH_Y_128 img_f = next(PATH_X.iterdir()) open_image(img_f).show() #export class ImageMask(Image): "Class for image segmentation target" def clone(self)->'ImageBase': "Clones this item" return self.__class__(self.px.clone()) def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any)->'Image': return self def refresh(self): self.sample_kwargs['mode'] = 'nearest' return super().refresh() @property def data(self)->TensorImage: "Returns this images pixels as a tensor" return self.px.long() def open_mask(fn:PathOrStr) -> ImageMask: "Return `ImageMask` object create from mask in file `fn`" return ImageMask(pil2tensor(PIL.Image.open(fn)).float()) def get_y_fn(x_fn): return PATH_Y/f'{x_fn.name[:-4]}_mask.png' img_y_f = get_y_fn(img_f) y = open_mask(img_y_f) y.show() #export # Same as `show_image`, but renamed with _ prefix def _show_image(img:Image, ax:plt.Axes=None, figsize:tuple=(3,3), hide_axis:bool=True, cmap:str='binary', alpha:float=None) -> plt.Axes: if ax is None: fig,ax = plt.subplots(figsize=figsize) ax.imshow(image2np(img), cmap=cmap, alpha=alpha) if hide_axis: ax.axis('off') return ax def show_image(x:Image, y:Image=None, ax:plt.Axes=None, figsize:tuple=(3,3), alpha:float=0.5, hide_axis:bool=True, cmap:str='viridis'): ax1 = _show_image(x, ax=ax, hide_axis=hide_axis, cmap=cmap) if y is not None: _show_image(y, ax=ax1, alpha=alpha, hide_axis=hide_axis, cmap=cmap) if hide_axis: ax1.axis('off') def _show(self:Image, ax:plt.Axes=None, y:Image=None, **kwargs): if y is not None: y=y.data return show_image(self.data, ax=ax, y=y, **kwargs) Image.show = _show x = open_image(img_f) x.show(y=y) x.shape y.shape #export class DatasetTfm(Dataset): "`Dataset` that applies a list of transforms to every item drawn" def __init__(self, ds:Dataset, tfms:TfmList=None, tfm_y:bool=False, **kwargs:Any): "this dataset will apply `tfms` to `ds`" self.ds,self.tfms,self.kwargs,self.tfm_y = ds,tfms,kwargs,tfm_y self.y_kwargs = {**self.kwargs, 'do_resolve':False} def __len__(self)->int: return len(self.ds) def __getitem__(self,idx:int)->Tuple[Image,Any]: "returns tfms(x),y" x,y = self.ds[idx] x = apply_tfms(self.tfms, x, **self.kwargs) if self.tfm_y: y = apply_tfms(self.tfms, y, **self.y_kwargs) return x, y def __getattr__(self,k): "passthrough access to wrapped dataset attributes" return getattr(self.ds, k) import nb_002b nb_002b.DatasetTfm = DatasetTfm #export class SegmentationDataset(DatasetBase): "A dataset for segmentation task" def __init__(self, x:Collection[PathOrStr], y:Collection[PathOrStr]): assert len(x)==len(y) self.x,self.y = np.array(x),np.array(y) def __getitem__(self, i:int) -> Tuple[Image,ImageMask]: return open_image(self.x[i]), open_mask(self.y[i]) def get_datasets(path): x_fns = [o for o in path.iterdir() if o.is_file()] y_fns = [get_y_fn(o) for o in x_fns] mask = [o>=1008 for o in range(len(x_fns))] arrs = arrays_split(mask, x_fns, y_fns) return [SegmentationDataset(*o) for o in arrs] train_ds,valid_ds = get_datasets(PATH_X_128) train_ds,valid_ds x,y = next(iter(train_ds)) x.shape, y.shape, type(x), type(y) size=128 def get_tfm_datasets(size): datasets = get_datasets(PATH_X_128 if size<=128 else PATH_X_FULL) tfms = get_transforms(do_flip=True, max_rotate=4, max_lighting=0.2) return transform_datasets(train_ds, valid_ds, tfms=tfms, tfm_y=True, size=size, padding_mode='border') transform_datasets train_tds,*_ = get_tfm_datasets(size) _,axes = plt.subplots(1,4, figsize=(12,6)) for i, ax in enumerate(axes.flat): imgx,imgy = train_tds[i] imgx.show(ax, y=imgy) default_norm,default_denorm = normalize_funcs(*imagenet_stats) bs = 64 def get_data(size, bs): return DataBunch.create(*get_tfm_datasets(size), bs=bs, tfms=default_norm) data = get_data(size, bs) #export def show_xy_images(x:Tensor,y:Tensor,rows:int,figsize:tuple=(9,9)): "Shows a selection of images and targets from a given batch." fig, axs = plt.subplots(rows,rows,figsize=figsize) for i, ax in enumerate(axs.flatten()): show_image(x[i], y=y[i], ax=ax) plt.tight_layout() x,y = next(iter(data.train_dl)) x,y = x.cpu(),y.cpu() x = default_denorm(x) show_xy_images(x,y,4, figsize=(9,9)) x.shape, y.shape #export class Debugger(nn.Module): "A module to debug inside a model" def forward(self,x:Tensor) -> Tensor: set_trace() return x class StdUpsample(nn.Module): "Standard upsample module" def __init__(self, n_in:int, n_out:int): super().__init__() self.conv = conv2d_trans(n_in, n_out) self.bn = nn.BatchNorm2d(n_out) def forward(self, x:Tensor) -> Tensor: return self.bn(F.relu(self.conv(x))) def std_upsample_head(c, *nfs:Collection[int]) -> Model: "Creates a sequence of upsample layers" return nn.Sequential( nn.ReLU(), *(StdUpsample(nfs[i],nfs[i+1]) for i in range(4)), conv2d_trans(nfs[-1], c) ) head = std_upsample_head(2, 512,256,256,256,256) head #export def dice(input:Tensor, targs:Tensor) -> Rank0Tensor: "Dice coefficient metric for binary target" n = targs.shape[0] input = input.argmax(dim=1).view(n,-1) targs = targs.view(n,-1) intersect = (input*targs).sum().float() union = (input+targs).sum().float() return 2. * intersect / union def accuracy(input:Tensor, targs:Tensor) -> Rank0Tensor: "Accuracy" n = targs.shape[0] input = input.argmax(dim=1).view(n,-1) targs = targs.view(n,-1) return (input==targs).float().mean() class CrossEntropyFlat(nn.CrossEntropyLoss): "Same as `nn.CrossEntropyLoss`, but flattens input and target" def forward(self, input:Tensor, target:Tensor) -> Rank0Tensor: n,c,*_ = input.shape return super().forward(input.view(n, c, -1), target.view(n, -1)) metrics=[accuracy, dice] learn = ConvLearner(data, tvm.resnet34, custom_head=head, metrics=metrics, loss_fn=CrossEntropyFlat()) lr_find(learn) learn.recorder.plot() lr = 1e-1 learn.fit_one_cycle(10, slice(lr)) learn.unfreeze() learn.save('0') learn.load('0') lr = 2e-2 learn.fit_one_cycle(10, slice(lr/100,lr)) x,y,py = learn.pred_batch() py = py.argmax(dim=1).unsqueeze(1) for i, ax in enumerate(plt.subplots(4,4,figsize=(10,10))[1].flat): show_image(default_denorm(x[i].cpu()), py[i], ax=ax) learn.save('1') size=512 bs = 8 data = get_data(size, bs) learn.data = data learn.load('1') learn.freeze() lr = 2e-2 learn.fit_one_cycle(5, slice(lr)) learn.save('2') learn.load('2') lr = 2e-2 learn.unfreeze() learn.fit_one_cycle(8, slice(lr/100,lr)) learn.save('3') x,py = learn.pred_batch() for i, ax in enumerate(plt.subplots(4,4,figsize=(10,10))[1].flat): show_image(default_denorm(x[i].cpu()), py[i]>0, ax=ax) def convert_img(fn): Image.open(fn).save(PATH_PNG/f'{fn.name[:-4]}.png') def resize_img(fn, dirname): Image.open(fn).resize((128,128)).save((fn.parent.parent)/dirname/fn.name) def do_conversion(): PATH_PNG.mkdir(exist_ok=True) PATH_X.mkdir(exist_ok=True) PATH_Y.mkdir(exist_ok=True) files = list((PATH/'train_masks').iterdir()) with ThreadPoolExecutor(8) as e: e.map(convert_img, files) files = list((PATH_PNG).iterdir()) with ThreadPoolExecutor(8) as e: e.map(partial(resize_img, dirname='train_masks-128'), files) files = list((PATH/'train').iterdir()) with ThreadPoolExecutor(8) as e: e.map(partial(resize_img, dirname='train-128'), files)