Source code for scitex_ml.optim._optimizers

#!/usr/bin/env python3
"""Optimizer utilities using external packages."""

import torch.optim as optim

# Use pytorch-optimizer package for Ranger when available
try:
    from pytorch_optimizer import Ranger21 as Ranger

    RANGER_AVAILABLE = True
except ImportError:
    # Fallback to vendored version temporarily
    try:
        from .Ranger_Deep_Learning_Optimizer.ranger.ranger2020 import Ranger

        RANGER_AVAILABLE = True
    except ImportError:
        RANGER_AVAILABLE = False
        Ranger = None


[docs] def get_optimizer(name: str): """Get optimizer class by name. Args: name: Optimizer name (adam, ranger, rmsprop, sgd) Returns: Optimizer class Raises: ValueError: If optimizer name is not supported """ optimizers = {"adam": optim.Adam, "rmsprop": optim.RMSprop, "sgd": optim.SGD} if name == "ranger": if not RANGER_AVAILABLE: raise ImportError( "Ranger optimizer not available. " "Please install pytorch-optimizer: pip install pytorch-optimizer" ) optimizers["ranger"] = Ranger if name not in optimizers: raise ValueError( f"Unknown optimizer: {name}. Available: {list(optimizers.keys())}" ) return optimizers[name]
[docs] def set_optimizer(models, optimizer_name: str, lr: float): """Set optimizer for models. Args: models: Model or list of models optimizer_name: Name of optimizer lr: Learning rate Returns: Configured optimizer instance """ if not isinstance(models, list): models = [models] learnable_params = [] for model in models: learnable_params.extend(list(model.parameters())) optimizer_class = get_optimizer(optimizer_name) return optimizer_class(learnable_params, lr=lr)