%reload_ext autoreload %autoreload 2 #export from nb_006 import * import gc PATH = Path('data/carvana') PATH_PNG = PATH/'train_masks_png' PATH_X_FULL = PATH/'train' PATH_X_128 = PATH/'train-128' PATH_Y_FULL = PATH_PNG PATH_Y_128 = PATH/'train_masks-128' PATH_X = PATH_X_128 PATH_Y = PATH_Y_128 def get_y_fn(x_fn): return PATH_Y/f'{x_fn.name[:-4]}_mask.png' def get_datasets(path): x_fns = [o for o in path.iterdir() if o.is_file()] y_fns = [get_y_fn(o) for o in x_fns] mask = [o>=1008 for o in range(len(x_fns))] arrs = arrays_split(mask, x_fns, y_fns) return [MatchedImageDataset(*o) for o in arrs] size=128 def get_tfm_datasets(size): datasets = get_datasets(PATH_X_128 if size<=128 else PATH_X_FULL) tfms = get_transforms(do_flip=True, max_rotate=4, max_lighting=0.2) return transform_datasets(*datasets, tfms=tfms, tfm_y=True, size=size) default_norm,default_denorm = normalize_funcs(*imagenet_stats) bs = 32 def get_data(size, bs): return DataBunch.create(*get_tfm_datasets(size), bs=bs, tfms=default_norm) data = get_data(size, bs) hook_outputs #export Sizes = List[List[int]] def in_channels(m:Model) -> List[int]: "Returns the shape of the first weight layer" for l in flatten_model(m): if hasattr(l, 'weight'): return l.weight.shape[1] raise Exception('No weight layer') def model_sizes(m:Model, size:tuple=(256,256), full:bool=True) -> Tuple[Sizes,Tensor,Hooks]: "Passes a dummy input through the model to get the various sizes" hooks = hook_outputs(m) ch_in = in_channels(m) x = torch.zeros(1,ch_in,*size) x = m.eval()(x) res = [o.stored.shape for o in hooks] if not full: hooks.remove() return res,x,hooks if full else res def get_sfs_idxs(sizes:Sizes, last:bool=True) -> List[int]: "Get the indexes of the layers where the size of the activation changes" if last: feature_szs = [size[-1] for size in sizes] sfs_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0]) if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs else: sfs_idxs = list(range(len(sfs))) return sfs_idxs #export class UnetBlock(nn.Module): "An basic unet block" def __init__(self, up_in_c:int, x_in_c:int, hook:Hook): super().__init__() self.hook = hook ni = up_in_c self.upconv = conv2d_trans(ni, ni//2) # H, W -> 2H, 2W ni = ni//2 + x_in_c self.conv1 = conv2d(ni, ni//2) ni = ni//2 self.conv2 = conv2d(ni, ni) self.bn = nn.BatchNorm2d(ni) def forward(self, up_in:Tensor) -> Tensor: up_out = self.upconv(up_in) cat_x = torch.cat([up_out, self.hook.stored], dim=1) x = F.relu(self.conv1(cat_x)) x = F.relu(self.conv2(x)) return self.bn(x) #export class DynamicUnet(nn.Sequential): "Unet created from a given architecture" def __init__(self, encoder:Model, n_classes:int, last:bool=True): imsize = (256,256) sfs_szs,x,self.sfs = model_sizes(encoder, size=imsize) sfs_idxs = reversed(get_sfs_idxs(sfs_szs, last)) ni = sfs_szs[-1][1] middle_conv = nn.Sequential(conv2d_relu(ni, ni*2, bn=True), conv2d_relu(ni*2, ni, bn=True)) x = middle_conv(x) layers = [encoder, nn.ReLU(), middle_conv] for idx in sfs_idxs: up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1]) unet_block = UnetBlock(up_in_c, x_in_c, self.sfs[idx]) layers.append(unet_block) x = unet_block(x) ni = unet_block.conv2.out_channels if imsize != sfs_szs[0][-2:]: layers.append(conv2d_trans(ni, ni)) layers.append(conv2d(ni, n_classes, 1)) super().__init__(*layers) def __del__(self): if hasattr(self, "sfs"): self.sfs.remove() metrics=[accuracy_thresh,dice] lr = 1e-3 class CrossEntropyFlat(nn.CrossEntropyLoss): def forward(self, input, target): n,c,*_ = input.shape return super().forward(input.view(n, c, -1), target.view(n, -1)) body = create_body(tvm.resnet34(True), 2) model = DynamicUnet(body, n_classes=2).cuda() learn = Learner(data, model, metrics=metrics, loss_fn=CrossEntropyFlat()) learn.split([model[0][6], model[1]]) learn.freeze() lr_find(learn) learn.recorder.plot() learn.fit_one_cycle(1, slice(lr), pct_start=0.05) learn.fit_one_cycle(6, slice(lr), pct_start=0.05) learn.save('u0') x,py = learn.pred_batch() for i, ax in enumerate(plt.subplots(4,4,figsize=(10,10))[1].flat): show_image(default_denorm(x[i].cpu()), py[i]>0, ax=ax) learn.unfreeze() lr=1e-3 learn.fit_one_cycle(6, slice(lr/100,lr), pct_start=0.05) size=512 bs = 8 learn.data = get_data(size, bs) learn.freeze()