from fastai.text import * path = Config().data_path()/'giga-fren' path.ls() #with open(path/'giga-fren.release2.fixed.fr') as f: # fr = f.read().split('\n') #with open(path/'giga-fren.release2.fixed.en') as f: # en = f.read().split('\n') #re_eq = re.compile('^(Wh[^?.!]+\?)') #re_fq = re.compile('^([^?.!]+\?)') #en_fname = path/'giga-fren.release2.fixed.en' #fr_fname = path/'giga-fren.release2.fixed.fr' #lines = ((re_eq.search(eq), re_fq.search(fq)) # for eq, fq in zip(open(en_fname, encoding='utf-8'), open(fr_fname, encoding='utf-8'))) #qs = [(e.group(), f.group()) for e,f in lines if e and f] #qs = [(q1,q2) for q1,q2 in qs] #df = pd.DataFrame({'fr': [q[1] for q in qs], 'en': [q[0] for q in qs]}, columns = ['en', 'fr']) #df.to_csv(path/'questions_easy.csv', index=False) df = pd.read_csv(path/'questions_easy.csv') df.head() def seq2seq_collate(samples:BatchSamples, pad_idx:int=1, pad_first:bool=True, backwards:bool=False) -> Tuple[LongTensor, LongTensor]: "Function that collect samples and adds padding. Flips token order if needed" samples = to_data(samples) max_len_x,max_len_y = max([len(s[0]) for s in samples]),max([len(s[1]) for s in samples]) res_x = torch.zeros(len(samples), max_len_x).long() + pad_idx res_y = torch.zeros(len(samples), max_len_y).long() + pad_idx if backwards: pad_first = not pad_first for i,s in enumerate(samples): if pad_first: res_x[i,-len(s[0]):],res_y[i,-len(s[1]):] = LongTensor(s[0]),LongTensor(s[1]) else: res_x[i,:len(s[0]):],res_y[i,:len(s[1]):] = LongTensor(s[0]),LongTensor(s[1]) if backwards: res_x,res_y = res_x.flip(1),res_y.flip(1) return res_x,res_y class Seq2SeqDataBunch(TextDataBunch): "Create a `TextDataBunch` suitable for training an RNN classifier." @classmethod def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', bs:int=32, val_bs:int=None, pad_idx=1, pad_first=False, device:torch.device=None, no_check:bool=False, backwards:bool=False, **dl_kwargs) -> DataBunch: "Function that transform the `datasets` in a `DataBunch` for classification. Passes `**dl_kwargs` on to `DataLoader()`" datasets = cls._init_ds(train_ds, valid_ds, test_ds) val_bs = ifnone(val_bs, bs) collate_fn = partial(seq2seq_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards) train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0][t][0].data), bs=bs//2) train_dl = DataLoader(datasets[0], batch_size=bs, sampler=train_sampler, drop_last=True, **dl_kwargs) dataloaders = [train_dl] for ds in datasets[1:]: lengths = [len(t) for t in ds.x.items] sampler = SortSampler(ds.x, key=lengths.__getitem__) dataloaders.append(DataLoader(ds, batch_size=val_bs, sampler=sampler, **dl_kwargs)) return cls(*dataloaders, path=path, device=device, collate_fn=collate_fn, no_check=no_check) class Seq2SeqTextList(TextList): _bunch = Seq2SeqDataBunch _label_cls = TextList src = Seq2SeqTextList.from_df(df, path = path, cols='en').split_by_rand_pct().label_from_df(cols='fr', label_cls=TextList) np.percentile([len(o) for o in src.train.x.items] + [len(o) for o in src.valid.x.items], 90) np.percentile([len(o) for o in src.train.y.items] + [len(o) for o in src.valid.y.items], 90) src = src.filter_by_func(lambda x,y: len(x) > 30 or len(y) > 30) len(src.train) + len(src.valid) data = src.databunch() data.save('en2fr') data = load_data(path, 'en2fr') data.show_batch() import fastText as ft fr_vecs = ft.load_model(str((path/'cc.fr.300.bin'))) def create_emb(vecs, itos, em_sz=300, mult=1.): emb = nn.Embedding(len(itos), em_sz, padding_idx=1) wgts = emb.weight.data vec_dic = {w:vecs.get_word_vector(w) for w in vecs.get_words()} miss = [] for i,w in enumerate(itos): try: wgts[i] = tensor(vec_dic[w]) except: miss.append(w) return emb emb_dec = create_emb(fr_vecs, data.y.vocab.itos) torch.save(emb_dec, path/'models'/'fr_dec_emb.pth') del fr_vecs from fastai.text.models.qrnn import QRNN, QRNNLayer class SimpleQRNN(nn.Module): def __init__(self, vocab_sz, emb_sz=300, n_hid=256, n_layers=2, p_inp=0.1, p_hid=0.1, p_out=0.1): super().__init__() self.embed = nn.Embedding(vocab_sz, emb_sz) self.inp_drop = nn.Dropout(p_inp) self.encoder = QRNN(emb_sz, n_hid, n_layers=n_layers, dropout=p_hid, save_prev_x=True) self.out_enc = nn.Linear(n_hid, emb_sz) self.out_drop = nn.Dropout(p_out) self.decoder = nn.Linear(emb_sz, vocab_sz) self.decoder.weight = self.embed.weight self.n_layers,self.n_hid,self.bs = n_layers,n_hid,1 def forward(self, inp): if self.bs != inp.size(0): self.bs = inp.size(0) self.init_hidden(inp.size(0)) enc = self.inp_drop(self.embed(inp)) enc, h = self.encoder(enc, self.hidden) self.hidden = h.detach() return self.decoder(self.out_drop(self.out_enc(enc))) def reset(self): self.encoder.reset() def init_hidden(self, bs): self.hidden = one_param(self).new_zeros(self.n_layers, bs, self.n_hid) param = one_param(self) def convert_weights(wgts:Weights, stoi_wgts:Dict[str,int], itos_new:Collection[str]) -> Weights: "Convert the model `wgts` to go with a new vocabulary." dec_bias, enc_wgts = wgts.get('decoder.bias', None), wgts['embed.weight'] wgts_m = enc_wgts.mean(0) if dec_bias is not None: bias_m = dec_bias.mean(0) new_w = enc_wgts.new_zeros((len(itos_new),enc_wgts.size(1))).zero_() if dec_bias is not None: new_b = dec_bias.new_zeros((len(itos_new),)).zero_() for i,w in enumerate(itos_new): r = stoi_wgts[w] if w in stoi_wgts else -1 new_w[i] = enc_wgts[r] if r>=0 else wgts_m if dec_bias is not None: new_b[i] = dec_bias[r] if r>=0 else bias_m wgts['embed.weight'] = new_w wgts['decoder.weight'] = new_w.clone() if dec_bias is not None: wgts['decoder.bias'] = new_b return wgts data_lm = TextList.from_df(df, path = path, cols='en', vocab=data.x.vocab).split_by_rand_pct().label_for_lm().databunch() model = SimpleQRNN(len(data.x.vocab.itos)) pretrained_wgts = torch.load(path/'small_qrnn.pth')['model'] pretrained_vocab = Vocab.load(path/'small_qrnn_vocab.pkl') model.load_state_dict(convert_weights(pretrained_wgts, pretrained_vocab.stoi, data_lm.vocab.itos)) learn = Learner(data_lm, model, metrics=[accuracy, Perplexity()]) learn.fit(5,1e-2) learn.save('finetuned', with_opt=False) class Seq2SeqQRNN(nn.Module): def __init__(self, emb_enc, emb_dec, n_hid, max_len, n_layers=2, p_inp:float=0.15, p_enc:float=0.25, p_dec:float=0.1, p_out:float=0.35, p_hid:float=0.05, bos_idx:int=0, pad_idx:int=1): super().__init__() self.n_layers,self.n_hid,self.max_len,self.bos_idx,self.pad_idx = n_layers,n_hid,max_len,bos_idx,pad_idx self.emb_enc = emb_enc self.emb_enc_drop = nn.Dropout(p_inp) self.encoder = QRNN(emb_enc.weight.size(1), n_hid, n_layers=n_layers, dropout=p_enc) self.out_enc = nn.Linear(n_hid, emb_enc.weight.size(1), bias=False) self.hid_dp = nn.Dropout(p_hid) self.emb_dec = emb_dec self.decoder = QRNN(emb_dec.weight.size(1), emb_dec.weight.size(1), n_layers=n_layers, dropout=p_dec) self.out_drop = nn.Dropout(p_out) self.out = nn.Linear(emb_dec.weight.size(1), emb_dec.weight.size(0)) self.out.weight.data = self.emb_dec.weight.data def forward(self, inp): bs,sl = inp.size() self.encoder.reset() self.decoder.reset() hid = self.initHidden(bs) emb = self.emb_enc_drop(self.emb_enc(inp)) enc_out, hid = self.encoder(emb, hid) hid = self.out_enc(self.hid_dp(hid)) dec_inp = inp.new_zeros(bs).long() + self.bos_idx outs = [] for i in range(self.max_len): emb = self.emb_dec(dec_inp).unsqueeze(1) out, hid = self.decoder(emb, hid) out = self.out(self.out_drop(out[:,0])) outs.append(out) dec_inp = out.max(1)[1] if (dec_inp==self.pad_idx).all(): break return torch.stack(outs, dim=1) def initHidden(self, bs): return one_param(self).new_zeros(self.n_layers, bs, self.n_hid) def seq2seq_loss(out, targ, pad_idx=1): bs,targ_len = targ.size() _,out_len,vs = out.size() if targ_len>out_len: out = F.pad(out, (0,0,0,targ_len-out_len,0,0), value=pad_idx) if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx) return CrossEntropyFlat()(out, targ) def seq2seq_acc(out, targ, pad_idx=1): bs,targ_len = targ.size() _,out_len,vs = out.size() if targ_len>out_len: out = F.pad(out, (0,0,0,targ_len-out_len,0,0), value=pad_idx) if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx) out = out.argmax(2) return (out==targ).float().mean() class NGram(): def __init__(self, ngram, max_n=5000): self.ngram,self.max_n = ngram,max_n def __eq__(self, other): if len(self.ngram) != len(other.ngram): return False return np.all(np.array(self.ngram) == np.array(other.ngram)) def __hash__(self): return int(sum([o * self.max_n**i for i,o in enumerate(self.ngram)])) def get_grams(x, n, max_n=5000): return x if n==1 else [NGram(x[i:i+n], max_n=max_n) for i in range(len(x)-n+1)] def get_correct_ngrams(pred, targ, n, max_n=5000): pred_grams,targ_grams = get_grams(pred, n, max_n=max_n),get_grams(targ, n, max_n=max_n) pred_cnt,targ_cnt = Counter(pred_grams),Counter(targ_grams) return sum([min(c, targ_cnt[g]) for g,c in pred_cnt.items()]),len(pred_grams) class CorpusBLEU(Callback): def __init__(self, vocab_sz): self.vocab_sz = vocab_sz self.name = 'bleu' def on_epoch_begin(self, **kwargs): self.pred_len,self.targ_len,self.corrects,self.counts = 0,0,[0]*4,[0]*4 def on_batch_end(self, last_output, last_target, **kwargs): last_output = last_output.argmax(dim=-1) for pred,targ in zip(last_output.cpu().numpy(),last_target.cpu().numpy()): self.pred_len += len(pred) self.targ_len += len(targ) for i in range(4): c,t = get_correct_ngrams(pred, targ, i+1, max_n=self.vocab_sz) self.corrects[i] += c self.counts[i] += t def on_epoch_end(self, last_metrics, **kwargs): precs = [c/t for c,t in zip(self.corrects,self.counts)] len_penalty = exp(1 - self.targ_len/self.pred_len) if self.pred_len < self.targ_len else 1 bleu = len_penalty * ((precs[0]*precs[1]*precs[2]*precs[3]) ** 0.25) return add_metrics(last_metrics, bleu) emb_enc = learn.model.embed emb_dec = torch.load(path/'models'/'fr_dec_emb.pth') model1 = Seq2SeqQRNN(emb_enc, emb_dec, 256, 30, n_layers=2) learn1 = Learner(data, model1, loss_func=seq2seq_loss, metrics=seq2seq_acc) new_wgts = model.state_dict() wgts = model1.state_dict() for k,k1 in zip(wgts.keys(), list(new_wgts.keys())[:-4]): wgts[k].data = new_wgts[k1].data model1.load_state_dict(wgts) learn1.save('init') learn = Learner(data, model1, loss_func=seq2seq_loss, metrics=[seq2seq_acc, CorpusBLEU(len(data.y.vocab.itos))]) learn = learn.load('init') def seq2seq_split(model): return [[model.emb_enc, model.emb_enc_drop, model.encoder], [model.out_enc, model.hid_dp, model.emb_dec, model.decoder, model.out_drop, model.out]] learn = learn.split(seq2seq_split) learn.freeze() learn.lr_find() learn.recorder.plot() learn.fit_one_cycle(8, 1e-2) class TeacherForcing(LearnerCallback): def __init__(self, learn, end_epoch): super().__init__(learn) self.end_epoch = end_epoch def on_batch_begin(self, last_input, last_target, train, **kwargs): if train: return {'last_input': [last_input, last_target]} def on_epoch_begin(self, epoch, **kwargs): self.learn.model.pr_force = 1 - 0.5 * epoch/self.end_epoch class Seq2SeqQRNN(nn.Module): def __init__(self, emb_enc, emb_dec, n_hid, max_len, n_layers=2, p_inp:float=0.15, p_enc:float=0.25, p_dec:float=0.1, p_out:float=0.35, p_hid:float=0.05, bos_idx:int=0, pad_idx:int=1): super().__init__() self.n_layers,self.n_hid,self.max_len,self.bos_idx,self.pad_idx = n_layers,n_hid,max_len,bos_idx,pad_idx self.emb_enc = emb_enc self.emb_enc_drop = nn.Dropout(p_inp) self.encoder = QRNN(emb_enc.weight.size(1), n_hid, n_layers=n_layers, dropout=p_enc) self.out_enc = nn.Linear(n_hid, emb_enc.weight.size(1), bias=False) self.hid_dp = nn.Dropout(p_hid) self.emb_dec = emb_dec self.decoder = QRNN(emb_dec.weight.size(1), emb_dec.weight.size(1), n_layers=n_layers, dropout=p_dec) self.out_drop = nn.Dropout(p_out) self.out = nn.Linear(emb_dec.weight.size(1), emb_dec.weight.size(0)) self.out.weight.data = self.emb_dec.weight.data self.pr_force = 0. def forward(self, inp, targ=None): bs,sl = inp.size() hid = self.initHidden(bs) emb = self.emb_enc_drop(self.emb_enc(inp)) enc_out, hid = self.encoder(emb, hid) hid = self.out_enc(self.hid_dp(hid)) dec_inp = inp.new_zeros(bs).long() + self.bos_idx res = [] for i in range(self.max_len): emb = self.emb_dec(dec_inp).unsqueeze(1) outp, hid = self.decoder(emb, hid) outp = self.out(self.out_drop(outp[:,0])) res.append(outp) dec_inp = outp.data.max(1)[1] if (dec_inp==self.pad_idx).all(): break if (targ is not None) and (random.random()=targ.shape[1]: break dec_inp = targ[:,i] return torch.stack(res, dim=1) def initHidden(self, bs): return one_param(self).new_zeros(self.n_layers, bs, self.n_hid) model = Seq2SeqQRNN(emb_enc, emb_dec, 256, 30, n_layers=2) learn = Learner(data, model, loss_func=seq2seq_loss, metrics=[seq2seq_acc, CorpusBLEU(len(data.y.vocab.itos))], callback_fns=partial(TeacherForcing, end_epoch=8)) learn = learn.load('init') learn.fit_one_cycle(8, 1e-2) inputs, targets, outputs, tensor_inputs = get_predictions(learn) inputs[700],targets[700],outputs[700] inputs[705],targets[705],outputs[705] learn.load('init'); learn.fit_one_cycle(8, 1e-2) inputs, targets, outputs, tensor_inputs = get_predictions(learn) inputs[700],targets[700],outputs[700] inputs[705],targets[705],outputs[705] class Seq2SeqQRNN(nn.Module): def __init__(self, emb_enc, emb_dec, n_hid, max_len, n_layers=2, p_inp:float=0.15, p_enc:float=0.25, p_dec:float=0.1, p_out:float=0.35, p_hid:float=0.05, bos_idx:int=0, pad_idx:int=1): super().__init__() self.n_layers,self.n_hid,self.max_len,self.bos_idx,self.pad_idx = n_layers,n_hid,max_len,bos_idx,pad_idx self.emb_enc = emb_enc self.emb_enc_drop = nn.Dropout(p_inp) self.encoder = QRNN(emb_enc.weight.size(1), n_hid, n_layers=n_layers, dropout=p_enc, bidirectional=True) self.out_enc = nn.Linear(2*n_hid, emb_enc.weight.size(1), bias=False) self.hid_dp = nn.Dropout(p_hid) self.emb_dec = emb_dec self.decoder = QRNN(emb_dec.weight.size(1), emb_dec.weight.size(1), n_layers=n_layers, dropout=p_dec) self.out_drop = nn.Dropout(p_out) self.out = nn.Linear(emb_dec.weight.size(1), emb_dec.weight.size(0)) self.out.weight.data = self.emb_dec.weight.data self.pr_force = 0. def forward(self, inp, targ=None): bs,sl = inp.size() hid = self.initHidden(bs) emb = self.emb_enc_drop(self.emb_enc(inp)) enc_out, hid = self.encoder(emb, hid) hid = hid.view(2,self.n_layers, bs, self.n_hid).permute(1,2,0,3).contiguous() hid = self.out_enc(self.hid_dp(hid).view(self.n_layers, bs, 2*self.n_hid)) dec_inp = inp.new_zeros(bs).long() + self.bos_idx res = [] for i in range(self.max_len): emb = self.emb_dec(dec_inp).unsqueeze(1) outp, hid = self.decoder(emb, hid) outp = self.out(self.out_drop(outp[:,0])) res.append(outp) dec_inp = outp.data.max(1)[1] if (dec_inp==self.pad_idx).all(): break if (targ is not None) and (random.random()=targ.shape[1]: break dec_inp = targ[:,i] return torch.stack(res, dim=1) def initHidden(self, bs): return one_param(self).new_zeros(2*self.n_layers, bs, self.n_hid) emb_enc = torch.load(path/'models'/'fr_emb.pth') emb_dec = torch.load(path/'models'/'en_emb.pth') emb_enc = torch.load(path/'models'/'en_emb2.pth') emb_dec = torch.load(path/'models'/'fr_emb2.pth') model = Seq2SeqQRNN(emb_enc, emb_dec, 256, 30, n_layers=2) learn = Learner(data, model, loss_func=seq2seq_loss, metrics=seq2seq_acc, callback_fns=partial(TeacherForcing, end_epoch=8)) learn.lr_find() learn.recorder.plot() learn.fit_one_cycle(8, 1e-2) inputs, targets, outputs, input_tensors = get_predictions(learn) inputs[700], targets[700], outputs[700] inputs[701], targets[701], outputs[701] inputs[4001], targets[4001], outputs[4001] data = load_data(path, 'en2fr') def init_param(*sz): return nn.Parameter(torch.randn(sz)/math.sqrt(sz[0])) class Seq2SeqQRNN(nn.Module): def __init__(self, emb_enc, emb_dec, n_hid, max_len, n_layers=2, p_inp:float=0.15, p_enc:float=0.25, p_dec:float=0.1, p_out:float=0.35, p_hid:float=0.05, bos_idx:int=0, pad_idx:int=1): super().__init__() self.n_layers,self.n_hid,self.max_len,self.bos_idx,self.pad_idx = n_layers,n_hid,max_len,bos_idx,pad_idx self.emb_enc = emb_enc self.emb_enc_drop = nn.Dropout(p_inp) self.encoder = QRNN(emb_enc.weight.size(1), n_hid, n_layers=n_layers, dropout=p_enc) self.out_enc = nn.Linear(n_hid, emb_enc.weight.size(1), bias=False) self.hid_dp = nn.Dropout(p_hid) self.emb_dec = emb_dec self.decoder = QRNN(emb_dec.weight.size(1), emb_dec.weight.size(1), n_layers=n_layers, dropout=p_dec) self.out_drop = nn.Dropout(p_out) emb_sz = emb_dec.weight.size(1) self.out = nn.Linear(emb_sz, emb_dec.weight.size(0)) self.out.weight.data = self.emb_dec.weight.data #Try tying self.W1 = init_param(n_hid, emb_sz) self.l2 = nn.Linear(emb_sz, emb_sz) self.l3 = nn.Linear(emb_sz+n_hid, emb_sz) self.V = init_param(emb_sz) self.pr_force = 0. def forward(self, inp, targ=None): bs,sl = inp.size() hid = self.initHidden(bs) emb = self.emb_enc_drop(self.emb_enc(inp)) enc_out, hid = self.encoder(emb, hid) hid = self.out_enc(self.hid_dp(hid)) dec_inp = inp.new_zeros(bs).long() + self.bos_idx res = [] w1e = enc_out @ self.W1 for i in range(self.max_len): w2h = self.l2(hid[-1]) u = torch.tanh(w1e + w2h[:,None]) a = F.softmax(u @ self.V, 1) Xa = (a.unsqueeze(2) * enc_out).sum(1) emb = self.emb_dec(dec_inp) wgt_enc = self.l3(torch.cat([emb, Xa], 1)) outp, hid = self.decoder(wgt_enc[:,None], hid) outp = self.out(self.out_drop(outp[:,0])) res.append(outp) dec_inp = outp.data.max(1)[1] if (dec_inp==self.pad_idx).all(): break if (targ is not None) and (random.random()=targ.shape[1]: break dec_inp = targ[:,i] return torch.stack(res, dim=1) def initHidden(self, bs): return one_param(self).new_zeros(self.n_layers, bs, self.n_hid) #emb_enc = torch.load(path/'models'/'fr_emb.pth') #emb_dec = torch.load(path/'models'/'en_emb.pth') emb_enc = torch.load(path/'models'/'en_emb2.pth') emb_dec = torch.load(path/'models'/'fr_emb2.pth') model = Seq2SeqQRNN(emb_enc, emb_dec, 256, 30, n_layers=2) learn = Learner(data, model, loss_func=seq2seq_loss, metrics=seq2seq_acc, callback_fns=partial(TeacherForcing, end_epoch=8)) learn.load('init', strict=False); learn.fit_one_cycle(8, 3e-3) inputs, targets, outputs, input_tensors = get_predictions(learn) inputs[700], targets[700], outputs[700] inputs[701], targets[701], outputs[701] inputs[4002], targets[4002], outputs[4002] data = load_data(path) def init_param(*sz): return nn.Parameter(torch.randn(sz)/math.sqrt(sz[0])) class Seq2SeqQRNN(nn.Module): def __init__(self, emb_enc, emb_dec, n_hid, max_len, n_layers=2, p_inp:float=0.15, p_enc:float=0.25, p_dec:float=0.1, p_out:float=0.35, p_hid:float=0.05, bos_idx:int=0, pad_idx:int=1): super().__init__() self.n_layers,self.n_hid,self.max_len,self.bos_idx,self.pad_idx = n_layers,n_hid,max_len,bos_idx,pad_idx self.emb_enc = emb_enc self.emb_enc_drop = nn.Dropout(p_inp) self.encoder = QRNN(emb_enc.weight.size(1), n_hid, n_layers=n_layers, dropout=p_enc, bidirectional=True) self.out_enc = nn.Linear(2*n_hid, emb_enc.weight.size(1), bias=False) self.hid_dp = nn.Dropout(p_hid) self.emb_dec = emb_dec self.decoder = QRNN(emb_dec.weight.size(1), emb_dec.weight.size(1), n_layers=n_layers, dropout=p_dec) self.out_drop = nn.Dropout(p_out) emb_sz = emb_dec.weight.size(1) self.out = nn.Linear(emb_sz, emb_dec.weight.size(0)) self.out.weight.data = self.emb_dec.weight.data #Try tying self.W1 = init_param(2*n_hid, emb_sz) self.l2 = nn.Linear(emb_sz, emb_sz) self.l3 = nn.Linear(emb_sz+2*n_hid, emb_sz) self.V = init_param(emb_sz) self.pr_force = 0. def forward(self, inp, targ=None): bs,sl = inp.size() hid = self.initHidden(bs) emb = self.emb_enc_drop(self.emb_enc(inp)) enc_out, hid = self.encoder(emb, hid) hid = hid.view(2,self.n_layers, bs, self.n_hid).permute(1,2,0,3).contiguous() hid = self.out_enc(self.hid_dp(hid).view(self.n_layers, bs, 2*self.n_hid)) dec_inp = inp.new_zeros(bs).long() + self.bos_idx res = [] w1e = enc_out @ self.W1 for i in range(self.max_len): w2h = self.l2(hid[-1]) u = torch.tanh(w1e + w2h[:,None]) a = F.softmax(u @ self.V, 1) Xa = (a.unsqueeze(2) * enc_out).sum(1) emb = self.emb_dec(dec_inp) wgt_enc = self.l3(torch.cat([emb, Xa], 1)) outp, hid = self.decoder(wgt_enc[:,None], hid) outp = self.out(self.out_drop(outp[:,0])) res.append(outp) dec_inp = outp.data.max(1)[1] if (dec_inp==self.pad_idx).all(): break if (targ is not None) and (random.random()=targ.shape[1]: break dec_inp = targ[:,i] return torch.stack(res, dim=1) def initHidden(self, bs): return one_param(self).new_zeros(2*self.n_layers, bs, self.n_hid) class Seq2SeqQRNN(nn.Module): def __init__(self, emb_enc, emb_dec, n_hid, max_len, n_layers=2, p_inp:float=0.15, p_enc:float=0.25, p_dec:float=0.1, p_out:float=0.35, p_hid:float=0.05, bos_idx:int=0, pad_idx:int=1): super().__init__() self.n_layers,self.n_hid,self.max_len,self.bos_idx,self.pad_idx = n_layers,n_hid,max_len,bos_idx,pad_idx self.emb_enc = emb_enc self.emb_enc_drop = nn.Dropout(p_inp) self.encoder = QRNN(emb_enc.weight.size(1), n_hid, n_layers=n_layers, dropout=p_enc, bidirectional=True) self.out_enc = nn.Linear(2*n_hid, emb_enc.weight.size(1), bias=False) self.hid_dp = nn.Dropout(p_hid) self.emb_dec = emb_dec emb_sz = emb_dec.weight.size(1) self.decoder = QRNN(emb_sz + 2*n_hid, emb_dec.weight.size(1), n_layers=n_layers, dropout=p_dec) self.out_drop = nn.Dropout(p_out) self.out = nn.Linear(emb_sz, emb_dec.weight.size(0)) self.out.weight.data = self.emb_dec.weight.data #Try tying self.enc_att = nn.Linear(2*n_hid, emb_sz, bias=False) self.hid_att = nn.Linear(emb_sz, emb_sz) self.V = init_param(emb_sz) self.pr_force = 0. def forward(self, inp, targ=None): bs,sl = inp.size() hid = self.initHidden(bs) emb = self.emb_enc_drop(self.emb_enc(inp)) enc_out, hid = self.encoder(emb, hid) hid = hid.view(2,self.n_layers, bs, self.n_hid).permute(1,2,0,3).contiguous() hid = self.out_enc(self.hid_dp(hid).view(self.n_layers, bs, 2*self.n_hid)) dec_inp = inp.new_zeros(bs).long() + self.bos_idx res = [] enc_att = self.enc_att(enc_out) for i in range(self.max_len): hid_att = self.hid_att(hid[-1]) u = torch.tanh(enc_att + hid_att[:,None]) attn_wgts = F.softmax(u @ self.V, 1) ctx = (attn_wgts[...,None] * enc_out).sum(1) emb = self.emb_dec(dec_inp) outp, hid = self.decoder(torch.cat([emb, ctx], 1)[:,None], hid) outp = self.out(self.out_drop(outp[:,0])) res.append(outp) dec_inp = outp.data.max(1)[1] if (dec_inp==self.pad_idx).all(): break if (targ is not None) and (random.random()=targ.shape[1]: break dec_inp = targ[:,i] return torch.stack(res, dim=1) def initHidden(self, bs): return one_param(self).new_zeros(2*self.n_layers, bs, self.n_hid) emb_enc = torch.load(path/'models'/'fr_emb.pth') emb_dec = torch.load(path/'models'/'en_emb.pth') #emb_enc = torch.load(path/'models'/'en_emb2.pth') #emb_dec = torch.load(path/'models'/'fr_emb2.pth') model = Seq2SeqQRNN(emb_enc, emb_dec, 256, 30, n_layers=2) learn = Learner(data, model, loss_func=seq2seq_loss, metrics=seq2seq_acc, callback_fns=partial(TeacherForcing, end_epoch=8)) learn.lr_find() learn.recorder.plot() learn.fit_one_cycle(8, 3e-3) inputs, targets, outputs, input_tensors = get_predictions(learn) inputs[700], targets[700], outputs[700] inputs[701], targets[701], outputs[701] inputs[4002], targets[4002], outputs[4002]