Coverage for yield_analysis_sdk\validators.py: 78%
59 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-02 10:00 +0800
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-02 10:00 +0800
1"""
2Common validators and mixins for the yield analysis SDK.
3"""
5import re
6from typing import Any, TYPE_CHECKING, Union
7from pydantic import field_validator
8from .exceptions import ValidationError
10if TYPE_CHECKING:
11 from .type import Chain
14class ChainValidatorMixin:
15 """Mixin class that provides chain validation functionality."""
17 @field_validator("chain", mode="before")
18 @classmethod
19 def validate_chain(cls, v: Any) -> "Chain":
20 """Validate chain and return OTHER if not found."""
21 from .type import Chain # Import here to avoid circular import
23 if isinstance(v, str):
24 try:
25 return Chain(v)
26 except ValueError:
27 return Chain.OTHER
28 elif isinstance(v, Chain):
29 return v
30 else:
31 return Chain.OTHER
34class VaultAddressValidatorMixin:
35 """Mixin class that provides vault address validation functionality."""
37 @field_validator("vault_address", mode="before")
38 @classmethod
39 def validate_vault_address(cls, v: Any) -> str:
40 """Validate vault address format and normalize it."""
41 if isinstance(v, str):
42 return normalize_address(v)
43 elif v is None:
44 raise ValidationError("Vault address cannot be None")
45 else:
46 return str(v)
49class UnderlyingTokenValidatorMixin:
50 """Mixin class that provides token address validation functionality."""
52 @field_validator("underlying_token", mode="before")
53 @classmethod
54 def validate_underlying_token(cls, v: Any) -> str:
55 """Validate underlying token address format and normalize it."""
56 if isinstance(v, str):
57 return normalize_address(v)
58 elif v is None:
59 raise ValidationError("Underlying token cannot be None")
60 else:
61 return str(v)
64def validate_chain_value(value: Any) -> "Chain":
65 """
66 Standalone function to validate chain values.
68 Args:
69 value: The value to validate
71 Returns:
72 Chain enum value, defaults to Chain.OTHER if invalid
73 """
74 from .type import Chain # Import here to avoid circular import
76 if isinstance(value, str):
77 try:
78 return Chain(value)
79 except ValueError:
80 return Chain.OTHER
81 elif isinstance(value, Chain):
82 return value
83 else:
84 return Chain.OTHER
87def normalize_address(address: str) -> str:
88 """
89 Normalize address format.
91 Args:
92 address: The address to normalize
94 Returns:
95 Normalized address (lowercase, with 0x prefix)
96 """
97 if not address:
98 raise ValidationError("Address cannot be empty")
100 # Remove whitespace
101 address = address.strip()
103 # Ensure it starts with 0x
104 if not address.startswith("0x"):
105 address = "0x" + address
107 # Convert to lowercase
108 address = address.lower()
110 # Validate format (0x followed by 40 hex characters)
111 if not re.match(r"^0x[a-f0-9]{40}$", address):
112 raise ValidationError(f"Invalid address format: {address}")
114 return address
117def validate_address_value(address: str) -> str:
118 """
119 Standalone function to validate address values.
121 Args:
122 address: The address to validate
124 Returns:
125 Normalized address
127 Raises:
128 ValidationError: If the address format is invalid
129 """
130 return normalize_address(address)