# default_exp learner #export from fastai2.data.all import * from fastai2.optimizer import * from fastai2.callback.core import * #hide from nbdev.showdoc import * #export _all_ = ['CancelFitException', 'CancelEpochException', 'CancelTrainException', 'CancelValidException', 'CancelBatchException'] #export _loop = ['Start Fit', 'begin_fit', 'Start Epoch Loop', 'begin_epoch', 'Start Train', 'begin_train', 'Start Batch Loop', 'begin_batch', 'after_pred', 'after_loss', 'after_backward', 'after_step', 'after_cancel_batch', 'after_batch','End Batch Loop','End Train', 'after_cancel_train', 'after_train', 'Start Valid', 'begin_validate','Start Batch Loop', '**CBs same as train batch**', 'End Batch Loop', 'End Valid', 'after_cancel_validate', 'after_validate', 'End Epoch Loop', 'after_cancel_epoch', 'after_epoch', 'End Fit', 'after_cancel_fit', 'after_fit'] #hide #For tests from torch.utils.data import TensorDataset def synth_dbunch(a=2, b=3, bs=16, n_train=10, n_valid=2, cuda=False): "A simple dataset where `x` is random and `y = a*x + b` plus some noise." def get_data(n): x = torch.randn(int(bs*n)) return TensorDataset(x, a*x + b + 0.1*torch.randn(int(bs*n))) train_ds = get_data(n_train) valid_ds = get_data(n_valid) device = default_device() if cuda else None train_dl = TfmdDL(train_ds, bs=bs, shuffle=True, num_workers=0) valid_dl = TfmdDL(valid_ds, bs=bs, num_workers=0) return DataLoaders(train_dl, valid_dl, device=device) class RegModel(Module): "A r" def __init__(self): self.a,self.b = nn.Parameter(torch.randn(1)),nn.Parameter(torch.randn(1)) def forward(self, x): return x*self.a + self.b # export defaults.lr = 1e-3 # export def replacing_yield(o, attr, val): "Context manager to temporarily replace an attribute" old = getattr(o,attr) try: yield setattr(o,attr,val) finally: setattr(o,attr,old) class _A: def __init__(self, a): self.a = a @contextmanager def a_changed(self, v): return replacing_yield(self, 'a', v) a = _A(42) with a.a_changed(32): test_eq(a.a, 32) test_eq(a.a, 42) #export def mk_metric(m): "Convert `m` to an `AvgMetric`, unless it's already a `Metric`" return m if isinstance(m, Metric) else AvgMetric(m) #export def save_model(file, model, opt, with_opt=True, pickle_protocol=2): "Save `model` to `file` along with `opt` (if available, and if `with_opt`)" if rank_distrib(): return # don't save if child proc if opt is None: with_opt=False state = get_model(model).state_dict() if with_opt: state = {'model': state, 'opt':opt.state_dict()} torch.save(state, file, pickle_protocol=pickle_protocol) # export def load_model(file, model, opt, with_opt=None, device=None, strict=True): "Load `model` from `file` along with `opt` (if available, and if `with_opt`)" distrib_barrier() if isinstance(device, int): device = torch.device('cuda', device) elif device is None: device = 'cpu' state = torch.load(file, map_location=device) hasopt = set(state)=={'model', 'opt'} model_state = state['model'] if hasopt else state get_model(model).load_state_dict(model_state, strict=strict) if hasopt and ifnone(with_opt,True): try: opt.load_state_dict(state['opt']) except: if with_opt: warn("Could not load the optimizer state.") elif with_opt: warn("Saved filed doesn't contain an optimizer state.") # export def _try_concat(o): try: return torch.cat(o) except: return sum([L(o_[i,:] for i in range_of(o_)) for o_ in o], L()) # export from contextlib import ExitStack #export _before_epoch = [event.begin_fit, event.begin_epoch] _after_epoch = [event.after_epoch, event.after_fit] #export class _ConstantFunc(): "Returns a function that returns `o`" def __init__(self, o): self.o = o def __call__(self, *args, **kwargs): return self.o # export @log_args(but='dls,model,opt_func,cbs') class Learner(): def __init__(self, dls, model, loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=trainable_params, cbs=None, metrics=None, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95)): store_attr(self, "dls,model,opt_func,lr,splitter,model_dir,wd,wd_bn_bias,train_bn,metrics,moms") self.training,self.create_mbar,self.logger,self.opt,self.cbs = False,True,print,None,L() if loss_func is None: loss_func = getattr(dls.train_ds, 'loss_func', None) assert loss_func is not None, "Could not infer loss function from the data, please pass a loss function." self.loss_func = loss_func self.path = Path(path) if path is not None else getattr(dls, 'path', Path('.')) self.add_cbs([(cb() if isinstance(cb, type) else cb) for cb in L(defaults.callbacks)+L(cbs)]) self.epoch,self.n_epoch,self.loss = 0,1,tensor(0.) @property def metrics(self): return self._metrics @metrics.setter def metrics(self,v): self._metrics = L(v).map(mk_metric) def _grab_cbs(self, cb_cls): return L(cb for cb in self.cbs if isinstance(cb, cb_cls)) def add_cbs(self, cbs): L(cbs).map(self.add_cb) def remove_cbs(self, cbs): L(cbs).map(self.remove_cb) def add_cb(self, cb): old = getattr(self, cb.name, None) assert not old or isinstance(old, type(cb)), f"self.{cb.name} already registered" cb.learn = self setattr(self, cb.name, cb) self.cbs.append(cb) return self def remove_cb(self, cb): if isinstance(cb, type): self.remove_cbs(self._grab_cbs(cb)) else: cb.learn = None if hasattr(self, cb.name): delattr(self, cb.name) if cb in self.cbs: self.cbs.remove(cb) @contextmanager def added_cbs(self, cbs): self.add_cbs(cbs) try: yield finally: self.remove_cbs(cbs) @contextmanager def removed_cbs(self, cbs): self.remove_cbs(cbs) try: yield self finally: self.add_cbs(cbs) def ordered_cbs(self, event): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, event)] def __call__(self, event_name): L(event_name).map(self._call_one) def _call_one(self, event_name): assert hasattr(event, event_name) [cb(event_name) for cb in sort_by_run(self.cbs)] def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state) def create_opt(self): self.opt = self.opt_func(self.splitter(self.model), lr=self.lr) if not self.wd_bn_bias: for p in self._bn_bias_state(True ): p['do_wd'] = False if self.train_bn: for p in self._bn_bias_state(False): p['force_train'] = True def _split(self, b): i = getattr(self.dls, 'n_inp', 1 if len(b)==1 else len(b)-1) self.xb,self.yb = b[:i],b[i:] def all_batches(self): self.n_iter = len(self.dl) for o in enumerate(self.dl): self.one_batch(*o) def one_batch(self, i, b): self.iter = i try: self._split(b); self('begin_batch') self.pred = self.model(*self.xb); self('after_pred') if len(self.yb) == 0: return self.loss = self.loss_func(self.pred, *self.yb); self('after_loss') if not self.training: return self.loss.backward(); self('after_backward') self.opt.step(); self('after_step') self.opt.zero_grad() except CancelBatchException: self('after_cancel_batch') finally: self('after_batch') def _do_begin_fit(self, n_epoch): self.n_epoch,self.loss = n_epoch,tensor(0.); self('begin_fit') def _do_epoch_train(self): try: self.dl = self.dls.train; self('begin_train') self.all_batches() except CancelTrainException: self('after_cancel_train') finally: self('after_train') def _do_epoch_validate(self, ds_idx=1, dl=None): if dl is None: dl = self.dls[ds_idx] try: self.dl = dl; self('begin_validate') with torch.no_grad(): self.all_batches() except CancelValidException: self('after_cancel_validate') finally: self('after_validate') def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None @log_args(but='cbs') def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False): with self.added_cbs(cbs): if reset_opt or not self.opt: self.create_opt() if wd is None: wd = self.wd if wd is not None: self.opt.set_hypers(wd=wd) self.opt.set_hypers(lr=self.lr if lr is None else lr) try: self._do_begin_fit(n_epoch) for epoch in range(n_epoch): try: self.epoch=epoch; self('begin_epoch') self._do_epoch_train() self._do_epoch_validate() except CancelEpochException: self('after_cancel_epoch') finally: self('after_epoch') except CancelFitException: self('after_cancel_fit') finally: self('after_fit') self._end_cleanup() def validate(self, ds_idx=1, dl=None, cbs=None): if dl is None: dl = self.dls[ds_idx] with self.added_cbs(cbs), self.no_logging(), self.no_mbar(): self(_before_epoch) self._do_epoch_validate(ds_idx, dl) self(_after_epoch) return getattr(self, 'final_record', None) @delegates(GatherPredsCallback.__init__) def get_preds(self, ds_idx=1, dl=None, with_input=False, with_decoded=False, with_loss=False, act=None, inner=False, reorder=True, cbs=None, **kwargs): if dl is None: dl = self.dls[ds_idx].new(shuffled=False, drop_last=False) if reorder and hasattr(dl, 'get_idxs'): idxs = dl.get_idxs() dl = dl.new(get_idxs = _ConstantFunc(idxs)) cb = GatherPredsCallback(with_input=with_input, with_loss=with_loss, **kwargs) ctx_mgrs = [self.no_logging(), self.added_cbs(L(cbs)+[cb]), self.no_mbar()] if with_loss: ctx_mgrs.append(self.loss_not_reduced()) with ExitStack() as stack: for mgr in ctx_mgrs: stack.enter_context(mgr) self(event.begin_epoch if inner else _before_epoch) self._do_epoch_validate(dl=dl) self(event.after_epoch if inner else _after_epoch) if act is None: act = getattr(self.loss_func, 'activation', noop) res = cb.all_tensors() pred_i = 1 if with_input else 0 if res[pred_i] is not None: res[pred_i] = act(res[pred_i]) if with_decoded: res.insert(pred_i+2, getattr(self.loss_func, 'decodes', noop)(res[pred_i])) if reorder and hasattr(dl, 'get_idxs'): res = nested_reorder(res, tensor(idxs).argsort()) return tuple(res) self._end_cleanup() def predict(self, item, rm_type_tfms=None, with_input=False): dl = self.dls.test_dl([item], rm_type_tfms=rm_type_tfms, num_workers=0) inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True) i = getattr(self.dls, 'n_inp', -1) inp = (inp,) if i==1 else tuplify(inp) dec = self.dls.decode_batch(inp + tuplify(dec_preds))[0] dec_inp,dec_targ = map(detuplify, [dec[:i],dec[i:]]) res = dec_targ,dec_preds[0],preds[0] if with_input: res = (dec_inp,) + res return res def show_results(self, ds_idx=1, dl=None, max_n=9, shuffle=True, **kwargs): if dl is None: dl = self.dls[ds_idx].new(shuffle=shuffle) b = dl.one_batch() _,_,preds = self.get_preds(dl=[b], with_decoded=True) self.dls.show_results(b, preds, max_n=max_n, **kwargs) def show_training_loop(self): indent = 0 for s in _loop: if s.startswith('Start'): print(f'{" "*indent}{s}'); indent += 2 elif s.startswith('End'): indent -= 2; print(f'{" "*indent}{s}') else: print(f'{" "*indent} - {s:15}:', self.ordered_cbs(s)) @contextmanager def no_logging(self): return replacing_yield(self, 'logger', noop) @contextmanager def no_mbar(self): return replacing_yield(self, 'create_mbar', False) @contextmanager def loss_not_reduced(self): if hasattr(self.loss_func, 'reduction'): return replacing_yield(self.loss_func, 'reduction', 'none') else: return replacing_yield(self, 'loss_func', partial(self.loss_func, reduction='none')) @delegates(save_model) def save(self, file, **kwargs): file = join_path_file(file, self.path/self.model_dir, ext='.pth') save_model(file, self.model, getattr(self,'opt',None), **kwargs) @delegates(load_model) def load(self, file, with_opt=None, device=None, **kwargs): if device is None: device = self.dls.device if self.opt is None: self.create_opt() file = join_path_file(file, self.path/self.model_dir, ext='.pth') load_model(file, self.model, self.opt, device=device, **kwargs) return self Learner.x,Learner.y = add_props(lambda i,x: detuplify((x.xb,x.yb)[i])) #export add_docs(Learner, "Group together a `model`, some `dls` and a `loss_func` to handle training", add_cbs="Add `cbs` to the list of `Callback` and register `self` as their learner", add_cb="Add `cb` to the list of `Callback` and register `self` as their learner", remove_cbs="Remove `cbs` from the list of `Callback` and deregister `self` as their learner", remove_cb="Add `cb` from the list of `Callback` and deregister `self` as their learner", added_cbs="Context manage that temporarily adds `cbs`", removed_cbs="Context manage that temporarily removes `cbs`", ordered_cbs="Return the list of `Callback`, in order, for an `event` in the training loop", create_opt="Create an optimizer with default hyper-parameters", one_batch="Train or evaluate `self.model` on batch `(xb,yb)`", all_batches="Train or evaluate `self.model` on all the batches of `self.dl`", fit="Fit `self.model` for `n_epoch` using `cbs`. Optionally `reset_opt`.", validate="Validate on `dl` with potential new `cbs`.", get_preds="Get the predictions and targets on the `ds_idx`-th dbunchset or `dl`, optionally `with_input` and `with_loss`", predict="Return the prediction on `item`, fully decoded, loss function decoded and probabilities", show_results="Show some predictions on `ds_idx`-th dataset or `dl`", show_training_loop="Show each step in the training loop", no_logging="Context manager to temporarily remove `logger`", no_mbar="Context manager to temporarily prevent the master progress bar from being created", loss_not_reduced="A context manager to evaluate `loss_func` with reduction set to none.", save="Save model and optimizer state (if `with_opt`) to `self.path/self.model_dir/file`", load="Load model and optimizer state (if `with_opt`) from `self.path/self.model_dir/file` using `device`", __call__="Call `event_name` for all `Callback`s in `self.cbs`" ) show_doc(Learner) show_doc(Learner.fit) #hide def synth_learner(n_train=10, n_valid=2, cuda=False, lr=defaults.lr, **kwargs): data = synth_dbunch(n_train=n_train,n_valid=n_valid, cuda=cuda) return Learner(data, RegModel(), loss_func=MSELossFlat(), lr=lr, **kwargs) #Training a few epochs should make the model better learn = synth_learner(lr=5e-2) learn.model = learn.model.cpu() xb,yb = learn.dls.one_batch() init_loss = learn.loss_func(learn.model(xb), yb) learn.fit(6) xb,yb = learn.dls.one_batch() final_loss = learn.loss_func(learn.model(xb), yb) assert final_loss < init_loss #hide #Test of TrainEvalCallback class TestTrainEvalCallback(Callback): run_after,run_valid = TrainEvalCallback,False def begin_fit(self): test_eq([self.pct_train,self.train_iter], [0., 0]) self.old_pct_train,self.old_train_iter = self.pct_train,self.train_iter def begin_batch(self): test_eq(next(self.model.parameters()).device, find_device(self.xb)) def after_batch(self): assert self.training test_eq(self.pct_train , self.old_pct_train+1/(self.n_iter*self.n_epoch)) test_eq(self.train_iter, self.old_train_iter+1) self.old_pct_train,self.old_train_iter = self.pct_train,self.train_iter def begin_train(self): assert self.training and self.model.training test_eq(self.pct_train, self.epoch/self.n_epoch) self.old_pct_train = self.pct_train def begin_validate(self): assert not self.training and not self.model.training learn = synth_learner(cbs=TestTrainEvalCallback) learn.fit(1) #Check order is properly taken into account learn.cbs = L(reversed(learn.cbs)) #hide #cuda #Check model is put on the GPU if needed learn = synth_learner(cbs=TestTrainEvalCallback, cuda=True) learn.fit(1) #hide #Check wd is not applied on bn/bias when option wd_bn_bias=False class _TstModel(nn.Module): def __init__(self): super().__init__() self.a,self.b = nn.Parameter(torch.randn(1)),nn.Parameter(torch.randn(1)) self.tst = nn.Sequential(nn.Linear(4,5), nn.BatchNorm1d(3)) self.tst[0].bias.data,self.tst[1].bias.data = torch.randn(5),torch.randn(3) def forward(self, x): return x * self.a + self.b class _PutGrad(Callback): def after_backward(self): for p in self.learn.model.tst.parameters(): p.grad = torch.ones_like(p.data) learn = synth_learner(n_train=5, opt_func = partial(SGD, wd=1, decouple_wd=True), cbs=_PutGrad) learn.model = _TstModel() init = [p.clone() for p in learn.model.tst.parameters()] learn.fit(1, lr=1e-2) end = list(learn.model.tst.parameters()) for i in [0]: assert not torch.allclose(end[i]-init[i], -0.05 * torch.ones_like(end[i])) for i in [1,2,3]: test_close(end[i]-init[i], -0.05 * torch.ones_like(end[i])) show_doc(Learner.one_batch) b = learn.dls.one_batch() learn.one_batch(0, b) test_eq(learn.x, b[0]) test_eq(learn.y, b[1]) out = learn.model(learn.x) test_eq(learn.pred, out) test_eq(learn.loss, learn.loss_func(out, b[1])) #hide class VerboseCallback(Callback): "Callback that prints the name of each event called" def __call__(self, event_name): print(event_name) super().__call__(event_name) #hide class TestOneBatch(VerboseCallback): def __init__(self, xb, yb, i): self.save_xb,self.save_yb,self.i = xb,yb,i self.old_pred,self.old_loss = None,tensor(0.) def begin_batch(self): self.old_a,self.old_b = self.model.a.data.clone(),self.model.b.data.clone() test_eq(self.iter, self.i) test_eq(self.save_xb, *self.xb) test_eq(self.save_yb, *self.yb) if hasattr(self.learn, 'pred'): test_eq(self.pred, self.old_pred) def after_pred(self): self.old_pred = self.pred test_eq(self.pred, self.model.a.data * self.x + self.model.b.data) test_eq(self.loss, self.old_loss) def after_loss(self): self.old_loss = self.loss test_eq(self.loss, self.loss_func(self.old_pred, self.save_yb)) for p in self.model.parameters(): if not hasattr(p, 'grad') or p.grad is not None: test_eq(p.grad, tensor([0.])) def after_backward(self): self.grad_a = (2 * self.x * (self.pred.data - self.y)).mean() self.grad_b = 2 * (self.pred.data - self.y).mean() test_close(self.model.a.grad.data, self.grad_a) test_close(self.model.b.grad.data, self.grad_b) test_eq(self.model.a.data, self.old_a) test_eq(self.model.b.data, self.old_b) def after_step(self): test_close(self.model.a.data, self.old_a - self.lr * self.grad_a) test_close(self.model.b.data, self.old_b - self.lr * self.grad_b) self.old_a,self.old_b = self.model.a.data.clone(),self.model.b.data.clone() test_close(self.model.a.grad.data, self.grad_a) test_close(self.model.b.grad.data, self.grad_b) def after_batch(self): for p in self.model.parameters(): test_eq(p.grad, tensor([0.])) #hide learn = synth_learner() b = learn.dls.one_batch() learn = synth_learner(cbs=TestOneBatch(*b, 42), lr=1e-2) #Remove train/eval learn.cbs = learn.cbs[1:] #Setup learn.loss,learn.training = tensor(0.),True learn.opt = SGD(learn.model.parameters(), lr=learn.lr) learn.model.train() batch_events = ['begin_batch', 'after_pred', 'after_loss', 'after_backward', 'after_step', 'after_batch'] test_stdout(lambda: learn.one_batch(42, b), '\n'.join(batch_events)) test_stdout(lambda: learn.one_batch(42, b), '\n'.join(batch_events)) #Check it works for a second batch show_doc(Learner.all_batches) #hide learn = synth_learner(n_train=5, cbs=VerboseCallback()) learn.opt = SGD(learn.model.parameters(), lr=learn.lr) with redirect_stdout(io.StringIO()): learn._do_begin_fit(1) learn.epoch,learn.dl = 0,learn.dls.train learn('begin_epoch') learn('begin_train') test_stdout(learn.all_batches, '\n'.join(batch_events * 5)) test_eq(learn.train_iter, 5) valid_events = ['begin_batch', 'after_pred', 'after_loss', 'after_batch'] with redirect_stdout(io.StringIO()): learn.dl = learn.dls.valid learn('begin_validate') test_stdout(learn.all_batches, '\n'.join(valid_events * 2)) test_eq(learn.train_iter, 5) #hide learn = synth_learner(n_train=5, cbs=VerboseCallback()) test_stdout(lambda: learn._do_begin_fit(42), 'begin_fit') test_eq(learn.n_epoch, 42) test_eq(learn.loss, tensor(0.)) #hide learn.opt = SGD(learn.model.parameters(), lr=learn.lr) learn.epoch = 0 test_stdout(lambda: learn._do_epoch_train(), '\n'.join(['begin_train'] + batch_events * 5 + ['after_train'])) #hide test_stdout(learn._do_epoch_validate, '\n'.join(['begin_validate'] + valid_events * 2+ ['after_validate'])) show_doc(Learner.create_opt) learn = synth_learner(n_train=5, cbs=VerboseCallback()) assert learn.opt is None learn.create_opt() assert learn.opt is not None test_eq(learn.opt.hypers[0]['lr'], learn.lr) show_doc(Learner.save) show_doc(Learner.load) with tempfile.TemporaryDirectory() as d: learn = synth_learner(path=d) learn.fit(1) #Test save created a file learn.save('tmp') assert (Path(d)/'models/tmp.pth').exists() #Test load did load the model learn1 = synth_learner(path=d) learn1 = learn1.load('tmp') test_eq(learn.model.a, learn1.model.a) test_eq(learn.model.b, learn1.model.b) test_eq(learn.opt.state_dict(), learn1.opt.state_dict()) #hide #Test load works when the model is saved without opt with tempfile.TemporaryDirectory() as d: learn = synth_learner(path=d) learn.fit(1) learn.save('tmp', with_opt=False) learn1 = synth_learner(path=d) learn1 = learn1.load('tmp') test_eq(learn.model.a, learn1.model.a) test_eq(learn.model.b, learn1.model.b) test_ne(learn.opt.state_dict(), learn1.opt.state_dict()) #Test init with callbacks class TstCallback(Callback): def batch_begin(self): self.learn.a = self.a + 1 tst_learn = synth_learner() test_eq(len(tst_learn.cbs), 1) assert isinstance(tst_learn.cbs[0], TrainEvalCallback) assert hasattr(tst_learn, ('train_eval')) tst_learn = synth_learner(cbs=TstCallback()) test_eq(len(tst_learn.cbs), 2) assert isinstance(tst_learn.cbs[1], TstCallback) assert hasattr(tst_learn, ('tst')) class AddCbCallback(Callback): pass test_fail(lambda: synth_learner(cbs=AddCbCallback())) show_doc(Learner.__call__) learn = synth_learner(cbs=VerboseCallback()) learn('after_fit') show_doc(Learner.add_cb) learn = synth_learner() learn.add_cb(TestTrainEvalCallback()) test_eq(len(learn.cbs), 2) assert isinstance(learn.cbs[1], TestTrainEvalCallback) test_eq(learn.train_eval.learn, learn) show_doc(Learner.add_cbs) learn.add_cbs([TestTrainEvalCallback(), TestTrainEvalCallback()]) test_eq(len(learn.cbs), 4) show_doc(Learner.added_cbs) learn = synth_learner() test_eq(len(learn.cbs), 1) with learn.added_cbs(TestTrainEvalCallback()): test_eq(len(learn.cbs), 2) show_doc(Learner.ordered_cbs) learn = synth_learner() learn.add_cb(TestTrainEvalCallback()) learn.ordered_cbs('begin_fit') show_doc(Learner.remove_cb) learn = synth_learner() learn.add_cb(TestTrainEvalCallback()) cb = learn.cbs[1] learn.remove_cb(learn.cbs[1]) test_eq(len(learn.cbs), 1) assert cb.learn is None assert not getattr(learn,'test_train_eval',None) learn = synth_learner() learn.add_cbs([TestTrainEvalCallback(), TestTrainEvalCallback()]) learn.remove_cb(TestTrainEvalCallback) test_eq(len(learn.cbs), 1) assert not getattr(learn,'test_train_eval',None) show_doc(Learner.remove_cbs) learn = synth_learner() learn.add_cbs([TestTrainEvalCallback() for _ in range(3)]) cb = learn.cbs[1] learn.remove_cbs(learn.cbs[1:]) test_eq(len(learn.cbs), 1) show_doc(Learner.removed_cbs) learn = synth_learner() learn.add_cb(TestTrainEvalCallback()) with learn.removed_cbs(learn.cbs[1]): test_eq(len(learn.cbs), 1) test_eq(len(learn.cbs), 2) show_doc(Learner.show_training_loop) learn = synth_learner() learn.show_training_loop() #export def _begin_batch_cb(self): xb,yb = f(self, self.xb, self.yb) self.learn.xb,self.learn.yb = xb,yb #export def begin_batch_cb(f): "Shortcut for creating a Callback on the `begin_batch` event, which takes and returns `xb,yb`" return Callback(begin_batch=_begin_batch_cb) class TstCallback(Callback): def begin_batch(self): self.learn.xb = self.xb + 1000 self.learn.yb = self.yb - 1000 @begin_batch_cb def cb(self, xb, yb): return xb+1000,yb-1000 #hide batch_events = ['begin_batch', 'after_pred', 'after_loss', 'after_backward', 'after_step', 'after_batch'] batchv_events = ['begin_batch', 'after_pred', 'after_loss', 'after_batch'] train_events = ['begin_train'] + batch_events + ['after_train'] valid_events = ['begin_validate'] + batchv_events + ['after_validate'] epoch_events = ['begin_epoch'] + train_events + valid_events + ['after_epoch'] cycle_events = ['begin_fit'] + epoch_events + ['after_fit'] #hide learn = synth_learner(n_train=1, n_valid=1) test_stdout(lambda: learn.fit(1, cbs=VerboseCallback()), '\n'.join(cycle_events)) #hide class TestCancelCallback(VerboseCallback): def __init__(self, cancel_at=event.begin_batch, exception=CancelBatchException, train=None): def _interrupt(): if train is None or train == self.training: raise exception() setattr(self, cancel_at, _interrupt) #hide #test cancel batch for i,e in enumerate(batch_events[:-1]): be = batch_events[:i+1] + ['after_cancel_batch', 'after_batch'] bev = be if i <3 else batchv_events cycle = cycle_events[:3] + be + ['after_train', 'begin_validate'] + bev + cycle_events[-3:] test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(cancel_at=e)), '\n'.join(cycle)) #CancelBatchException not caught if thrown in any other event for e in cycle_events: if e not in batch_events[:-1]: with redirect_stdout(io.StringIO()): cb = TestCancelCallback(cancel_at=e) test_fail(lambda: learn.fit(1, cbs=cb)) learn.remove_cb(cb) #Have to remove it manually #hide #test cancel train for i,e in enumerate(['begin_train'] + batch_events): be = batch_events[:i] + (['after_batch'] if i >=1 and i < len(batch_events) else []) be += ['after_cancel_train', 'after_train'] cycle = cycle_events[:3] + be + ['begin_validate'] + batchv_events + cycle_events[-3:] test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(e, CancelTrainException, True)), '\n'.join(cycle)) #CancelTrainException not caught if thrown in any other event for e in cycle_events: if e not in ['begin_train'] + batch_events[:-1]: with redirect_stdout(io.StringIO()): cb = TestCancelCallback(e, CancelTrainException) test_fail(lambda: learn.fit(1, cbs=cb)) learn.remove_cb(cb) #Have to remove it manually #hide #test cancel valid for i,e in enumerate(['begin_validate'] + batchv_events): bev = batchv_events[:i] + (['after_batch'] if i >=1 and i < len(batchv_events) else []) + ['after_cancel_validate'] cycle = cycle_events[:3] + batch_events + ['after_train', 'begin_validate'] + bev + cycle_events[-3:] test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(e, CancelValidException, False)), '\n'.join(cycle)) #CancelValidException not caught if thrown in any other event for e in cycle_events: if e not in ['begin_validate'] + batch_events[:3]: with redirect_stdout(io.StringIO()): cb = TestCancelCallback(e, CancelValidException) test_fail(lambda: learn.fit(1, cbs=cb)) learn.remove_cb(cb) #Have to remove it manually #hide #test cancel epoch #In train for i,e in enumerate(['begin_train'] + batch_events): be = batch_events[:i] + (['after_batch'] if i >=1 and i=1 and i=1 and i=1 and i1: val = val.clone() torch.distributed.all_reduce(val, op=torch.distributed.ReduceOp.SUM) val /= num_distrib() return val #export class AvgMetric(Metric): "Average the values of `func` taking into account potential different batch sizes" def __init__(self, func): self.func = func def reset(self): self.total,self.count = 0.,0 def accumulate(self, learn): bs = find_bs(learn.yb) self.total += to_detach(self.func(learn.pred, *learn.yb))*bs self.count += bs @property def value(self): return self.total/self.count if self.count != 0 else None @property def name(self): return self.func.func.__name__ if hasattr(self.func, 'func') else self.func.__name__ show_doc(AvgMetric, title_level=3) learn = synth_learner() tst = AvgMetric(lambda x,y: (x-y).abs().mean()) t,u = torch.randn(100),torch.randn(100) tst.reset() for i in range(0,100,25): learn.pred,learn.yb = t[i:i+25],(u[i:i+25],) tst.accumulate(learn) test_close(tst.value, (t-u).abs().mean()) #hide #With varying batch size tst.reset() splits = [0, 30, 50, 60, 100] for i in range(len(splits )-1): learn.pred,learn.yb = t[splits[i]:splits[i+1]],(u[splits[i]:splits[i+1]],) tst.accumulate(learn) test_close(tst.value, (t-u).abs().mean()) #export class AvgLoss(Metric): "Average the losses taking into account potential different batch sizes" def reset(self): self.total,self.count = 0.,0 def accumulate(self, learn): bs = find_bs(learn.yb) self.total += to_detach(learn.loss.mean())*bs self.count += bs @property def value(self): return self.total/self.count if self.count != 0 else None @property def name(self): return "loss" show_doc(AvgLoss, title_level=3) tst = AvgLoss() t = torch.randn(100) tst.reset() for i in range(0,100,25): learn.yb,learn.loss = t[i:i+25],t[i:i+25].mean() tst.accumulate(learn) test_close(tst.value, t.mean()) #hide #With varying batch size tst.reset() splits = [0, 30, 50, 60, 100] for i in range(len(splits )-1): learn.yb,learn.loss = t[splits[i]:splits[i+1]],t[splits[i]:splits[i+1]].mean() tst.accumulate(learn) test_close(tst.value, t.mean()) #export class AvgSmoothLoss(Metric): "Smooth average of the losses (exponentially weighted with `beta`)" def __init__(self, beta=0.98): self.beta = beta def reset(self): self.count,self.val = 0,tensor(0.) def accumulate(self, learn): self.count += 1 self.val = torch.lerp(to_detach(learn.loss.mean(), gather=False), self.val, self.beta) @property def value(self): return self.val/(1-self.beta**self.count) show_doc(AvgSmoothLoss, title_level=3) tst = AvgSmoothLoss() t = torch.randn(100) tst.reset() val = tensor(0.) for i in range(4): learn.loss = t[i*25:(i+1)*25].mean() tst.accumulate(learn) val = val*0.98 + t[i*25:(i+1)*25].mean()*(1-0.98) test_close(val/(1-0.98**(i+1)), tst.value) #export class ValueMetric(Metric): "Use to include a pre-calculated metric value (for insance calculated in a `Callback`) and returned by `func`" def __init__(self, func, metric_name=None): store_attr(self, 'func, metric_name') @property def value(self): return self.func() @property def name(self): return self.metric_name if self.metric_name else self.func.__name__ show_doc(ValueMetric, title_level=3) def metric_value_fn(): return 5e-3 vm = ValueMetric(metric_value_fn, 'custom_value_metric') test_eq(vm.value, 5e-3) test_eq(vm.name, 'custom_value_metric') vm = ValueMetric(metric_value_fn) test_eq(vm.name, 'metric_value_fn') #export from fastprogress.fastprogress import format_time def _maybe_item(t): t = t.value return t.item() if isinstance(t, Tensor) and t.numel()==1 else t #export class Recorder(Callback): "Callback that registers statistics (lr, loss and metrics) during training" remove_on_fetch,run_after = True,TrainEvalCallback def __init__(self, add_time=True, train_metrics=False, valid_metrics=True, beta=0.98): store_attr(self, 'add_time,train_metrics,valid_metrics') self.loss,self.smooth_loss = AvgLoss(),AvgSmoothLoss(beta=beta) def begin_fit(self): "Prepare state for training" self.lrs,self.iters,self.losses,self.values = [],[],[],[] names = self.metrics.attrgot('name') if self.train_metrics and self.valid_metrics: names = L('loss') + names names = names.map('train_{}') + names.map('valid_{}') elif self.valid_metrics: names = L('train_loss', 'valid_loss') + names else: names = L('train_loss') + names if self.add_time: names.append('time') self.metric_names = 'epoch'+names self.smooth_loss.reset() def after_batch(self): "Update all metrics and records lr and smooth loss in training" if len(self.yb) == 0: return mets = self._train_mets if self.training else self._valid_mets for met in mets: met.accumulate(self.learn) if not self.training: return self.lrs.append(self.opt.hypers[-1]['lr']) self.losses.append(self.smooth_loss.value) self.learn.smooth_loss = self.smooth_loss.value def begin_epoch(self): "Set timer if `self.add_time=True`" self.cancel_train,self.cancel_valid = False,False if self.add_time: self.start_epoch = time.time() self.log = L(getattr(self, 'epoch', 0)) def begin_train (self): self._train_mets[1:].map(Self.reset()) def begin_validate(self): self._valid_mets.map(Self.reset()) def after_train (self): self.log += self._train_mets.map(_maybe_item) def after_validate(self): self.log += self._valid_mets.map(_maybe_item) def after_cancel_train(self): self.cancel_train = True def after_cancel_validate(self): self.cancel_valid = True def after_epoch(self): "Store and log the loss/metric values" self.learn.final_record = self.log[1:].copy() self.values.append(self.learn.final_record) if self.add_time: self.log.append(format_time(time.time() - self.start_epoch)) self.logger(self.log) self.iters.append(self.smooth_loss.count) @property def _train_mets(self): if getattr(self, 'cancel_train', False): return L() return L(self.smooth_loss) + (self.metrics if self.train_metrics else L()) @property def _valid_mets(self): if getattr(self, 'cancel_valid', False): return L() return (L(self.loss) + self.metrics if self.valid_metrics else L()) def plot_loss(self, skip_start=5, with_valid=True): plt.plot(list(range(skip_start, len(self.losses))), self.losses[skip_start:], label='train') if with_valid: idx = (np.array(self.iters)