Coverage for yield_analysis_sdk\validators.py: 78%

60 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-15 13:20 +0800

1""" 

2Common validators and mixins for the yield analysis SDK. 

3""" 

4 

5import re 

6from typing import TYPE_CHECKING, Any, Union 

7 

8from pydantic import ConfigDict, field_validator, field_serializer 

9 

10from .exceptions import ValidationError 

11 

12if TYPE_CHECKING: 

13 from .type import Chain 

14 

15 

16class ChainMixin: 

17 """Mixin class that provides chain validation functionality.""" 

18 model_config = ConfigDict(use_enum_values=True) 

19 

20 @field_validator("chain", mode="before") 

21 @classmethod 

22 def validate_chain(cls, v: Any) -> "Chain": 

23 """Validate chain and return OTHER if not found.""" 

24 from .type import Chain # Import here to avoid circular import 

25 

26 if isinstance(v, str): 

27 try: 

28 return Chain(v) 

29 except ValueError: 

30 return Chain.OTHER 

31 elif isinstance(v, Chain): 

32 return v 

33 else: 

34 return Chain.OTHER 

35 

36 

37class VaultAddressValidatorMixin: 

38 """Mixin class that provides vault address validation functionality.""" 

39 

40 @field_validator("vault_address", mode="before") 

41 @classmethod 

42 def validate_vault_address(cls, v: Any) -> str: 

43 """Validate vault address format and normalize it.""" 

44 if isinstance(v, str): 

45 return normalize_address(v) 

46 elif v is None: 

47 raise ValidationError("Vault address cannot be None") 

48 else: 

49 return str(v) 

50 

51 

52class UnderlyingTokenValidatorMixin: 

53 """Mixin class that provides token address validation functionality.""" 

54 

55 @field_validator("underlying_token", mode="before") 

56 @classmethod 

57 def validate_underlying_token(cls, v: Any) -> str: 

58 """Validate underlying token address format and normalize it.""" 

59 if isinstance(v, str): 

60 return normalize_address(v) 

61 elif v is None: 

62 raise ValidationError("Underlying token cannot be None") 

63 else: 

64 return str(v) 

65 

66 

67def validate_chain_value(value: Any) -> "Chain": 

68 """ 

69 Standalone function to validate chain values. 

70 

71 Args: 

72 value: The value to validate 

73 

74 Returns: 

75 Chain enum value, defaults to Chain.OTHER if invalid 

76 """ 

77 from .type import Chain # Import here to avoid circular import 

78 

79 if isinstance(value, str): 

80 try: 

81 return Chain(value) 

82 except ValueError: 

83 return Chain.OTHER 

84 elif isinstance(value, Chain): 

85 return value 

86 else: 

87 return Chain.OTHER 

88 

89 

90def normalize_address(address: str) -> str: 

91 """ 

92 Normalize address format. 

93 

94 Args: 

95 address: The address to normalize 

96 

97 Returns: 

98 Normalized address (lowercase, with 0x prefix) 

99 """ 

100 if not address: 

101 raise ValidationError("Address cannot be empty") 

102 

103 # Remove whitespace 

104 address = address.strip() 

105 

106 # Ensure it starts with 0x 

107 if not address.startswith("0x"): 

108 address = "0x" + address 

109 

110 # Convert to lowercase 

111 address = address.lower() 

112 

113 # Validate format (0x followed by 40 hex characters) 

114 if not re.match(r"^0x[a-f0-9]{40}$", address): 

115 raise ValidationError(f"Invalid address format: {address}") 

116 

117 return address 

118 

119 

120def validate_address_value(address: str) -> str: 

121 """ 

122 Standalone function to validate address values. 

123 

124 Args: 

125 address: The address to validate 

126 

127 Returns: 

128 Normalized address 

129 

130 Raises: 

131 ValidationError: If the address format is invalid 

132 """ 

133 return normalize_address(address)