Coverage for yield_analysis_sdk\validators.py: 78%

59 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-02 10:00 +0800

1""" 

2Common validators and mixins for the yield analysis SDK. 

3""" 

4 

5import re 

6from typing import Any, TYPE_CHECKING, Union 

7from pydantic import field_validator 

8from .exceptions import ValidationError 

9 

10if TYPE_CHECKING: 

11 from .type import Chain 

12 

13 

14class ChainValidatorMixin: 

15 """Mixin class that provides chain validation functionality.""" 

16 

17 @field_validator("chain", mode="before") 

18 @classmethod 

19 def validate_chain(cls, v: Any) -> "Chain": 

20 """Validate chain and return OTHER if not found.""" 

21 from .type import Chain # Import here to avoid circular import 

22 

23 if isinstance(v, str): 

24 try: 

25 return Chain(v) 

26 except ValueError: 

27 return Chain.OTHER 

28 elif isinstance(v, Chain): 

29 return v 

30 else: 

31 return Chain.OTHER 

32 

33 

34class VaultAddressValidatorMixin: 

35 """Mixin class that provides vault address validation functionality.""" 

36 

37 @field_validator("vault_address", mode="before") 

38 @classmethod 

39 def validate_vault_address(cls, v: Any) -> str: 

40 """Validate vault address format and normalize it.""" 

41 if isinstance(v, str): 

42 return normalize_address(v) 

43 elif v is None: 

44 raise ValidationError("Vault address cannot be None") 

45 else: 

46 return str(v) 

47 

48 

49class UnderlyingTokenValidatorMixin: 

50 """Mixin class that provides token address validation functionality.""" 

51 

52 @field_validator("underlying_token", mode="before") 

53 @classmethod 

54 def validate_underlying_token(cls, v: Any) -> str: 

55 """Validate underlying token address format and normalize it.""" 

56 if isinstance(v, str): 

57 return normalize_address(v) 

58 elif v is None: 

59 raise ValidationError("Underlying token cannot be None") 

60 else: 

61 return str(v) 

62 

63 

64def validate_chain_value(value: Any) -> "Chain": 

65 """ 

66 Standalone function to validate chain values. 

67 

68 Args: 

69 value: The value to validate 

70 

71 Returns: 

72 Chain enum value, defaults to Chain.OTHER if invalid 

73 """ 

74 from .type import Chain # Import here to avoid circular import 

75 

76 if isinstance(value, str): 

77 try: 

78 return Chain(value) 

79 except ValueError: 

80 return Chain.OTHER 

81 elif isinstance(value, Chain): 

82 return value 

83 else: 

84 return Chain.OTHER 

85 

86 

87def normalize_address(address: str) -> str: 

88 """ 

89 Normalize address format. 

90 

91 Args: 

92 address: The address to normalize 

93 

94 Returns: 

95 Normalized address (lowercase, with 0x prefix) 

96 """ 

97 if not address: 

98 raise ValidationError("Address cannot be empty") 

99 

100 # Remove whitespace 

101 address = address.strip() 

102 

103 # Ensure it starts with 0x 

104 if not address.startswith("0x"): 

105 address = "0x" + address 

106 

107 # Convert to lowercase 

108 address = address.lower() 

109 

110 # Validate format (0x followed by 40 hex characters) 

111 if not re.match(r"^0x[a-f0-9]{40}$", address): 

112 raise ValidationError(f"Invalid address format: {address}") 

113 

114 return address 

115 

116 

117def validate_address_value(address: str) -> str: 

118 """ 

119 Standalone function to validate address values. 

120 

121 Args: 

122 address: The address to validate 

123 

124 Returns: 

125 Normalized address 

126 

127 Raises: 

128 ValidationError: If the address format is invalid 

129 """ 

130 return normalize_address(address)