#export from fastai2.torch_basics import * from fastai2.tabular.core import * #hide from nbdev.showdoc import * #default_exp tabular.model #export def emb_sz_rule(n_cat): "Rule of thumb to pick embedding size corresponding to `n_cat`" return min(600, round(1.6 * n_cat**0.56)) #export def _one_emb_sz(classes, n, sz_dict=None): "Pick an embedding size for `n` depending on `classes` if not given in `sz_dict`." sz_dict = ifnone(sz_dict, {}) n_cat = len(classes[n]) sz = sz_dict.get(n, int(emb_sz_rule(n_cat))) # rule of thumb return n_cat,sz #export def get_emb_sz(to, sz_dict=None): "Get default embedding size from `TabularPreprocessor` `proc` or the ones in `sz_dict`" return [_one_emb_sz(to.classes, n, sz_dict) for n in to.cat_names] #export class TabularModel(Module): "Basic model for tabular data." def __init__(self, emb_szs, n_cont, out_sz, layers, ps=None, embed_p=0., y_range=None, use_bn=True, bn_final=False, bn_cont=True): ps = ifnone(ps, [0]*len(layers)) if not is_listy(ps): ps = [ps]*len(layers) self.embeds = nn.ModuleList([Embedding(ni, nf) for ni,nf in emb_szs]) self.emb_drop = nn.Dropout(embed_p) self.bn_cont = nn.BatchNorm1d(n_cont) if bn_cont else None n_emb = sum(e.embedding_dim for e in self.embeds) self.n_emb,self.n_cont = n_emb,n_cont sizes = [n_emb + n_cont] + layers + [out_sz] actns = [nn.ReLU(inplace=True) for _ in range(len(sizes)-2)] + [None] _layers = [LinBnDrop(sizes[i], sizes[i+1], bn=use_bn and (i!=len(actns)-1 or bn_final), p=p, act=a) for i,(p,a) in enumerate(zip(ps+[0.],actns))] if y_range is not None: _layers.append(SigmoidRange(*y_range)) self.layers = nn.Sequential(*_layers) def forward(self, x_cat, x_cont=None): if self.n_emb != 0: x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)] x = torch.cat(x, 1) x = self.emb_drop(x) if self.n_cont != 0: if self.bn_cont is not None: x_cont = self.bn_cont(x_cont) x = torch.cat([x, x_cont], 1) if self.n_emb != 0 else x_cont return self.layers(x) #export @delegates(TabularModel.__init__) def tabular_config(**kwargs): "Convenience function to easily create a config for `tabular_model`" return kwargs #hide from nbdev.export import notebook2script notebook2script()