#!/usr/bin/env python # coding: utf-8 # This notebook is adapted from [this one](https://github.com/fastai/fastai_docs/blob/master/dev_course/dl2/translation_transformer.ipynb) created by Sylvain Gugger. # # See also [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html) from Harvard NLP. # # Attention and the Transformer # Nvidia AI researcher [Chip Huyen](https://huyenchip.com/) wrote a great post [Top 8 trends from ICLR 2019](https://huyenchip.com/2019/05/12/top-8-trends-from-iclr-2019.html) in which one of the trends is that *RNN is losing its luster with researchers*. # # There's good reason for this, RNNs can be a pain: parallelization can be tricky and they can be difficult to debug. Since language is recursive, it seemed like RNNs were a good conceptual fit with NLP, but recently methods using *attention* have been achieving state of the art results on NLP. # # This is still an area of very active research, for instance, a recent paper [Pay Less Attention with Lightweight and Dynamic Convolutions](https://arxiv.org/abs/1901.10430) showed that convolutions can beat attention on some tasks, including English to German translation. More research is needed on the various strenghts of RNNs, CNNs, and transformers/attention, and perhaps on approaches to combine the best of each. # In[1]: from fastai.text import * # In[2]: path = Config().data_path()/'giga-fren' path.ls() # ## Load data # We reuse the same functions as in the translation notebook to load our data. # In[3]: def seq2seq_collate(samples, pad_idx=1, pad_first=True, backwards=False): "Function that collect samples and adds padding. Flips token order if needed" samples = to_data(samples) max_len_x,max_len_y = max([len(s[0]) for s in samples]),max([len(s[1]) for s in samples]) res_x = torch.zeros(len(samples), max_len_x).long() + pad_idx res_y = torch.zeros(len(samples), max_len_y).long() + pad_idx if backwards: pad_first = not pad_first for i,s in enumerate(samples): if pad_first: res_x[i,-len(s[0]):],res_y[i,-len(s[1]):] = LongTensor(s[0]),LongTensor(s[1]) else: res_x[i, :len(s[0])],res_y[i, :len(s[1])] = LongTensor(s[0]),LongTensor(s[1]) if backwards: res_x,res_y = res_x.flip(1),res_y.flip(1) return res_x, res_y # In[4]: class Seq2SeqDataBunch(TextDataBunch): "Create a `TextDataBunch` suitable for training an RNN classifier." @classmethod def create(cls, train_ds, valid_ds, test_ds=None, path='.', bs=32, val_bs=None, pad_idx=1, dl_tfms=None, pad_first=False, device=None, no_check=False, backwards=False, **dl_kwargs): "Function that transform the `datasets` in a `DataBunch` for classification. Passes `**dl_kwargs` on to `DataLoader()`" datasets = cls._init_ds(train_ds, valid_ds, test_ds) val_bs = ifnone(val_bs, bs) collate_fn = partial(seq2seq_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards) train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0][t][0].data), bs=bs//2) train_dl = DataLoader(datasets[0], batch_size=bs, sampler=train_sampler, drop_last=True, **dl_kwargs) dataloaders = [train_dl] for ds in datasets[1:]: lengths = [len(t) for t in ds.x.items] sampler = SortSampler(ds.x, key=lengths.__getitem__) dataloaders.append(DataLoader(ds, batch_size=val_bs, sampler=sampler, **dl_kwargs)) return cls(*dataloaders, path=path, device=device, collate_fn=collate_fn, no_check=no_check) # In[5]: class Seq2SeqTextList(TextList): _bunch = Seq2SeqDataBunch _label_cls = TextList # Refer to notebook 7-seq2seq-translation for the code we used to create, process, and save this data. # In[6]: data = load_data(path) # In[7]: data.show_batch() # ## Transformer model # ![Transformer model](images/Transformer.png) # ### Shifting # We add a transform to the dataloader that shifts the targets right and adds a padding at the beginning. # In[8]: v = data.vocab # In[9]: v.stoi['xxpad'] # In[10]: def shift_tfm(b): x,y = b y = F.pad(y, (1, 0), value=1) return [x,y[:,:-1]], y[:,1:] # In[11]: data.add_tfm(shift_tfm) # ### Embeddings # The input and output embeddings are traditional PyTorch embeddings (and we can use pretrained vectors if we want to). The transformer model isn't a recurrent one, so it has no idea of the relative positions of the words. To help it with that, they had to the input embeddings a positional encoding which is cosine of a certain frequency: # In[156]: d = 30 torch.arange(0., d, 2.)/d # In[12]: class PositionalEncoding(nn.Module): "Encode the position with a sinusoid." def __init__(self, d): super().__init__() self.register_buffer('freq', 1 / (10000 ** (torch.arange(0., d, 2.)/d))) def forward(self, pos): inp = torch.ger(pos, self.freq) enc = torch.cat([inp.sin(), inp.cos()], dim=-1) return enc # In[13]: tst_encoding = PositionalEncoding(20) res = tst_encoding(torch.arange(0,100).float()) _, ax = plt.subplots(1,1) for i in range(1,5): ax.plot(res[:,i]) # In[14]: res[:6,:6] # In[15]: class TransformerEmbedding(nn.Module): "Embedding + positional encoding + dropout" def __init__(self, vocab_sz, emb_sz, inp_p=0.): super().__init__() self.emb_sz = emb_sz self.embed = embedding(vocab_sz, emb_sz) self.pos_enc = PositionalEncoding(emb_sz) self.drop = nn.Dropout(inp_p) def forward(self, inp): pos = torch.arange(0, inp.size(1), device=inp.device).float() return self.drop(self.embed(inp) * math.sqrt(self.emb_sz) + self.pos_enc(pos)) # ### Feed forward # The feed forward cell is easy: it's just two linear layers with a skip connection and a LayerNorm. # In[16]: def feed_forward(d_model, d_ff, ff_p=0., double_drop=True): layers = [nn.Linear(d_model, d_ff), nn.ReLU()] if double_drop: layers.append(nn.Dropout(ff_p)) return SequentialEx(*layers, nn.Linear(d_ff, d_model), nn.Dropout(ff_p), MergeLayer(), nn.LayerNorm(d_model)) # ### Multi-head attention # ![Multi head attention](images/attention.png) # In[117]: class MultiHeadAttention(nn.Module): def __init__(self, n_heads, d_model, d_head=None, p=0., bias=True, scale=True): super().__init__() d_head = ifnone(d_head, d_model//n_heads) self.n_heads,self.d_head,self.scale = n_heads,d_head,scale self.q_wgt,self.k_wgt,self.v_wgt = [nn.Linear( d_model, n_heads * d_head, bias=bias) for o in range(3)] self.out = nn.Linear(n_heads * d_head, d_model, bias=bias) self.drop_att,self.drop_res = nn.Dropout(p),nn.Dropout(p) self.ln = nn.LayerNorm(d_model) def forward(self, q, kv, mask=None): return self.ln(q + self.drop_res(self.out(self._apply_attention(q, kv, mask=mask)))) def create_attn_mat(self, x, layer, bs): return layer(x).view(bs, x.size(1), self.n_heads, self.d_head ).permute(0, 2, 1, 3) def _apply_attention(self, q, kv, mask=None): bs,seq_len = q.size(0),q.size(1) wq,wk,wv = map(lambda o: self.create_attn_mat(*o,bs), zip((q,kv,kv),(self.q_wgt,self.k_wgt,self.v_wgt))) attn_score = wq @ wk.transpose(2,3) if self.scale: attn_score /= math.sqrt(self.d_head) if mask is not None: attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score) attn_prob = self.drop_att(F.softmax(attn_score, dim=-1)) attn_vec = attn_prob @ wv return attn_vec.permute(0, 2, 1, 3).contiguous().view(bs, seq_len, -1) # ### Masking # The attention layer uses a mask to avoid paying attention to certain timesteps. The first thing is that we don't really want the network to pay attention to the padding, so we're going to mask it. The second thing is that since this model isn't recurrent, we need to mask (in the output) all the tokens we're not supposed to see yet (otherwise it would be cheating). # In[18]: def get_output_mask(inp, pad_idx=1): return torch.triu(inp.new_ones(inp.size(1),inp.size(1)), diagonal=1)[None,None].byte() # return ((inp == pad_idx)[:,None,:,None].long() + torch.triu(inp.new_ones(inp.size(1),inp.size(1)), diagonal=1)[None,None] != 0) # Example of mask for the future tokens: # In[19]: torch.triu(torch.ones(10,10), diagonal=1).byte() # ### Encoder and decoder blocks # We are now ready to regroup these layers in the blocks we add in the model picture: # # ![Transformer model](images/Transformer.png) # In[125]: class EncoderBlock(nn.Module): "Encoder block of a Transformer model." #Can't use Sequential directly cause more than one input... def __init__(self, n_heads, d_model, d_head, d_inner, p=0., bias=True, scale=True, double_drop=True): super().__init__() self.mha = MultiHeadAttention(n_heads, d_model, d_head, p=p, bias=bias, scale=scale) self.ff = feed_forward(d_model, d_inner, ff_p=p, double_drop=double_drop) def forward(self, x, mask=None): return self.ff(self.mha(x, x, mask=mask)) # In[126]: class DecoderBlock(nn.Module): "Decoder block of a Transformer model." #Can't use Sequential directly cause more than one input... def __init__(self, n_heads, d_model, d_head, d_inner, p=0., bias=True, scale=True, double_drop=True): super().__init__() self.mha1 = MultiHeadAttention(n_heads, d_model, d_head, p=p, bias=bias, scale=scale) self.mha2 = MultiHeadAttention(n_heads, d_model, d_head, p=p, bias=bias, scale=scale) self.ff = feed_forward(d_model, d_inner, ff_p=p, double_drop=double_drop) def forward(self, x, enc, mask_out=None): return self.ff(self.mha2(self.mha1(x, x, mask_out), enc)) # ### The whole model # In[127]: class Transformer(Module): def __init__(self, inp_vsz, out_vsz, n_layers=6, n_heads=8, d_model=256, d_head=32, d_inner=1024, p=0.1, bias=True, scale=True, double_drop=True, pad_idx=1): self.enc_emb = TransformerEmbedding(inp_vsz, d_model, p) self.dec_emb = TransformerEmbedding(out_vsz, d_model, 0.) args = (n_heads, d_model, d_head, d_inner, p, bias, scale, double_drop) self.encoder = nn.ModuleList([EncoderBlock(*args) for _ in range(n_layers)]) self.decoder = nn.ModuleList([DecoderBlock(*args) for _ in range(n_layers)]) self.out = nn.Linear(d_model, out_vsz) self.out.weight = self.dec_emb.embed.weight self.pad_idx = pad_idx def forward(self, inp, out): mask_out = get_output_mask(out, self.pad_idx) enc,out = self.enc_emb(inp),self.dec_emb(out) enc = compose(self.encoder)(enc) out = compose(self.decoder)(out, enc, mask_out) return self.out(out) # #### Bleu metric (see dedicated notebook) # In[23]: class NGram(): def __init__(self, ngram, max_n=5000): self.ngram,self.max_n = ngram,max_n def __eq__(self, other): if len(self.ngram) != len(other.ngram): return False return np.all(np.array(self.ngram) == np.array(other.ngram)) def __hash__(self): return int(sum([o * self.max_n**i for i,o in enumerate(self.ngram)])) # In[24]: def get_grams(x, n, max_n=5000): return x if n==1 else [NGram(x[i:i+n], max_n=max_n) for i in range(len(x)-n+1)] # In[25]: def get_correct_ngrams(pred, targ, n, max_n=5000): pred_grams,targ_grams = get_grams(pred, n, max_n=max_n),get_grams(targ, n, max_n=max_n) pred_cnt,targ_cnt = Counter(pred_grams),Counter(targ_grams) return sum([min(c, targ_cnt[g]) for g,c in pred_cnt.items()]),len(pred_grams) # In[26]: class CorpusBLEU(Callback): def __init__(self, vocab_sz): self.vocab_sz = vocab_sz self.name = 'bleu' def on_epoch_begin(self, **kwargs): self.pred_len,self.targ_len,self.corrects,self.counts = 0,0,[0]*4,[0]*4 def on_batch_end(self, last_output, last_target, **kwargs): last_output = last_output.argmax(dim=-1) for pred,targ in zip(last_output.cpu().numpy(),last_target.cpu().numpy()): self.pred_len += len(pred) self.targ_len += len(targ) for i in range(4): c,t = get_correct_ngrams(pred, targ, i+1, max_n=self.vocab_sz) self.corrects[i] += c self.counts[i] += t def on_epoch_end(self, last_metrics, **kwargs): precs = [c/t for c,t in zip(self.corrects,self.counts)] len_penalty = exp(1 - self.targ_len/self.pred_len) if self.pred_len < self.targ_len else 1 bleu = len_penalty * ((precs[0]*precs[1]*precs[2]*precs[3]) ** 0.25) return add_metrics(last_metrics, bleu) # ### Training # In[128]: n_x_vocab,n_y_vocab = len(data.train_ds.x.vocab.itos), len(data.train_ds.y.vocab.itos) model = Transformer(n_x_vocab, n_y_vocab, d_model=256) learn = Learner(data, model, metrics=[accuracy, CorpusBLEU(n_y_vocab)], loss_func = CrossEntropyFlat()) # In[129]: learn.lr_find() learn.recorder.plot() # In[154]: learn.fit_one_cycle(8, 5e-4, div_factor=5) # In[49]: def get_predictions(learn, ds_type=DatasetType.Valid): learn.model.eval() inputs, targets, outputs = [],[],[] with torch.no_grad(): for xb,yb in progress_bar(learn.dl(ds_type)): out = learn.model(*xb) for x,y,z in zip(xb[0],xb[1],out): inputs.append(learn.data.train_ds.x.reconstruct(x)) targets.append(learn.data.train_ds.y.reconstruct(y)) outputs.append(learn.data.train_ds.y.reconstruct(z.argmax(1))) return inputs, targets, outputs # In[50]: inputs, targets, outputs = get_predictions(learn) # In[ ]: inputs[10],targets[10],outputs[10] # In[ ]: inputs[700],targets[700],outputs[700] # In[ ]: inputs[701],targets[701],outputs[701] # In[ ]: inputs[2500],targets[2500],outputs[2500] # In[ ]: inputs[4002],targets[4002],outputs[4002] # ### Label smoothing # They point out in the paper that using label smoothing helped getting a better BLEU/accuracy, even if it made the loss worse. # In[51]: model = Transformer(len(data.train_ds.x.vocab.itos), len(data.train_ds.y.vocab.itos), d_model=256) # In[52]: learn = Learner(data, model, metrics=[accuracy, CorpusBLEU(len(data.train_ds.y.vocab.itos))], loss_func=FlattenedLoss(LabelSmoothingCrossEntropy, axis=-1)) # In[53]: learn.fit_one_cycle(8, 5e-4, div_factor=5) # In[ ]: learn.fit_one_cycle(8, 5e-4, div_factor=5) # In[ ]: print("Quels sont les atouts particuliers du Canada en recherche sur l'obésité sur la scène internationale ?") print("What are Specific strengths canada strengths in obesity - ? are up canada ? from international international stage ?") print("Quelles sont les répercussions politiques à long terme de cette révolution scientifique mondiale ?") print("What are the long the long - term policies implications of this global scientific ? ?") # In[ ]: inputs[10],targets[10],outputs[10] # In[ ]: inputs[700],targets[700],outputs[700] # In[ ]: inputs[701],targets[701],outputs[701] # In[ ]: inputs[4001],targets[4001],outputs[4001] # ### Test leakage # If we change a token in the targets at position n, it shouldn't impact the predictions before that. # In[ ]: learn.model.eval(); # In[ ]: xb,yb = data.one_batch(cpu=False) # In[ ]: inp1,out1 = xb[0][:1],xb[1][:1] inp2,out2 = inp1.clone(),out1.clone() out2[0,15] = 10 # In[ ]: y1 = learn.model(inp1, out1) y2 = learn.model(inp2, out2) # In[ ]: (y1[0,:15] - y2[0,:15]).abs().mean()