from fastai import * from fastai.vision import * class_need_activ = [nn.CrossEntropyLoss(), nn.NLLLoss(), nn.PoissonNLLLoss(), nn.KLDivLoss(), nn.BCEWithLogitsLoss()] class_need_activ += [CrossEntropyFlat()] class_names = [camel2snake(c.__class__.__name__) for c in class_need_activ] activs = [partial(F.softmax, dim=1), torch.exp, torch.exp, torch.exp, F.sigmoid, partial(F.softmax, dim=1)] loss_func_name2activ = {c:a for c,a in zip(class_names, activs)} loss_func_name2activ F_need_activ = [F.cross_entropy, F.nll_loss, F.poisson_nll_loss, F.kl_div, F.binary_cross_entropy_with_logits] for f,a in zip(F_need_activ, activs): if f.__name__ not in loss_func_name2activ: loss_func_name2activ[f.__name__] = a loss_func_name2activ def loss_func2activ(loss_func): cls_name = camel2snake(loss_func.__class__.__name__) if cls_name == 'mix_up_loss': loss_func = loss_func.crit cls_name = camel2snake(loss_func.__class__.__name__) if cls_name in loss_func_name2activ: if cls_name == 'poisson_nll_loss' and (not getattr(loss_func, 'log_input', True)): return noop return loss_func_name2activ[cls_name] if hasattr(loss_func, 'func'): if loss_func.func.__name__ == 'poisson_nll_loss' and (not loss_func.keywords.get('log_input', True)): return noop loss_func = loss_func.func if getattr(loss_func,'__name__','') in loss_func_name2activ: return loss_func_name2activ[loss_func.__name__] return noop loss_func2activ(nn.CrossEntropyLoss()) loss_func2activ(nn.NLLLoss()) loss_func2activ(nn.KLDivLoss()) loss_func2activ(nn.PoissonNLLLoss(log_input=False)) loss_func2activ(nn.PoissonNLLLoss()) loss_func2activ(nn.MSELoss()) loss_func2activ(nn.BCEWithLogitsLoss()) loss_func2activ(nn.BCELoss()) loss_func2activ(F.cross_entropy) loss_func2activ(partial(F.cross_entropy, reduce=True)) loss_func2activ(partial(F.poisson_nll_loss, log_input=False)) loss_func2activ(F.poisson_nll_loss)