Coverage for C: \ Users \ peaco \ OneDrive \ Documents \ GitHub \ mt_metadata \ mt_metadata \ common \ enumerations.py: 92%

170 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-10 00:11 -0800

1# ===================================================== 

2# Imports 

3# ===================================================== 

4import json 

5from enum import Enum, EnumMeta 

6from pathlib import Path 

7 

8from pydantic import GetCoreSchemaHandler 

9from pydantic_core import core_schema, CoreSchema 

10 

11 

12# ===================================================== 

13 

14 

15class StrEnumerationBase(str, Enum): 

16 @classmethod 

17 def __get_pydantic_core_schema__( 

18 cls, source_type: type[Enum], handler: GetCoreSchemaHandler 

19 ) -> CoreSchema: 

20 # Create a case-insensitive enum schema that returns enum instances 

21 # Use the before validator to handle case conversion, then enum schema for validation 

22 return core_schema.no_info_before_validator_function( 

23 cls._normalize_case, core_schema.enum_schema(cls, list(cls)) 

24 ) 

25 

26 @classmethod 

27 def _normalize_case(cls, value): 

28 """Convert input to the proper enum member for validation.""" 

29 if isinstance(value, cls): 

30 return value # Already the correct enum type 

31 

32 if not isinstance(value, str): 

33 raise TypeError(f"Expected string, got {type(value)}") 

34 

35 # For each enum member, check if the lowercase input matches the lowercase enum value 

36 for member in cls: 

37 if value.lower() == member.value.lower(): 

38 return member # Return the enum member, not just the string value 

39 

40 # If no match found, raise an error with valid options 

41 valid_values = [member.value for member in cls] 

42 raise ValueError( 

43 f"Invalid value: {value}. Must be one of {valid_values} (case-insensitive)." 

44 ) 

45 

46 

47class YesNoEnum(StrEnumerationBase): 

48 yes = "yes" 

49 no = "no" 

50 

51 

52class DataTypeEnum(StrEnumerationBase): 

53 RMT = "RMT" 

54 AMT = "AMT" 

55 BBMT = "BBMT" 

56 LPMT = "LPMT" 

57 ULPMT = "ULPMT" 

58 MT = "MT" 

59 LP = "LP" 

60 BB = "BB" 

61 WB = "WB" 

62 MT_TF = "MT_TF" 

63 BBMT_TF = "BBMT_TF" 

64 WBMT_TF = "WBMT_TF" 

65 LPMT_TF = "LPMT_TF" 

66 CSAMT = "CSAMT" 

67 NSAMT = "NSAMT" 

68 

69 

70class ArrayDTypeEnum(str, Enum): 

71 real_type = "real" 

72 complex_type = "complex" 

73 float_type = "float" 

74 int_type = "int" 

75 complex_128_type = "complex128" 

76 complex_64_type = "complex64" 

77 float_64_type = "float64" 

78 float_32_type = "float32" 

79 float_16_type = "float16" 

80 int_64_type = "int64" 

81 int_32_type = "int32" 

82 int_16_type = "int16" 

83 int_8_type = "int8" 

84 

85 

86class EstimateIntentionEnum(str, Enum): 

87 error_estimate = "error estimate" 

88 signal_coherence = "signal coherence" 

89 signal_power_estimate = "signal power estimate" 

90 primary_data_type = "primary data type" 

91 derived_data_type = "derived data type" 

92 

93 

94class ChannelLayoutEnum(StrEnumerationBase): 

95 L = "L" 

96 X = "X" 

97 plus = "+" 

98 

99 

100class ElectrodeLocationEnum(StrEnumerationBase): 

101 N = "N" 

102 S = "S" 

103 E = "E" 

104 W = "W" 

105 NONE = "" 

106 

107 

108class OrientationMethodEnum(StrEnumerationBase): 

109 compass = "compass" 

110 GPS = "GPS" 

111 theodolite = "theodolite" 

112 

113 

114class GeographicReferenceFrameEnum(StrEnumerationBase): 

115 geographic = "geographic" 

116 geomagnetic = "geomagnetic" 

117 station = "station" 

118 site_layout = "sitelayout" 

119 

120 

121class ChannelOrientationEnum(StrEnumerationBase): 

122 orthogonal = "orthogonal" 

123 station = "station" 

124 site_layout = "sitelayout" 

125 

126 

127class GeomagneticModelEnum(str, Enum): 

128 """split by - if needed""" 

129 

130 EMAG2 = "EMAG2" 

131 EMM = "EMM" 

132 HDGM = "HDGM" 

133 IGRF = "IGRF" 

134 WMM = "WMM" 

135 unknown = "unknown" 

136 

137 @classmethod 

138 def __get_pydantic_core_schema__( 

139 cls, source_type: type[Enum], handler: GetCoreSchemaHandler 

140 ) -> CoreSchema: 

141 # Define a schema that validates and converts input to lowercase 

142 return core_schema.no_info_plain_validator_function(cls._validate_lowercase) 

143 

144 @classmethod 

145 def _validate_lowercase(cls, value: str) -> str: 

146 if not isinstance(value, str): 

147 raise TypeError(f"Expected string, got {type(value)}") 

148 value_lower = value.lower().split("-")[0] 

149 valid_values = [member.value.lower() for member in cls] 

150 if value_lower not in valid_values: 

151 raise ValueError(f"Invalid value: {value}. Must be one of {valid_values}.") 

152 return value 

153 

154 

155class FilterTypeEnum(StrEnumerationBase): 

156 fap_table = "fap" 

157 zpk = "zpk" 

158 time_delay = "time_delay" 

159 coefficient = "coefficient" 

160 fir = "fir" 

161 

162 

163class SymmetryEnum(StrEnumerationBase): 

164 NONE = "NONE" 

165 ODD = "ODD" 

166 EVEN = "EVEN" 

167 

168 

169class SignConventionEnum(str, Enum): 

170 plus = "+" 

171 minus = "-" 

172 exp_plus = "exp(+iwt)" 

173 exp_minus = "exp(-iwt)" 

174 exp_plus_iwt = "exp(+ i\\omega t)" 

175 exp_minus_iwt = "exp(- i\\omega t)" 

176 

177 

178class StdEDIversionsEnum(str, Enum): 

179 SEG_1 = "SEG 1.0" 

180 one = "1.0" 

181 SEG_10 = "SEG_1.0" 

182 SEG_101 = "SEG 1.01" 

183 SEG_1011 = "SEG_1.01" 

184 

185 

186class ReleaseStatusEnum(StrEnumerationBase): 

187 Unrestricted_release = "Unrestricted Release" 

188 Restricted_release = "Restricted Release" 

189 Paper_Citation_Required = "Paper Citation Required" 

190 Academic_Use_Only = "Academic Use Only" 

191 Conditions_Apply = "Conditions Apply" 

192 Data_Citation_Required = "Data Citation Required" 

193 

194 

195class ChannelEnum(StrEnumerationBase): 

196 ex = "ex" 

197 ey = "ey" 

198 hx = "hx" 

199 hy = "hy" 

200 hz = "hz" 

201 null = "" 

202 

203 

204## This is a better way to making an pydantic type of enumeration with a validator 

205class LicenseEnumMeta(EnumMeta): 

206 """Metaclass to dynamically load license data when the enum is defined""" 

207 

208 def __new__(metacls, cls, bases, classdict): 

209 # Create the enum class first 

210 enum_class = super().__new__(metacls, cls, bases, classdict) 

211 

212 # Load the licenses JSON file 

213 filename = Path(__file__).parent.parent.joinpath("data", "licenses.json") 

214 with open(filename, "r") as fid: 

215 licenses = json.load(fid) 

216 

217 # Add base licenses 

218 base_licenses = [ 

219 ("CC_0", "CC0"), 

220 ("CC0", "CC0"), 

221 ("CC_BY", "CC BY"), 

222 ("CC_BY_SA", "CC BY-SA"), 

223 ("CC_BY_NC", "CC BY-NC"), 

224 ("CC_BY_NA_SA", "CC BY-NC-SA"), 

225 ("CC_BY_ND", "CC BY-ND"), 

226 ("CC_BY_NC_ND", "CC BY-NC-ND"), 

227 ] 

228 

229 # Create dynamic enum members 

230 member_dict = {} 

231 for key, value in base_licenses: 

232 member_dict[key] = value 

233 

234 # Add licenses from JSON file 

235 for license in licenses["licenses"]: 

236 key = ( 

237 license["licenseId"] 

238 .replace("-", "_") 

239 .replace(" ", "_") 

240 .replace(".", "_") 

241 .replace("(", "_") 

242 .replace(")", "_") 

243 .replace("/", "_") 

244 .replace(":", "_") 

245 ) 

246 value = license["licenseId"] 

247 member_dict[key] = value 

248 

249 # Now create the actual enum 

250 return Enum(enum_class.__name__, member_dict) 

251 

252 

253class LicenseEnum(str, Enum, metaclass=LicenseEnumMeta): 

254 """ 

255 Enumeration of software licenses. 

256 Dynamically loaded from JSON data. 

257 """ 

258 

259 @classmethod 

260 def __get_pydantic_core_schema__( 

261 cls, source_type: type[Enum], handler: GetCoreSchemaHandler 

262 ) -> CoreSchema: 

263 return core_schema.no_info_plain_validator_function(cls._validate) 

264 

265 @classmethod 

266 def _validate(cls, value: str) -> str: 

267 if not isinstance(value, str): 

268 raise TypeError(f"Expected string, got {type(value)}") 

269 

270 # Check if the value is a valid license 

271 for member in cls: 

272 if member.value == value: 

273 return value 

274 

275 # Handle case insensitivity or slight variations 

276 value_normalized = value.upper().replace("-", "_").replace(" ", "_") 

277 for member in cls: 

278 member_normalized = member.value.upper().replace("-", "_").replace(" ", "_") 

279 if member_normalized == value_normalized: 

280 return member.value 

281 

282 raise ValueError( 

283 f"Invalid license: {value}. Must be one of the valid licenses." 

284 )