Coverage for src/configuraptor/helpers.py: 100%

116 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2026-05-14 16:50 +0200

1""" 

2Contains stand-alone helper functions. 

3""" 

4 

5import contextlib 

6import dataclasses as dc 

7import io 

8import math 

9import re 

10import types 

11import typing 

12import warnings 

13from collections import ChainMap 

14from pathlib import Path 

15 

16from expandvars import expand 

17from typeguard import TypeCheckError 

18from typeguard import check_type as _check_type 

19 

20try: 

21 import annotationlib 

22except ImportError: # pragma: no cover 

23 annotationlib = None 

24 

25 

26def camel_to_snake(s: str) -> str: 

27 """ 

28 Convert CamelCase to snake_case. 

29 

30 Source: 

31 https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case 

32 """ 

33 return "".join([f"_{c.lower()}" if c.isupper() else c for c in s]).lstrip("_") 

34 

35 

36# def find_pyproject_toml() -> typing.Optional[str]: 

37# """ 

38# Find the project's config toml, looks up until it finds the project root (black's logic). 

39# """ 

40# return black.files.find_pyproject_toml((os.getcwd(),)) 

41 

42 

43def find_pyproject_toml(start_dir: typing.Optional[Path | str] = None) -> Path | None: 

44 """ 

45 Search for pyproject.toml starting from the current working directory \ 

46 and moving upwards in the directory tree. 

47 

48 Args: 

49 start_dir: Starting directory to begin the search. 

50 If not provided, uses the current working directory. 

51 

52 Returns: 

53 Path or None: Path object to the found pyproject.toml file, or None if not found. 

54 """ 

55 start_dir = Path.cwd() if start_dir is None else Path(start_dir).resolve() 

56 

57 current_dir = start_dir 

58 

59 while str(current_dir) != str(current_dir.root): 

60 pyproject_toml = current_dir / "pyproject.toml" 

61 if pyproject_toml.is_file(): 

62 return pyproject_toml 

63 current_dir = current_dir.parent 

64 

65 # If not found anywhere 

66 return None 

67 

68 

69Type = typing.Type[typing.Any] 

70 

71 

72def strip_annotated(_type: type) -> type: 

73 """ 

74 Unwrap typing.Annotated[T, ...] to T. 

75 """ 

76 while typing.get_origin(_type) is typing.Annotated: 

77 args = typing.get_args(_type) 

78 if not args: # pragma: no cover 

79 break 

80 _type = typing.cast(type, args[0]) 

81 return _type 

82 

83 

84def _cls_annotations(c: type) -> dict[str, type]: # pragma: no cover 

85 """ 

86 Functions to get the annotations of a class (excl inherited, use _all_annotations for that). 

87 

88 Uses `annotationlib` if available (since 3.14) and if so, resolves forward references immediately. 

89 """ 

90 if annotationlib: 

91 return typing.cast( 

92 dict[str, type], 

93 annotationlib.get_annotations(c, format=annotationlib.Format.VALUE, eval_str=True), 

94 ) 

95 else: 

96 # note: idk why but this is not equivalent (the first doesn't work well): 

97 # return getattr(c, "__annotations__", {}) 

98 return c.__dict__.get("__annotations__") or {} 

99 

100 

101def _all_annotations(cls: type) -> ChainMap[str, type]: 

102 """ 

103 Returns a dictionary-like ChainMap that includes annotations for all \ 

104 attributes defined in cls or inherited from superclasses. 

105 """ 

106 # chainmap reverses the iterable, so reverse again beforehand to keep order normally: 

107 

108 return ChainMap(*(_cls_annotations(c) for c in getattr(cls, "__mro__", []))) 

109 

110 

111def all_annotations(cls: Type, _except: typing.Iterable[str] = None) -> dict[str, type[object]]: 

112 """ 

113 Wrapper around `_all_annotations` that filters away any keys in _except. 

114 

115 It also flattens the ChainMap to a regular dict. 

116 """ 

117 if _except is None: 

118 _except = set() 

119 

120 _all = _all_annotations(cls) 

121 return {k: strip_annotated(v) for k, v in _all.items() if k not in _except} 

122 

123 

124T = typing.TypeVar("T") 

125 

126 

127def check_type(value: typing.Any, expected_type: typing.Type[T]) -> typing.TypeGuard[T]: 

128 """ 

129 Given a variable, check if it matches 'expected_type' (which can be a Union, parameterized generic etc.). 

130 

131 Based on typeguard but this returns a boolean instead of returning the value or throwing a TypeCheckError 

132 """ 

133 try: 

134 _check_type(value, expected_type) 

135 return True 

136 except TypeCheckError: 

137 return False 

138 

139 

140def is_builtin_type(_type: Type) -> bool: 

141 """ 

142 Returns whether _type is one of the builtin types. 

143 """ 

144 return _type.__module__ in ("__builtin__", "builtins") 

145 

146 

147# def is_builtin_class_instance(obj: typing.Any) -> bool: 

148# return is_builtin_type(obj.__class__) 

149 

150 

151def is_from_types_or_typing(_type: Type) -> bool: 

152 """ 

153 Returns whether _type is one of the stlib typing/types types. 

154 

155 e.g. types.UnionType or typing.Union 

156 """ 

157 return _type.__module__ in ("types", "typing") 

158 

159 

160def is_from_other_toml_supported_module(_type: Type) -> bool: 

161 """ 

162 Besides builtins, toml also supports 'datetime' and 'math' types, \ 

163 so this returns whether _type is a type from these stdlib modules. 

164 """ 

165 return _type.__module__ in ("datetime", "math") 

166 

167 

168def is_parameterized(_type: Type) -> bool: 

169 """ 

170 Returns whether _type is a parameterized type. 

171 

172 Examples: 

173 list[str] -> True 

174 str -> False 

175 """ 

176 return typing.get_origin(_type) is not None 

177 

178 

179def is_custom_class(_type: Type) -> bool: 

180 """ 

181 Tries to guess if _type is a builtin or a custom (user-defined) class. 

182 

183 Other logic in this module depends on knowing that. 

184 """ 

185 return ( 

186 type(_type) is type 

187 and not is_builtin_type(_type) 

188 and not is_from_other_toml_supported_module(_type) 

189 and not is_from_types_or_typing(_type) 

190 ) 

191 

192 

193def instance_of_custom_class(var: typing.Any) -> bool: 

194 """ 

195 Calls `is_custom_class` on an instance of a (possibly custom) class. 

196 """ 

197 return is_custom_class(var.__class__) 

198 

199 

200def is_union(sometype: typing.Type[typing.Any] | typing.Any) -> bool: 

201 """ 

202 Determines if a given type is a Union type. 

203 

204 A Union type in Python is used to represent a type that can be one of multiple 

205 types. This function checks whether the provided type object corresponds to a 

206 Union type as defined in Python's type hints or annotations. 

207 

208 Returns: 

209 bool 

210 True if the provided type is a Union type, False otherwise. 

211 """ 

212 origin = typing.get_origin(sometype) 

213 return origin in (typing.Union, types.UnionType) 

214 

215 

216def is_optional(_type: Type | typing.Any) -> bool: 

217 """ 

218 Tries to guess if _type could be optional. 

219 

220 Examples: 

221 None -> True 

222 NoneType -> True 

223 typing.Union[str, None] -> True 

224 str | None -> True 

225 list[str | None] -> False 

226 list[str] -> False 

227 """ 

228 if _type and (is_parameterized(_type) and typing.get_origin(_type) in (dict, list)) or (_type is math.nan): 

229 # e.g. list[str] 

230 # will crash issubclass to test it first here 

231 return False 

232 

233 try: 

234 return ( 

235 _type is None 

236 or types.NoneType in typing.get_args(_type) # union with Nonetype 

237 or issubclass(types.NoneType, _type) 

238 or issubclass(types.NoneType, type(_type)) # no type # Nonetype 

239 ) 

240 except TypeError: 

241 # probably some weird input that's not a type 

242 return False 

243 

244def dataclass_field(cls: Type, key: str) -> typing.Optional[dc.Field[typing.Any]]: 

245 """ 

246 Get Field info for a dataclass cls. 

247 """ 

248 fields = getattr(cls, "__dataclass_fields__", {}) 

249 return fields.get(key) 

250 

251 

252@contextlib.contextmanager 

253def uncloseable(fd: typing.BinaryIO) -> typing.Generator[typing.BinaryIO, typing.Any, None]: 

254 """ 

255 Context manager which turns the fd's close operation to no-op for the duration of the context. 

256 """ 

257 close = fd.close 

258 fd.close = lambda: None # type: ignore 

259 yield fd 

260 fd.close = close # type: ignore 

261 

262 

263def as_binaryio(file: str | Path | typing.BinaryIO | None, mode: typing.Literal["rb", "wb"] = "rb") -> typing.BinaryIO: 

264 """ 

265 Convert a number of possible 'file' descriptions into a single BinaryIO interface. 

266 """ 

267 if isinstance(file, str): 

268 file = Path(file) 

269 if isinstance(file, Path): 

270 file = file.open(mode) 

271 if file is None: 

272 file = io.BytesIO() 

273 if isinstance(file, io.BytesIO): 

274 # so .read() works after .write(): 

275 file.seek(0) 

276 # so the with-statement doesn't close the in-memory file: 

277 file = uncloseable(file) # type: ignore 

278 

279 return file 

280 

281 

282_LEGACY_ENV_DEFAULTS_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*):([^\-?=+][^}]*)}") 

283 

284 

285def _normalize_legacy_env_defaults(value: str) -> str: 

286 """ 

287 Rewrite legacy ${VAR:default} to ${VAR:-default} and warn once. 

288 """ 

289 if not _LEGACY_ENV_DEFAULTS_RE.search(value): 

290 return value 

291 

292 warnings.warn( 

293 "Legacy ${VAR:default} syntax is deprecated; use ${VAR:-default}. " 

294 "Support for the legacy form may be removed in a future release.", 

295 DeprecationWarning, 

296 stacklevel=3, 

297 ) 

298 

299 return _LEGACY_ENV_DEFAULTS_RE.sub(r"${\1:-\2}", value) 

300 

301 

302def expand_posix_vars(posix_expr: str, context: dict[str, str]) -> str: 

303 """ 

304 Replace case-insensitive POSIX and Docker Compose-like environment variables in a string with their values. 

305 

306 Args: 

307 posix_expr (str): The input string containing case-insensitive POSIX or Docker Compose-like variables. 

308 context (dict): A dictionary containing variable names and their respective values. 

309 

310 Returns: 

311 str: The string with replaced variable values. 

312 """ 

313 posix_expr = _normalize_legacy_env_defaults(posix_expr) 

314 return typing.cast(str, expand(posix_expr, environ=context)) 

315 

316 

317def expand_env_vars_into_toml_values( 

318 toml: dict[str, typing.Any], 

319 env: dict[str, typing.Any], 

320 *, 

321 case_insensitive: bool = True, 

322) -> None: 

323 """ 

324 Recursively expands POSIX/Docker Compose-like environment variables in a TOML dictionary. 

325 

326 This function traverses a TOML dictionary and expands POSIX/Docker Compose-like 

327 environment variables (${VAR:default}) using values provided in the 'env' dictionary. 

328 It performs in-place modification of the 'toml' dictionary. 

329 

330 Args: 

331 toml (dict): A TOML dictionary with string values possibly containing environment variables. 

332 env (dict): A dictionary containing environment variable names and their respective values. 

333 case_insensitive (bool): If True, treat environment keys as case-insensitive by adding 

334 upper/lower variants for lookup. Defaults to True. 

335 

336 Returns: 

337 None: The function modifies the 'toml' dictionary in place. 

338 

339 Notes: 

340 The function recursively traverses the 'toml' dictionary. If a value is a string or a list of strings, 

341 it attempts to substitute any environment variables found within those strings using the 'env' dictionary. 

342 

343 Example: 

344 toml_data = { 

345 'key1': 'This has ${ENV_VAR:default}', 

346 'key2': ['String with ${ANOTHER_VAR}', 'Another ${YET_ANOTHER_VAR}'] 

347 } 

348 environment = { 

349 'ENV_VAR': 'replaced_value', 

350 'ANOTHER_VAR': 'value_1', 

351 'YET_ANOTHER_VAR': 'value_2' 

352 } 

353 

354 expand_env_vars_into_toml_values(toml_data, environment) 

355 # 'toml_data' will be modified in place: 

356 # { 

357 # 'key1': 'This has replaced_value', 

358 # 'key2': ['String with value_1', 'Another value_2'] 

359 # } 

360 """ 

361 if not toml or not env: # pragma: no cover 

362 return 

363 

364 if case_insensitive: 

365 env_case: dict[str, typing.Any] = dict(env) 

366 for key, value in env.items(): 

367 upper = key.upper() 

368 lower = key.lower() 

369 if upper not in env_case: 

370 env_case[upper] = value 

371 if lower not in env_case: 

372 env_case[lower] = value 

373 env = env_case 

374 

375 for key, var in toml.items(): 

376 if isinstance(var, dict): 

377 expand_env_vars_into_toml_values(var, env, case_insensitive=case_insensitive) 

378 elif isinstance(var, list): 

379 toml[key] = [expand_posix_vars(value, env) if isinstance(value, str) else value for value in var] 

380 elif isinstance(var, str): 

381 toml[key] = expand_posix_vars(var, env) 

382 else: 

383 # nothing to substitute 

384 continue