scitex_ml.loss.multi_task_loss
Classes
|
- class scitex_ml.loss.multi_task_loss.MultiTaskLoss(*args: Any, **kwargs: Any)[source]
-
Example
are_regression = [False, False] mtl = MultiTaskLoss(are_regression) losses = [torch.rand(1, requires_grad=True) for _ in range(len(are_regression))] loss = mtl(losses) print(loss) # [tensor([0.4215], grad_fn=<AddBackward0>), tensor([0.6190], grad_fn=<AddBackward0>)]