scitex_ml.loss.multi_task_loss

Classes

MultiTaskLoss(*args, **kwargs)

# https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf

class scitex_ml.loss.multi_task_loss.MultiTaskLoss(*args: Any, **kwargs: Any)[source]

# https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf

Example

are_regression = [False, False] mtl = MultiTaskLoss(are_regression) losses = [torch.rand(1, requires_grad=True) for _ in range(len(are_regression))] loss = mtl(losses) print(loss) # [tensor([0.4215], grad_fn=<AddBackward0>), tensor([0.6190], grad_fn=<AddBackward0>)]

__init__(are_regression=[False, False], reduction='none')[source]
forward(losses)[source]