Source code for scitex_ml.loss.multi_task_loss

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Time-stamp: "2024-11-07 19:07:29 (ywatanabe)"
# File: ./scitex_repo/src/scitex/ai/loss/MultiTaskLoss.py

import numpy as np
import torch
import torch.nn as nn

from scitex_repro import fix_seeds


[docs] class MultiTaskLoss(nn.Module): """ # https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf Example: are_regression = [False, False] mtl = MultiTaskLoss(are_regression) losses = [torch.rand(1, requires_grad=True) for _ in range(len(are_regression))] loss = mtl(losses) print(loss) # [tensor([0.4215], grad_fn=<AddBackward0>), tensor([0.6190], grad_fn=<AddBackward0>)] """
[docs] def __init__(self, are_regression=[False, False], reduction="none"): super().__init__() fix_seeds(np=np, torch=torch, show=False) n_tasks = len(are_regression) self.register_buffer("are_regression", torch.tensor(are_regression)) # for the numercal stability, log(variables) are learned. self.log_vars = torch.nn.Parameter(torch.zeros(n_tasks)) self.reduction = reduction
[docs] def forward(self, losses): vars = torch.exp(self.log_vars).type_as(losses[0]) stds = vars ** (1 / 2) coeffs = 1 / ((self.are_regression + 1) * vars) scaled_losses = [ coeffs[i] * losses[i] + torch.log(stds[i]) for i in range(len(losses)) ] return scaled_losses
# EOF