Source code for scitex_ml.loss._L1L2Losses

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Time-stamp: "2024-11-07 18:53:03 (ywatanabe)"
# File: ./scitex_repo/src/scitex/ai/loss/_L1L2Losses.py

import torch


[docs] def l1(model, lambda_l1=0.01): lambda_l1 = torch.tensor(lambda_l1) l1 = torch.tensor(0.0).cuda() for param in model.parameters(): # fixme; is this OK? l1 += torch.abs(param).sum() return l1
[docs] def l2(model, lambda_l2=0.01): lambda_l2 = torch.tensor(lambda_l2) l2 = torch.tensor(0.0).cuda() for param in model.parameters(): # fixme; is this OK? l2 += torch.norm(param).sum() return l2
[docs] def elastic(model, alpha=1.0, l1_ratio=0.5): assert 0 <= l1_ratio <= 1 L1 = l1(model) L2 = l2(model) return alpha * (l1_ratio * L1 + (1 - l1_ratio) * L2)
# EOF