%load_ext autoreload %autoreload 2 %matplotlib inline #export from exp.nb_07 import * x_train,y_train,x_valid,y_valid = get_data() x_train,x_valid = normalize_to(x_train,x_valid) train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid) nh,bs = 50,512 c = y_train.max().item()+1 loss_func = F.cross_entropy data = DataBunch(*get_dls(train_ds, valid_ds, bs), c) mnist_view = view_tfm(1,28,28) cbfs = [Recorder, partial(AvgStatsCallback,accuracy), CudaCallback, partial(BatchTransformXCallback, mnist_view)] nfs = [8,16,32,64,64] class ConvLayer(nn.Module): def __init__(self, ni, nf, ks=3, stride=2, sub=0., **kwargs): super().__init__() self.conv = nn.Conv2d(ni, nf, ks, padding=ks//2, stride=stride, bias=True) self.relu = GeneralRelu(sub=sub, **kwargs) def forward(self, x): return self.relu(self.conv(x)) @property def bias(self): return -self.relu.sub @bias.setter def bias(self,v): self.relu.sub = -v @property def weight(self): return self.conv.weight learn,run = get_learn_run(nfs, data, 0.6, ConvLayer, cbs=cbfs) run.fit(2, learn) learn,run = get_learn_run(nfs, data, 0.6, ConvLayer, cbs=cbfs) #export def get_batch(dl, run): run.xb,run.yb = next(iter(dl)) for cb in run.cbs: cb.set_runner(run) run('begin_batch') return run.xb,run.yb xb,yb = get_batch(data.train_dl, run) #export def find_modules(m, cond): if cond(m): return [m] return sum([find_modules(o,cond) for o in m.children()], []) def is_lin_layer(l): lin_layers = (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear, nn.ReLU) return isinstance(l, lin_layers) mods = find_modules(learn.model, lambda o: isinstance(o,ConvLayer)) mods def append_stat(hook, mod, inp, outp): d = outp.data hook.mean,hook.std = d.mean().item(),d.std().item() mdl = learn.model.cuda() with Hooks(mods, append_stat) as hooks: mdl(xb) for hook in hooks: print(hook.mean,hook.std) #export def lsuv_module(m, xb): h = Hook(m, append_stat) while mdl(xb) is not None and abs(h.mean) > 1e-3: m.bias -= h.mean while mdl(xb) is not None and abs(h.std-1) > 1e-3: m.weight.data /= h.std h.remove() return h.mean,h.std for m in mods: print(lsuv_module(m, xb)) %time run.fit(2, learn) !python notebook2script.py 07a_lsuv.ipynb