scitex_ml.loss

Scitex loss module.

scitex_ml.loss.elastic(model, alpha=1.0, l1_ratio=0.5)[source]
scitex_ml.loss.l1(model, lambda_l1=0.01)[source]
scitex_ml.loss.l2(model, lambda_l2=0.01)[source]
class scitex_ml.loss.MultiTaskLoss(*args: Any, **kwargs: Any)[source]

# https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf

Example

are_regression = [False, False] mtl = MultiTaskLoss(are_regression) losses = [torch.rand(1, requires_grad=True) for _ in range(len(are_regression))] loss = mtl(losses) print(loss) # [tensor([0.4215], grad_fn=<AddBackward0>), tensor([0.6190], grad_fn=<AddBackward0>)]

__init__(are_regression=[False, False], reduction='none')[source]
forward(losses)[source]

Modules

multi_task_loss