from fastai.basic_train import * from fastai.gen_doc.nbdoc import * from fastai import * from fastai.vision import * show_doc(Learner, title_level=2) path = untar_data(URLs.MNIST_SAMPLE) data = ImageDataBunch.from_folder(path) learn = create_cnn(data, models.resnet18, metrics=accuracy) learn.fit(1) show_doc(Learner.fit) show_doc(Learner.fit_one_cycle) show_doc(Learner.lr_find) show_doc(Learner.get_preds) show_doc(Learner.validate) show_doc(Learner.TTA, full_name = 'TTA') show_doc(Learner.to_fp16) # creates 3 layer groups learn.split(lambda m: (m[0][6], m[1])) # only randomly initialized head now trainable learn.freeze() learn.fit_one_cycle(1) # all layers now trainable learn.unfreeze() # optionally, separate LR and WD for each group learn.fit_one_cycle(1, max_lr=(1e-4, 1e-3, 1e-2), wd=(1e-4,1e-4,1e-1)) show_doc(Learner.lr_range) learn.lr_range(slice(1e-5,1e-3)), learn.lr_range(slice(3e-4)) show_doc(Learner.unfreeze) show_doc(Learner.freeze) show_doc(Learner.freeze_to) show_doc(Learner.split) show_doc(Learner.load) show_doc(Learner.save) show_doc(Learner.show_results) show_doc(Learner.predict) show_doc(Learner.validate) show_doc(Learner.create_unet, doc_string=False) show_doc(Learner.init) show_doc(Learner.mixup) show_doc(Learner.pred_batch) show_doc(Learner.create_opt) show_doc(Learner.dl) show_doc(Recorder, title_level=2) show_doc(Recorder.plot) learn = create_cnn(data, models.resnet18, metrics=accuracy) learn.lr_find() learn.recorder.plot() show_doc(Recorder.plot_losses) learn.fit_one_cycle(2) learn.recorder.plot_losses() show_doc(Recorder.plot_lr) learn.recorder.plot_lr(show_moms=True) show_doc(Recorder.plot_metrics) learn.recorder.plot_metrics() show_doc(Recorder.on_backward_begin) show_doc(Recorder.on_batch_begin) show_doc(Recorder.on_epoch_end) show_doc(Recorder.on_train_begin) show_doc(fit) show_doc(train_epoch) show_doc(validate) show_doc(get_preds) show_doc(loss_batch) show_doc(LearnerCallback, title_level=3) show_doc(Learner.tta_only) show_doc(Learner.get_preds) show_doc(Learner.TTA) show_doc(Recorder.format_stats) show_doc(Recorder.add_metrics) show_doc(Recorder.add_metric_names)