#!/usr/bin/env python # coding: utf-8 # # Neural text generation # In[1]: from seq2seq import * # In[2]: path = Config().data_path()/'giga-fren' data = load_data(path) model_path = Config().model_path() emb_enc = torch.load(model_path/'fr_emb.pth') emb_dec = torch.load(model_path/'en_emb.pth') # In[3]: class Seq2SeqRNN_attn(nn.Module): def __init__(self, emb_enc, emb_dec, nh, out_sl, nl=2, bos_idx=0, pad_idx=1): super().__init__() self.nl,self.nh,self.out_sl,self.pr_force = nl,nh,out_sl,1 self.bos_idx,self.pad_idx = bos_idx,pad_idx self.emb_enc,self.emb_dec = emb_enc,emb_dec self.emb_sz_enc,self.emb_sz_dec = emb_enc.embedding_dim,emb_dec.embedding_dim self.voc_sz_dec = emb_dec.num_embeddings self.emb_enc_drop = nn.Dropout(0.15) self.gru_enc = nn.GRU(self.emb_sz_enc, nh, num_layers=nl, dropout=0.25, batch_first=True, bidirectional=True) self.out_enc = nn.Linear(2*nh, self.emb_sz_dec, bias=False) self.gru_dec = nn.GRU(self.emb_sz_dec + 2*nh, self.emb_sz_dec, num_layers=nl, dropout=0.1, batch_first=True) self.out_drop = nn.Dropout(0.35) self.out = nn.Linear(self.emb_sz_dec, self.voc_sz_dec) self.out.weight.data = self.emb_dec.weight.data self.enc_att = nn.Linear(2*nh, self.emb_sz_dec, bias=False) self.hid_att = nn.Linear(self.emb_sz_dec, self.emb_sz_dec) self.V = self.init_param(self.emb_sz_dec) def encoder(self, bs, inp): h = self.initHidden(bs) emb = self.emb_enc_drop(self.emb_enc(inp)) enc_out, hid = self.gru_enc(emb, 2*h) pre_hid = hid.view(2, self.nl, bs, self.nh).permute(1,2,0,3).contiguous() pre_hid = pre_hid.view(self.nl, bs, 2*self.nh) hid = self.out_enc(pre_hid) return hid,enc_out def decoder(self, dec_inp, hid, enc_att, enc_out): hid_att = self.hid_att(hid[-1]) u = torch.tanh(enc_att + hid_att[:,None]) attn_wgts = F.softmax(u @ self.V, 1) ctx = (attn_wgts[...,None] * enc_out).sum(1) emb = self.emb_dec(dec_inp) outp, hid = self.gru_dec(torch.cat([emb, ctx], 1)[:,None], hid) outp = self.out(self.out_drop(outp[:,0])) return hid, outp def forward(self, inp, targ=None): bs, sl = inp.size() hid,enc_out = self.encoder(bs, inp) dec_inp = inp.new_zeros(bs).long() + self.bos_idx enc_att = self.enc_att(enc_out) res = [] for i in range(self.out_sl): hid, outp = self.decoder(dec_inp, hid, enc_att, enc_out) res.append(outp) dec_inp = outp.max(1)[1] if (dec_inp==self.pad_idx).all(): break if (targ is not None) and (random.random()=targ.shape[1]: continue dec_inp = targ[:,i] return torch.stack(res, dim=1) def initHidden(self, bs): return one_param(self).new_zeros(2*self.nl, bs, self.nh) def init_param(self, *sz): return nn.Parameter(torch.randn(sz)/math.sqrt(sz[0])) # In[4]: model = Seq2SeqRNN_attn(emb_enc, emb_dec, 256, 30) learn = Learner(data, model, loss_func=seq2seq_loss, metrics=seq2seq_acc, callback_fns=partial(TeacherForcing, end_epoch=30)) # In[14]: learn.fit_one_cycle(5, 3e-3) # In[19]: # learn.save('5') # In[5]: learn.load('5'); # In[7]: def preds_acts(learn, ds_type=DatasetType.Valid): "Same as `get_predictions` but also returns non-reconstructed activations" learn.model.eval() ds = learn.data.train_ds rxs,rys,rzs,xs,ys,zs = [],[],[],[],[],[] # 'r' == 'reconstructed' with torch.no_grad(): for xb,yb in progress_bar(learn.dl(ds_type)): out = learn.model(xb) for x,y,z in zip(xb,yb,out): rxs.append(ds.x.reconstruct(x)) rys.append(ds.y.reconstruct(y)) preds = z.argmax(1) rzs.append(ds.y.reconstruct(preds)) for a,b in zip([xs,ys,zs],[x,y,z]): a.append(b) return rxs,rys,rzs,xs,ys,zs # In[8]: rxs,rys,rzs,xs,ys,zs = preds_acts(learn) # In[10]: idx=701 rx,ry,rz = rxs[idx],rys[idx],rzs[idx] x,y,z = xs[idx],ys[idx],zs[idx] rx,ry,rz # In[11]: def select_topk(outp, k=5): probs = F.softmax(outp,dim=-1) vals,idxs = probs.topk(k, dim=-1) return idxs[torch.randint(k, (1,))] # From [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751). # In[12]: from random import choice def select_nucleus(outp, p=0.5): probs = F.softmax(outp,dim=-1) idxs = torch.argsort(probs, descending=True) res,cumsum = [],0. for idx in idxs: res.append(idx) cumsum += probs[idx] if cumsum>p: return idxs.new_tensor([choice(res)]) # In[13]: def decode(self, inp): inp = inp[None] bs, sl = inp.size() hid,enc_out = self.encoder(bs, inp) dec_inp = inp.new_zeros(bs).long() + self.bos_idx enc_att = self.enc_att(enc_out) res = [] for i in range(self.out_sl): hid, outp = self.decoder(dec_inp, hid, enc_att, enc_out) dec_inp = select_nucleus(outp[0], p=0.3) # dec_inp = select_topk(outp[0], k=2) res.append(dec_inp) if (dec_inp==self.pad_idx).all(): break return torch.cat(res) # In[14]: def predict_with_decode(learn, x, y): learn.model.eval() ds = learn.data.train_ds with torch.no_grad(): out = decode(learn.model, x) rx = ds.x.reconstruct(x) ry = ds.y.reconstruct(y) rz = ds.y.reconstruct(out) return rx,ry,rz # In[91]: rx,ry,rz = predict_with_decode(learn, x, y) rz # In[ ]: