#export
from fastai2.basics import *
from fastai2.text.core import *
from fastai2.text.data import *
from fastai2.text.models.core import *
from fastai2.text.models.awdlstm import *
from fastai2.callback.rnn import *
from fastai2.callback.progress import *
#hide
from nbdev.showdoc import *
#default_exp text.learner
All the functions necessary to build
Learnersuitable for transfer learning in NLP
The most important functions of this module are language_model_learner and text_classifier_learner. They will help you define a Learner using a pretrained model. See the text tutorial for exmaples of use.
In text, to load a pretrained model, we need to adapt the embeddings of the vocabulary used for the pre-training to the vocabulary of our current corpus.
#export
def match_embeds(old_wgts, old_vocab, new_vocab):
"Convert the embedding in `old_wgts` to go from `old_vocab` to `new_vocab`."
bias, wgts = old_wgts.get('1.decoder.bias', None), old_wgts['0.encoder.weight']
wgts_m = wgts.mean(0)
new_wgts = wgts.new_zeros((len(new_vocab),wgts.size(1)))
if bias is not None:
bias_m = bias.mean(0)
new_bias = bias.new_zeros((len(new_vocab),))
old_o2i = old_vocab.o2i if hasattr(old_vocab, 'o2i') else {w:i for i,w in enumerate(old_vocab)}
for i,w in enumerate(new_vocab):
idx = old_o2i.get(w, -1)
new_wgts[i] = wgts[idx] if idx>=0 else wgts_m
if bias is not None: new_bias[i] = bias[idx] if idx>=0 else bias_m
old_wgts['0.encoder.weight'] = new_wgts
if '0.encoder_dp.emb.weight' in old_wgts: old_wgts['0.encoder_dp.emb.weight'] = new_wgts.clone()
old_wgts['1.decoder.weight'] = new_wgts.clone()
if bias is not None: old_wgts['1.decoder.bias'] = new_bias
return old_wgts
For words in new_vocab that don't have a corresponding match in old_vocab, we use the mean of all pretrained embeddings.
wgts = {'0.encoder.weight': torch.randn(5,3)}
new_wgts = match_embeds(wgts.copy(), ['a', 'b', 'c'], ['a', 'c', 'd', 'b'])
old,new = wgts['0.encoder.weight'],new_wgts['0.encoder.weight']
test_eq(new[0], old[0])
test_eq(new[1], old[2])
test_eq(new[2], old.mean(0))
test_eq(new[3], old[1])
#hide
#With bias
wgts = {'0.encoder.weight': torch.randn(5,3), '1.decoder.bias': torch.randn(5)}
new_wgts = match_embeds(wgts.copy(), ['a', 'b', 'c'], ['a', 'c', 'd', 'b'])
old_w,new_w = wgts['0.encoder.weight'],new_wgts['0.encoder.weight']
old_b,new_b = wgts['1.decoder.bias'], new_wgts['1.decoder.bias']
test_eq(new_w[0], old_w[0])
test_eq(new_w[1], old_w[2])
test_eq(new_w[2], old_w.mean(0))
test_eq(new_w[3], old_w[1])
test_eq(new_b[0], old_b[0])
test_eq(new_b[1], old_b[2])
test_eq(new_b[2], old_b.mean(0))
test_eq(new_b[3], old_b[1])
#export
def _get_text_vocab(dls):
vocab = dls.vocab
if isinstance(vocab, L): vocab = vocab[0]
return vocab
#export
def load_ignore_keys(model, wgts):
"Load `wgts` in `model` ignoring the names of the keys, just taking parameters in order"
sd = model.state_dict()
for k1,k2 in zip(sd.keys(), wgts.keys()): sd[k1].data = wgts[k2].data.clone()
return model.load_state_dict(sd)
#export
def _rm_module(n):
t = n.split('.')
for i in range(len(t)-1, -1, -1):
if t[i] == 'module':
t.pop(i)
break
return '.'.join(t)
#export
#For previous versions compatibility, remove for release
def clean_raw_keys(wgts):
keys = list(wgts.keys())
for k in keys:
t = k.split('.module')
if f'{_rm_module(k)}_raw' in keys: del wgts[k]
return wgts
#export
#For previous versions compatibility, remove for release
def load_model_text(file, model, opt, with_opt=None, device=None, strict=True):
"Load `model` from `file` along with `opt` (if available, and if `with_opt`)"
distrib_barrier()
if isinstance(device, int): device = torch.device('cuda', device)
elif device is None: device = 'cpu'
state = torch.load(file, map_location=device)
hasopt = set(state)=={'model', 'opt'}
model_state = state['model'] if hasopt else state
get_model(model).load_state_dict(clean_raw_keys(model_state), strict=strict)
if hasopt and ifnone(with_opt,True):
try: opt.load_state_dict(state['opt'])
except:
if with_opt: warn("Could not load the optimizer state.")
elif with_opt: warn("Saved filed doesn't contain an optimizer state.")
#export
@log_args(but_as=Learner.__init__)
@delegates(Learner.__init__)
class TextLearner(Learner):
"Basic class for a `Learner` in NLP."
def __init__(self, dls, model, alpha=2., beta=1., moms=(0.8,0.7,0.8), **kwargs):
super().__init__(dls, model, moms=moms, **kwargs)
self.add_cbs([ModelResetter(), RNNRegularizer(alpha=alpha, beta=beta)])
def save_encoder(self, file):
"Save the encoder to `file` in the model directory"
if rank_distrib(): return # don't save if child proc
encoder = get_model(self.model)[0]
if hasattr(encoder, 'module'): encoder = encoder.module
torch.save(encoder.state_dict(), join_path_file(file, self.path/self.model_dir, ext='.pth'))
def load_encoder(self, file, device=None):
"Load the encoder `file` from the model directory, optionally ensuring it's on `device`"
encoder = get_model(self.model)[0]
if device is None: device = self.dls.device
if hasattr(encoder, 'module'): encoder = encoder.module
distrib_barrier()
wgts = torch.load(join_path_file(file,self.path/self.model_dir, ext='.pth'), map_location=device)
encoder.load_state_dict(clean_raw_keys(wgts))
self.freeze()
return self
def load_pretrained(self, wgts_fname, vocab_fname, model=None):
"Load a pretrained model and adapt it to the data vocabulary."
old_vocab = Path(vocab_fname).load()
new_vocab = _get_text_vocab(self.dls)
distrib_barrier()
wgts = torch.load(wgts_fname, map_location = lambda storage,loc: storage)
if 'model' in wgts: wgts = wgts['model'] #Just in case the pretrained model was saved with an optimizer
wgts = match_embeds(wgts, old_vocab, new_vocab)
load_ignore_keys(self.model if model is None else model, clean_raw_keys(wgts))
self.freeze()
return self
#For previous versions compatibility. Remove at release
@delegates(load_model_text)
def load(self, file, with_opt=None, device=None, **kwargs):
if device is None: device = self.dls.device
if self.opt is None: self.create_opt()
file = join_path_file(file, self.path/self.model_dir, ext='.pth')
load_model_text(file, self.model, self.opt, device=device, **kwargs)
return self
Adds a ModelResetter and an RNNRegularizer with alpha and beta to the callbacks, the rest is the same as Learner init.
This Learner adds functionality to the base class:
show_doc(TextLearner.load_pretrained)
TextLearner.load_pretrained[source]
TextLearner.load_pretrained(wgts_fname,vocab_fname,model=None)
Load a pretrained model and adapt it to the data vocabulary.
wgts_fname should point to the weights of the pretrained model and vocab_fname to the vocabulary used to pretrain it.
show_doc(TextLearner.save_encoder)
TextLearner.save_encoder[source]
TextLearner.save_encoder(file)
Save the encoder to file in the model directory
The model directory is Learner.path/Learner.model_dir.
show_doc(TextLearner.load_encoder)
TextLearner.load_encoder[source]
TextLearner.load_encoder(file,device=None)
Load the encoder file from the model directory, optionally ensuring it's on device
For language modeling, the predict method is quite different form the other applications, which is why it needs its own subclass.
#export
def decode_spec_tokens(tokens):
"Decode the special tokens in `tokens`"
new_toks,rule,arg = [],None,None
for t in tokens:
if t in [TK_MAJ, TK_UP, TK_REP, TK_WREP]: rule = t
elif rule is None: new_toks.append(t)
elif rule == TK_MAJ:
new_toks.append(t[:1].upper() + t[1:].lower())
rule = None
elif rule == TK_UP:
new_toks.append(t.upper())
rule = None
elif arg is None:
try: arg = int(t)
except: rule = None
else:
if rule == TK_REP: new_toks.append(t * arg)
else: new_toks += [t] * arg
return new_toks
test_eq(decode_spec_tokens(['xxmaj', 'text']), ['Text'])
test_eq(decode_spec_tokens(['xxup', 'text']), ['TEXT'])
test_eq(decode_spec_tokens(['xxrep', '3', 'a']), ['aaa'])
test_eq(decode_spec_tokens(['xxwrep', '3', 'word']), ['word', 'word', 'word'])
#export
@log_args(but_as=TextLearner.__init__)
class LMLearner(TextLearner):
"Add functionality to `TextLearner` when dealingwith a language model"
def predict(self, text, n_words=1, no_unk=True, temperature=1., min_p=None, no_bar=False,
decoder=decode_spec_tokens, only_last_word=False):
"Return `text` and the `n_words` that come after"
self.model.reset()
idxs = idxs_all = self.dls.test_dl([text]).items[0].to(self.dls.device)
if no_unk: unk_idx = self.dls.vocab.index(UNK)
for _ in (range(n_words) if no_bar else progress_bar(range(n_words), leave=False)):
with self.no_bar(): preds,_ = self.get_preds(dl=[(idxs[None],)])
res = preds[0][-1]
if no_unk: res[unk_idx] = 0.
if min_p is not None:
if (res >= min_p).float().sum() == 0:
warn(f"There is no item with probability >= {min_p}, try a lower value.")
else: res[res < min_p] = 0.
if temperature != 1.: res.pow_(1 / temperature)
idx = torch.multinomial(res, 1).item()
idxs = idxs_all = torch.cat([idxs_all, idxs.new([idx])])
if only_last_word: idxs = idxs[-1][None]
num = self.dls.train_ds.numericalize
tokens = [num.vocab[i] for i in idxs_all if num.vocab[i] not in [BOS, PAD]]
sep = self.dls.train_ds.tokenizer.sep
return sep.join(decoder(tokens))
@delegates(Learner.get_preds)
def get_preds(self, concat_dim=1, **kwargs): return super().get_preds(concat_dim=1, **kwargs)
show_doc(LMLearner, title_level=3)
class LMLearner[source]
LMLearner(dls,model,alpha=2.0,beta=1.0,moms=(0.8, 0.7, 0.8),loss_func=None,opt_func='Adam',lr=0.001,splitter='trainable_params',cbs=None,metrics=None,path=None,model_dir='models',wd=None,wd_bn_bias=False,train_bn=True) ::TextLearner
Add functionality to TextLearner when dealingwith a language model
show_doc(LMLearner.predict)
The words are picked randomly among the predictions, depending on the probability of each index. no_unk means we never pick the UNK token, tempreature is applied to the predictions, if min_p is passed, we don't consider the indices with a probability lower than it. Set no_bar to True if you don't want any progress bar, and you can pass a long a custom decoder to process the predicted tokens.
Learner convenience functions¶#export
from fastai2.text.models.core import _model_meta
#export
def _get_text_vocab(dls):
vocab = dls.vocab
if isinstance(vocab, L): vocab = vocab[0]
return vocab
#export
@log_args(to_return=True, but_as=Learner.__init__)
@delegates(Learner.__init__)
def language_model_learner(dls, arch, config=None, drop_mult=1., backwards=False, pretrained=True, pretrained_fnames=None, **kwargs):
"Create a `Learner` with a language model from `dls` and `arch`."
vocab = _get_text_vocab(dls)
model = get_language_model(arch, len(vocab), config=config, drop_mult=drop_mult)
meta = _model_meta[arch]
learn = LMLearner(dls, model, loss_func=CrossEntropyLossFlat(), splitter=meta['split_lm'], **kwargs)
url = 'url_bwd' if backwards else 'url'
if pretrained or pretrained_fnames:
if pretrained_fnames is not None:
fnames = [learn.path/learn.model_dir/f'{fn}.{ext}' for fn,ext in zip(pretrained_fnames, ['pth', 'pkl'])]
else:
if url not in meta:
warn("There are no pretrained weights for that architecture yet!")
return learn
model_path = untar_data(meta[url] , c_key='model')
fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]
learn = learn.load_pretrained(*fnames)
return learn
You can use the config to customize the architecture used (change the values from awd_lstm_lm_config for this), pretrained will use fastai's pretrained model for this arch (if available) or you can pass specific pretrained_fnames containing your own pretrained model and the corresponding vocabulary. All other arguments are passed to Learner.
path = untar_data(URLs.IMDB_SAMPLE)
df = pd.read_csv(path/'texts.csv')
dls = TextDataLoaders.from_df(df, path=path, text_col='text', is_lm=True, valid_col='is_valid')
learn = language_model_learner(dls, AWD_LSTM)
You can then use the .predict method to generate new text.
learn.predict('This movie is about', n_words=20)
'This movie is about front - line highlights for fifteen United States Navy and the US Navy , four'
By default the entire sentence is feed again to the model after each predicted word, this little trick shows an improvement on the quality of the generated text. If you want to feed only the last word, specify argument only_last_word.
learn.predict('This movie is about', n_words=20, only_last_word=True)
'This movie is about 6 No person in Suppose he was used searching for the late in West of other important'
#export
@log_args(to_return=True, but_as=Learner.__init__)
@delegates(Learner.__init__)
def text_classifier_learner(dls, arch, seq_len=72, config=None, backwards=False, pretrained=True, drop_mult=0.5, n_out=None,
lin_ftrs=None, ps=None, max_len=72*20, y_range=None, **kwargs):
"Create a `Learner` with a text classifier from `dls` and `arch`."
vocab = _get_text_vocab(dls)
if n_out is None: n_out = get_c(dls)
assert n_out, "`n_out` is not defined, and could not be infered from data, set `dls.c` or pass `n_out`"
model = get_text_classifier(arch, len(vocab), n_out, seq_len=seq_len, config=config, y_range=y_range,
drop_mult=drop_mult, lin_ftrs=lin_ftrs, ps=ps, max_len=max_len)
meta = _model_meta[arch]
learn = TextLearner(dls, model, splitter=meta['split_clas'], **kwargs)
url = 'url_bwd' if backwards else 'url'
if pretrained:
if url not in meta:
warn("There are no pretrained weights for that architecture yet!")
return learn
model_path = untar_data(meta[url], c_key='model')
fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]
learn = learn.load_pretrained(*fnames, model=learn.model[0])
learn.freeze()
return learn
You can use the config to customize the architecture used (change the values from awd_lstm_clas_config for this), pretrained will use fastai's pretrained model for this arch (if available). drop_mult is a global multiplier applied to control all dropouts. n_out is usually infered from the dls but you may pass it.
The model uses a SentenceEncoder, which means the texts are passed seq_len tokens at a time, and will only compute the gradients on the last max_len steps. lin_ftrs and ps are passed to get_text_classifier.
All other arguments are passed to Learner.
path = untar_data(URLs.IMDB_SAMPLE)
df = pd.read_csv(path/'texts.csv')
dls = TextDataLoaders.from_df(df, path=path, text_col='text', label_col='label', valid_col='is_valid')
learn = text_classifier_learner(dls, AWD_LSTM)
#export
@typedispatch
def show_results(x: LMTensorText, y, samples, outs, ctxs=None, max_n=10, **kwargs):
if ctxs is None: ctxs = get_empty_df(min(len(samples), max_n))
for i,l in enumerate(['input', 'target']):
ctxs = [b.show(ctx=c, label=l, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]
ctxs = [b.show(ctx=c, label='pred', **kwargs) for b,c,_ in zip(outs.itemgot(0),ctxs,range(max_n))]
display_df(pd.DataFrame(ctxs))
return ctxs
#export
@typedispatch
def show_results(x: TensorText, y, samples, outs, ctxs=None, max_n=10, trunc_at=150, **kwargs):
if ctxs is None: ctxs = get_empty_df(min(len(samples), max_n))
samples = L((s[0].truncate(trunc_at),*s[1:]) for s in samples)
ctxs = show_results[object](x, y, samples, outs, ctxs=ctxs, max_n=max_n, **kwargs)
display_df(pd.DataFrame(ctxs))
return ctxs
#export
@typedispatch
def plot_top_losses(x: TensorText, y:TensorCategory, samples, outs, raws, losses, trunc_at=150, **kwargs):
rows = get_empty_df(len(samples))
samples = L((s[0].truncate(trunc_at),*s[1:]) for s in samples)
for i,l in enumerate(['input', 'target']):
rows = [b.show(ctx=c, label=l, **kwargs) for b,c in zip(samples.itemgot(i),rows)]
outs = L(o + (TitledFloat(r.max().item()), TitledFloat(l.item())) for o,r,l in zip(outs, raws, losses))
for i,l in enumerate(['predicted', 'probability', 'loss']):
rows = [b.show(ctx=c, label=l, **kwargs) for b,c in zip(outs.itemgot(i),rows)]
display_df(pd.DataFrame(rows))
#hide
from nbdev.export import notebook2script
notebook2script()
Converted 00_torch_core.ipynb. Converted 01_layers.ipynb. Converted 02_data.load.ipynb. Converted 03_data.core.ipynb. Converted 04_data.external.ipynb. Converted 05_data.transforms.ipynb. Converted 06_data.block.ipynb. Converted 07_vision.core.ipynb. Converted 08_vision.data.ipynb. Converted 09_vision.augment.ipynb. Converted 09b_vision.utils.ipynb. Converted 09c_vision.widgets.ipynb. Converted 10_tutorial.pets.ipynb. Converted 11_vision.models.xresnet.ipynb. Converted 12_optimizer.ipynb. Converted 13_callback.core.ipynb. Converted 13a_learner.ipynb. Converted 13b_metrics.ipynb. Converted 14_callback.schedule.ipynb. Converted 14a_callback.data.ipynb. Converted 15_callback.hook.ipynb. Converted 15a_vision.models.unet.ipynb. Converted 16_callback.progress.ipynb. Converted 17_callback.tracker.ipynb. Converted 18_callback.fp16.ipynb. Converted 18a_callback.training.ipynb. Converted 19_callback.mixup.ipynb. Converted 20_interpret.ipynb. Converted 20a_distributed.ipynb. Converted 21_vision.learner.ipynb. Converted 22_tutorial.imagenette.ipynb. Converted 23_tutorial.vision.ipynb. Converted 24_tutorial.siamese.ipynb. Converted 24_vision.gan.ipynb. Converted 30_text.core.ipynb. Converted 31_text.data.ipynb. Converted 32_text.models.awdlstm.ipynb. Converted 33_text.models.core.ipynb. Converted 34_callback.rnn.ipynb. Converted 35_tutorial.wikitext.ipynb. Converted 36_text.models.qrnn.ipynb. Converted 37_text.learner.ipynb. Converted 38_tutorial.text.ipynb. Converted 39_tutorial.transformers.ipynb. Converted 40_tabular.core.ipynb. Converted 41_tabular.data.ipynb. Converted 42_tabular.model.ipynb. Converted 43_tabular.learner.ipynb. Converted 44_tutorial.tabular.ipynb. Converted 45_collab.ipynb. Converted 46_tutorial.collab.ipynb. Converted 50_tutorial.datablock.ipynb. Converted 60_medical.imaging.ipynb. Converted 61_tutorial.medical_imaging.ipynb. Converted 65_medical.text.ipynb. Converted 70_callback.wandb.ipynb. Converted 71_callback.tensorboard.ipynb. Converted 72_callback.neptune.ipynb. Converted 73_callback.captum.ipynb. Converted 74_callback.cutmix.ipynb. Converted 97_test_utils.ipynb. Converted 99_pytorch_doc.ipynb. Converted index.ipynb. Converted tutorial.ipynb.