%matplotlib inline %reload_ext autoreload %autoreload 2 from nb_007 import * PATH = Path('../data/cifar10/') torch.backends.cudnn.benchmark = True class Lambda(nn.Module): def __init__(self, func): super().__init__() self.func=func def forward(self, x): return self.func(x) def ResizeBatch(*size): return Lambda(lambda x: x.view((-1,)+size)) def Flatten(): return Lambda(lambda x: x.view((x.size(0), -1))) def PoolFlatten(): return nn.Sequential(nn.AdaptiveAvgPool2d(1), Flatten()) def conv_2d(ni, nf, ks, stride): return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=False) def bn(ni, init_zero=False): m = nn.BatchNorm2d(ni) m.weight.data.fill_(0 if init_zero else 1) m.bias.data.zero_() return m def bn_relu_conv(ni, nf, ks, stride, init_zero=False): bn_initzero = bn(ni, init_zero=init_zero) return nn.Sequential(bn_initzero, nn.ReLU(inplace=True), conv_2d(ni, nf, ks, stride)) def noop(x): return x class BasicBlock(nn.Module): def __init__(self, ni, nf, stride, drop_p=0.0): super().__init__() self.bn = nn.BatchNorm2d(ni) self.conv1 = conv_2d(ni, nf, 3, stride) self.conv2 = bn_relu_conv(nf, nf, 3, 1) self.drop = nn.Dropout(drop_p, inplace=True) if drop_p else None self.shortcut = conv_2d(ni, nf, 1, stride) if ni != nf else noop def forward(self, x): x2 = F.relu(self.bn(x), inplace=True) r = self.shortcut(x2) x = self.conv1(x2) if self.drop: x = self.drop(x) x = self.conv2(x) * 0.2 return x.add_(r) def _make_group(N, ni, nf, block, stride, drop_p): return [block(ni if i == 0 else nf, nf, stride if i == 0 else 1, drop_p) for i in range(N)] class WideResNet(nn.Module): def __init__(self, num_groups, N, num_classes, k=1, drop_p=0.0, start_nf=16): super().__init__() n_channels = [start_nf] for i in range(num_groups): n_channels.append(start_nf*(2**i)*k) layers = [conv_2d(3, n_channels[0], 3, 1)] # conv1 for i in range(num_groups): layers += _make_group(N, n_channels[i], n_channels[i+1], BasicBlock, (1 if i==0 else 2), drop_p) layers += [nn.BatchNorm2d(n_channels[3]), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(1), Flatten(), nn.Linear(n_channels[3], num_classes)] self.features = nn.Sequential(*layers) def forward(self, x): return self.features(x) def wrn_22(): return WideResNet(num_groups=3, N=3, num_classes=10, k=6, drop_p=0.) model = wrn_22() train_ds,valid_ds = ImageDataset.from_folder(PATH/'train'), ImageDataset.from_folder(PATH/'test') cifar_mean,cifar_std = map(tensor, ([0.4914, 0.48216, 0.44653], [0.24703, 0.24349, 0.26159])) cifar_norm,cifar_denorm = normalize_funcs(cifar_mean,cifar_std) train_tfms = [pad(padding=4), crop(size=32, row_pct=(0,1), col_pct=(0,1)), flip_lr(p=0.5)] data = DataBunch.create(train_ds, valid_ds, bs=512, train_tfm=train_tfms, tfms=cifar_norm, num_workers=8) model = wrn_22() learn = Learner(data, model) learn.metrics = [accuracy] learn.fit_one_cycle(1, 3e-3, wd=0.4, div_factor=10) learn.recorder.plot_lr() model = wrn_22() model = model2half(model) learn = Learner(data, model) learn.metrics = [accuracy] learn.callbacks.append(MixedPrecision(learn)) %time learn.fit_one_cycle(25, 3e-3, wd=0.4, div_factor=10) %time learn.fit_one_cycle(30, 3e-3, wd=0.4, div_factor=10)