%reload_ext autoreload
%autoreload 2
from fastai import *
from fastai.vision import *
from fastai.vision.models.darknet import Darknet
PATH = Path('../data/cifar10/')
ds_tfms = ([*rand_pad(4, 32), flip_lr(p=0.5)], [])
data = image_data_from_folder(PATH, valid='test', ds_tfms=ds_tfms, tfms=cifar_norm, bs=64)
learn = Learner(data, Darknet([1,1,1], 10), metrics=accuracy)
This is a toy example. We pretend that our final loss comes from two different losses loss1 and loss2.
class CombinedLoss(nn.Module):
def forward(self, output, target):
pct = uniform(0,1)
loss = F.cross_entropy(output, target)
self.loss1 = pct * loss
self.loss2 = (1-pct) * loss
return loss
class HandleDualLoss(LearnerCallback):
_order = -20 #Needs to run before the recorder
def on_train_begin(self, **kwargs):
self.learn.recorder.add_metric_names(['loss1', 'loss2'])
def on_epoch_begin(self, **kwargs):
self.avg1, self.avg2, self.nums = 0., 0., 0
def on_batch_end(self, last_target, train, **kwargs):
if not train:
bs = last_target.size(0)
self.avg1 += bs * learn.loss_fn.loss1.detach()
self.avg2 += bs * learn.loss_fn.loss2.detach()
self.nums += bs
def on_epoch_end(self, **kwargs):
self.learn.recorder.add_metrics([self.avg1/self.nums, self.avg2/self.nums])
learn.loss_fn = CombinedLoss()
learn.callback_fns.append(HandleDualLoss)
learn.fit_one_cycle(2, 3e-3, wd=0.4, div_factor=10, pct_start=0.5)
Compute the precision for the first class.
class Precision(LearnerCallback):
_order = -20 #Needs to run before the recorder
def on_train_begin(self, **kwargs):
self.learn.recorder.add_metric_names(['precision'])
def on_epoch_begin(self, **kwargs):
self.correct, self.total = 0, 0
def on_batch_end(self, last_output, last_target, train, **kwargs):
if not train:
preds = last_output.argmax(1)
pdb.set_trace()
self.correct += ((preds==0) * (last_target==0)).float().sum()
self.total += (preds==0).long().sum()
def on_epoch_end(self, **kwargs):
self.learn.recorder.add_metrics([self.correct/self.total])
learn = Learner(data, Darknet([1,1,1], 10), metrics=accuracy)
learn.callback_fns.append(Precision)
learn.fit_one_cycle(2, 3e-3, wd=0.4, div_factor=10, pct_start=0.5)
import pdb