from fastai.text import *
itos = pickle.load(open('itos_tfm.pkl', 'rb'))
vocab_sz = len(itos); vocab_sz
model = get_transformer_lm(vocab_sz, 512, 12, 12, 768, 64, 768*4, 0.1, 0.1, 0.1, 0., act=Activation.GeLU, double_drop=False,
out_bias=False).cuda()
model.load_state_dict(torch.load('tfmer.pth'))
model = model.cuda()
stoi = {s:i for i,s in itos.items()}
Careful: words have a </w> flag in the vocabulary.
stoi['vanilla</w>']
itos[15000]
def textify(ids): return ' '.join([itos[i].replace('</w>', '') for i in ids])
def numericalize(text): return [stoi[f'{w}</w>'] for w in text.split(' ')]
def predict(text, n_words, topk=10, temperature=1.):
ids = numericalize(text)
x = LongTensor(ids)[None].cuda()
model.reset()
model.eval()
new_idx = []
for k in range(n_words):
out = F.softmax(model(x)[0][0,-1], dim=-1)
if temperature != 1.: out.pow_(1 / temperature)
values, indices = out.topk(topk)
next_idx = indices.gather(-1, torch.multinomial(values, 1)).item()
new_idx.append(next_idx)
x = LongTensor(ids + new_idx)[None].cuda()
model.reset()
return text + ' ' + textify(new_idx)
predict("this state has a population of", 50)