Coverage for src/configuraptor/helpers.py: 100%

116 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2026-05-01 17:15 +0200

1""" 

2Contains stand-alone helper functions. 

3""" 

4 

5import contextlib 

6import dataclasses as dc 

7import io 

8import math 

9import re 

10import types 

11import typing 

12import warnings 

13from collections import ChainMap 

14from pathlib import Path 

15 

16from expandvars import expand 

17from typeguard import TypeCheckError 

18from typeguard import check_type as _check_type 

19 

20try: 

21 import annotationlib 

22except ImportError: # pragma: no cover 

23 annotationlib = None 

24 

25 

26def camel_to_snake(s: str) -> str: 

27 """ 

28 Convert CamelCase to snake_case. 

29 

30 Source: 

31 https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case 

32 """ 

33 return "".join([f"_{c.lower()}" if c.isupper() else c for c in s]).lstrip("_") 

34 

35 

36# def find_pyproject_toml() -> typing.Optional[str]: 

37# """ 

38# Find the project's config toml, looks up until it finds the project root (black's logic). 

39# """ 

40# return black.files.find_pyproject_toml((os.getcwd(),)) 

41 

42 

43def find_pyproject_toml(start_dir: typing.Optional[Path | str] = None) -> Path | None: 

44 """ 

45 Search for pyproject.toml starting from the current working directory \ 

46 and moving upwards in the directory tree. 

47 

48 Args: 

49 start_dir: Starting directory to begin the search. 

50 If not provided, uses the current working directory. 

51 

52 Returns: 

53 Path or None: Path object to the found pyproject.toml file, or None if not found. 

54 """ 

55 start_dir = Path.cwd() if start_dir is None else Path(start_dir).resolve() 

56 

57 current_dir = start_dir 

58 

59 while str(current_dir) != str(current_dir.root): 

60 pyproject_toml = current_dir / "pyproject.toml" 

61 if pyproject_toml.is_file(): 

62 return pyproject_toml 

63 current_dir = current_dir.parent 

64 

65 # If not found anywhere 

66 return None 

67 

68 

69Type = typing.Type[typing.Any] 

70 

71 

72def strip_annotated(_type: type) -> type: 

73 """ 

74 Unwrap typing.Annotated[T, ...] to T. 

75 """ 

76 while typing.get_origin(_type) is typing.Annotated: 

77 args = typing.get_args(_type) 

78 if not args: # pragma: no cover 

79 break 

80 _type = typing.cast(type, args[0]) 

81 return _type 

82 

83 

84def _cls_annotations(c: type) -> dict[str, type]: # pragma: no cover 

85 """ 

86 Functions to get the annotations of a class (excl inherited, use _all_annotations for that). 

87 

88 Uses `annotationlib` if available (since 3.14) and if so, resolves forward references immediately. 

89 """ 

90 if annotationlib: 

91 return typing.cast( 

92 dict[str, type], 

93 annotationlib.get_annotations(c, format=annotationlib.Format.VALUE, eval_str=True), 

94 ) 

95 else: 

96 # note: idk why but this is not equivalent (the first doesn't work well): 

97 # return getattr(c, "__annotations__", {}) 

98 return c.__dict__.get("__annotations__") or {} 

99 

100 

101def _all_annotations(cls: type) -> ChainMap[str, type]: 

102 """ 

103 Returns a dictionary-like ChainMap that includes annotations for all \ 

104 attributes defined in cls or inherited from superclasses. 

105 """ 

106 # chainmap reverses the iterable, so reverse again beforehand to keep order normally: 

107 

108 return ChainMap(*(_cls_annotations(c) for c in getattr(cls, "__mro__", []))) 

109 

110 

111def all_annotations(cls: Type, _except: typing.Iterable[str] = None) -> dict[str, type[object]]: 

112 """ 

113 Wrapper around `_all_annotations` that filters away any keys in _except. 

114 

115 It also flattens the ChainMap to a regular dict. 

116 """ 

117 if _except is None: 

118 _except = set() 

119 

120 _all = _all_annotations(cls) 

121 return {k: strip_annotated(v) for k, v in _all.items() if k not in _except} 

122 

123 

124T = typing.TypeVar("T") 

125 

126 

127def check_type(value: typing.Any, expected_type: typing.Type[T]) -> typing.TypeGuard[T]: 

128 """ 

129 Given a variable, check if it matches 'expected_type' (which can be a Union, parameterized generic etc.). 

130 

131 Based on typeguard but this returns a boolean instead of returning the value or throwing a TypeCheckError 

132 """ 

133 try: 

134 _check_type(value, expected_type) 

135 return True 

136 except TypeCheckError: 

137 return False 

138 

139 

140def is_builtin_type(_type: Type) -> bool: 

141 """ 

142 Returns whether _type is one of the builtin types. 

143 """ 

144 return _type.__module__ in ("__builtin__", "builtins") 

145 

146 

147# def is_builtin_class_instance(obj: typing.Any) -> bool: 

148# return is_builtin_type(obj.__class__) 

149 

150 

151def is_from_types_or_typing(_type: Type) -> bool: 

152 """ 

153 Returns whether _type is one of the stlib typing/types types. 

154 

155 e.g. types.UnionType or typing.Union 

156 """ 

157 return _type.__module__ in ("types", "typing") 

158 

159 

160def is_from_other_toml_supported_module(_type: Type) -> bool: 

161 """ 

162 Besides builtins, toml also supports 'datetime' and 'math' types, \ 

163 so this returns whether _type is a type from these stdlib modules. 

164 """ 

165 return _type.__module__ in ("datetime", "math") 

166 

167 

168def is_parameterized(_type: Type) -> bool: 

169 """ 

170 Returns whether _type is a parameterized type. 

171 

172 Examples: 

173 list[str] -> True 

174 str -> False 

175 """ 

176 return typing.get_origin(_type) is not None 

177 

178 

179def is_custom_class(_type: Type) -> bool: 

180 """ 

181 Tries to guess if _type is a builtin or a custom (user-defined) class. 

182 

183 Other logic in this module depends on knowing that. 

184 """ 

185 return ( 

186 type(_type) is type 

187 and not is_builtin_type(_type) 

188 and not is_from_other_toml_supported_module(_type) 

189 and not is_from_types_or_typing(_type) 

190 ) 

191 

192 

193def instance_of_custom_class(var: typing.Any) -> bool: 

194 """ 

195 Calls `is_custom_class` on an instance of a (possibly custom) class. 

196 """ 

197 return is_custom_class(var.__class__) 

198 

199 

200def is_union(sometype: typing.Type[typing.Any] | typing.Any) -> bool: 

201 """ 

202 Determines if a given type is a Union type. 

203 

204 A Union type in Python is used to represent a type that can be one of multiple 

205 types. This function checks whether the provided type object corresponds to a 

206 Union type as defined in Python's type hints or annotations. 

207 

208 Returns: 

209 bool 

210 True if the provided type is a Union type, False otherwise. 

211 """ 

212 origin = typing.get_origin(sometype) 

213 return origin in (typing.Union, types.UnionType) 

214 

215 

216def is_optional(_type: Type | typing.Any) -> bool: 

217 """ 

218 Tries to guess if _type could be optional. 

219 

220 Examples: 

221 None -> True 

222 NoneType -> True 

223 typing.Union[str, None] -> True 

224 str | None -> True 

225 list[str | None] -> False 

226 list[str] -> False 

227 """ 

228 if _type and (is_parameterized(_type) and typing.get_origin(_type) in (dict, list)) or (_type is math.nan): 

229 # e.g. list[str] 

230 # will crash issubclass to test it first here 

231 return False 

232 

233 try: 

234 return ( 

235 _type is None 

236 or types.NoneType in typing.get_args(_type) # union with Nonetype 

237 or issubclass(types.NoneType, _type) 

238 or issubclass(types.NoneType, type(_type)) # no type # Nonetype 

239 ) 

240 except TypeError: 

241 # probably some weird input that's not a type 

242 return False 

243 

244 

245def dataclass_field(cls: Type, key: str) -> typing.Optional[dc.Field[typing.Any]]: 

246 """ 

247 Get Field info for a dataclass cls. 

248 """ 

249 fields = getattr(cls, "__dataclass_fields__", {}) 

250 return fields.get(key) 

251 

252 

253@contextlib.contextmanager 

254def uncloseable(fd: typing.BinaryIO) -> typing.Generator[typing.BinaryIO, typing.Any, None]: 

255 """ 

256 Context manager which turns the fd's close operation to no-op for the duration of the context. 

257 """ 

258 close = fd.close 

259 fd.close = lambda: None # type: ignore 

260 yield fd 

261 fd.close = close # type: ignore 

262 

263 

264def as_binaryio(file: str | Path | typing.BinaryIO | None, mode: typing.Literal["rb", "wb"] = "rb") -> typing.BinaryIO: 

265 """ 

266 Convert a number of possible 'file' descriptions into a single BinaryIO interface. 

267 """ 

268 if isinstance(file, str): 

269 file = Path(file) 

270 if isinstance(file, Path): 

271 file = file.open(mode) 

272 if file is None: 

273 file = io.BytesIO() 

274 if isinstance(file, io.BytesIO): 

275 # so .read() works after .write(): 

276 file.seek(0) 

277 # so the with-statement doesn't close the in-memory file: 

278 file = uncloseable(file) # type: ignore 

279 

280 return file 

281 

282 

283_LEGACY_ENV_DEFAULTS_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*):([^\-?=+][^}]*)}") 

284 

285 

286def _normalize_legacy_env_defaults(value: str) -> str: 

287 """ 

288 Rewrite legacy ${VAR:default} to ${VAR:-default} and warn once. 

289 """ 

290 if not _LEGACY_ENV_DEFAULTS_RE.search(value): 

291 return value 

292 

293 warnings.warn( 

294 "Legacy ${VAR:default} syntax is deprecated; use ${VAR:-default}. " 

295 "Support for the legacy form may be removed in a future release.", 

296 DeprecationWarning, 

297 stacklevel=3, 

298 ) 

299 

300 return _LEGACY_ENV_DEFAULTS_RE.sub(r"${\1:-\2}", value) 

301 

302 

303def expand_posix_vars(posix_expr: str, context: dict[str, str]) -> str: 

304 """ 

305 Replace case-insensitive POSIX and Docker Compose-like environment variables in a string with their values. 

306 

307 Args: 

308 posix_expr (str): The input string containing case-insensitive POSIX or Docker Compose-like variables. 

309 context (dict): A dictionary containing variable names and their respective values. 

310 

311 Returns: 

312 str: The string with replaced variable values. 

313 """ 

314 posix_expr = _normalize_legacy_env_defaults(posix_expr) 

315 return typing.cast(str, expand(posix_expr, environ=context)) 

316 

317 

318def expand_env_vars_into_toml_values( 

319 toml: dict[str, typing.Any], 

320 env: dict[str, typing.Any], 

321 *, 

322 case_insensitive: bool = True, 

323) -> None: 

324 """ 

325 Recursively expands POSIX/Docker Compose-like environment variables in a TOML dictionary. 

326 

327 This function traverses a TOML dictionary and expands POSIX/Docker Compose-like 

328 environment variables (${VAR:default}) using values provided in the 'env' dictionary. 

329 It performs in-place modification of the 'toml' dictionary. 

330 

331 Args: 

332 toml (dict): A TOML dictionary with string values possibly containing environment variables. 

333 env (dict): A dictionary containing environment variable names and their respective values. 

334 case_insensitive (bool): If True, treat environment keys as case-insensitive by adding 

335 upper/lower variants for lookup. Defaults to True. 

336 

337 Returns: 

338 None: The function modifies the 'toml' dictionary in place. 

339 

340 Notes: 

341 The function recursively traverses the 'toml' dictionary. If a value is a string or a list of strings, 

342 it attempts to substitute any environment variables found within those strings using the 'env' dictionary. 

343 

344 Example: 

345 toml_data = { 

346 'key1': 'This has ${ENV_VAR:default}', 

347 'key2': ['String with ${ANOTHER_VAR}', 'Another ${YET_ANOTHER_VAR}'] 

348 } 

349 environment = { 

350 'ENV_VAR': 'replaced_value', 

351 'ANOTHER_VAR': 'value_1', 

352 'YET_ANOTHER_VAR': 'value_2' 

353 } 

354 

355 expand_env_vars_into_toml_values(toml_data, environment) 

356 # 'toml_data' will be modified in place: 

357 # { 

358 # 'key1': 'This has replaced_value', 

359 # 'key2': ['String with value_1', 'Another value_2'] 

360 # } 

361 """ 

362 if not toml or not env: # pragma: no cover 

363 return 

364 

365 if case_insensitive: 

366 env_case: dict[str, typing.Any] = dict(env) 

367 for key, value in env.items(): 

368 upper = key.upper() 

369 lower = key.lower() 

370 if upper not in env_case: 

371 env_case[upper] = value 

372 if lower not in env_case: 

373 env_case[lower] = value 

374 env = env_case 

375 

376 for key, var in toml.items(): 

377 if isinstance(var, dict): 

378 expand_env_vars_into_toml_values(var, env, case_insensitive=case_insensitive) 

379 elif isinstance(var, list): 

380 toml[key] = [expand_posix_vars(value, env) if isinstance(value, str) else value for value in var] 

381 elif isinstance(var, str): 

382 toml[key] = expand_posix_vars(var, env) 

383 else: 

384 # nothing to substitute 

385 continue