Coverage for src / dynapydantic / tracking_group.py: 100%

148 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-17 17:07 +0000

1"""Base class for dynamic pydantic models""" 

2 

3import contextlib 

4import functools 

5import operator 

6import typing as ty 

7import warnings 

8 

9import pydantic 

10import pydantic.fields 

11import pydantic_core 

12 

13from .exceptions import ( 

14 AmbiguousDiscriminatorValueError, 

15 NoRegisteredTypesError, 

16 RegistrationError, 

17) 

18from .union_mode import DiscriminatedConfig, UnionMode 

19 

20 

21def _inject_discriminator_field( 

22 cls: type[pydantic.BaseModel], 

23 disc_field: str, 

24 value: str, 

25) -> pydantic.fields.FieldInfo: 

26 """Injects the discriminator field into the given model 

27 

28 Parameters 

29 ---------- 

30 cls 

31 The BaseModel subclass 

32 disc_field 

33 Name of the discriminator field 

34 value 

35 Value of the discriminator field 

36 """ 

37 if hasattr(cls, disc_field): 

38 msg = ( 

39 f'Cannot inject discriminator field "{disc_field}" into ' 

40 f"{cls.__name__}: an attribute with that name already exists. " 

41 "Rename either the attribute or the discriminator_field to avoid " 

42 "the conflict." 

43 ) 

44 raise RegistrationError(msg) 

45 

46 cls.model_fields[disc_field] = pydantic.fields.FieldInfo( 

47 default=value, 

48 annotation=ty.Literal[value], # type: ignore[not-a-type] 

49 frozen=True, 

50 ) 

51 with contextlib.suppress(pydantic.errors.PydanticUndefinedAnnotation): 

52 cls.model_rebuild(force=True) 

53 return cls.model_fields[disc_field] 

54 

55 

56class TrackingGroup(pydantic.BaseModel): 

57 """Tracker for pydantic models""" 

58 

59 name: str = pydantic.Field( 

60 description=( 

61 "Name of the tracking group. This is for human display, so it " 

62 "doesn't technically need to be globally unique, but it should be " 

63 "meaningfully named, as it will be used in error messages." 

64 ), 

65 ) 

66 union_mode: UnionMode | None = pydantic.Field( 

67 None, 

68 description=( 

69 "Union validation strategy. Pass a DiscriminatedConfig instance " 

70 'or one of the plain strings "smart" or "left_to_right". You can ' 

71 "also just pass the fields for DiscriminatedConfig to this " 

72 "model and they will be forwarded." 

73 ), 

74 ) 

75 discriminator_field: str | None = pydantic.Field( 

76 None, 

77 description=( 

78 "Name of the discriminator field. NOTE: This field is " 

79 "here as an alias for union_mode.discriminator_field. Passing " 

80 "both a discriminator_field and a union_mode will result in an " 

81 "error." 

82 ), 

83 ) 

84 discriminator_value_generator: ty.Callable[[type], str] | None = pydantic.Field( 

85 None, 

86 description=( 

87 "A callable that produces default values for the discriminator field" 

88 ), 

89 ) 

90 plugin_entry_point: str | None = pydantic.Field( 

91 None, 

92 description=( 

93 "If given, then plugins packages will be supported through this " 

94 "Python entrypoint. The entrypoint can either be a function, " 

95 "which will be called, or simply a module, which will be " 

96 "imported. In either case, models found along the import path of " 

97 "the entrypoint will be registered. If the entrypoint is a " 

98 "function, additional models may be declared in the function." 

99 ), 

100 ) 

101 models: dict[str, type[pydantic.BaseModel]] = pydantic.Field( 

102 {}, 

103 description="The tracked models", 

104 ) 

105 

106 _generation: int = pydantic.PrivateAttr(default=0) 

107 _adapter: pydantic.TypeAdapter | None = pydantic.PrivateAttr(default=None) 

108 _adapter_generation: int = pydantic.PrivateAttr(default=-1) 

109 

110 @pydantic.model_validator(mode="after") 

111 def _ensure_union_mode(self) -> "TrackingGroup": 

112 """There must be a union_mode 

113 

114 This validator works as a guard on _coerce_union_mode to make 

115 """ 

116 if self.union_mode is None: 

117 msg = ( 

118 "union_mode is required. This normally indicates that you " 

119 "subclasses TrackingGroup and wrote an invalid validator, but " 

120 "could also be a bug with dynapydantic, so please file a bug " 

121 "report with a reproducer on how you got here if you suspect " 

122 "a bug." 

123 ) 

124 raise ValueError(msg) 

125 

126 # Ensure the top-level fields are in-sync 

127 if isinstance(self.union_mode, DiscriminatedConfig): 

128 self.discriminator_field = self.union_mode.discriminator_field 

129 self.discriminator_value_generator = ( 

130 self.union_mode.discriminator_value_generator 

131 ) 

132 else: 

133 self.discriminator_field = None 

134 self.discriminator_value_generator = None 

135 

136 return self 

137 

138 @pydantic.model_validator(mode="before") 

139 @classmethod 

140 def _coerce_union_mode(cls, data: ty.Any) -> ty.Any: # noqa: ANN401 

141 """Coerce flat discriminator kwargs into a DiscriminatedConfig. 

142 

143 Allows callers to pass ``discriminator_field`` and 

144 ``discriminator_value_generator`` at the top level and transparently 

145 assembles a ``DiscriminatedConfig`` from them. This avoids an extra 

146 import/nesting layer for the user. 

147 """ 

148 if not isinstance(data, dict): 

149 return data 

150 

151 disc_field = data.get("discriminator_field", None) 

152 has_disc_field = disc_field is not None 

153 union_mode = data.get("union_mode", None) 

154 has_union_mode = union_mode is not None 

155 

156 if has_disc_field and has_union_mode: 

157 msg = ( 

158 "Received both union_mode and discriminator_field; pass one " 

159 "or the other." 

160 ) 

161 raise ValueError(msg) 

162 

163 if has_disc_field and not has_union_mode: 

164 # Forward arguments to DiscriminatedConfig 

165 data["union_mode"] = { 

166 "discriminator_field": disc_field, 

167 "discriminator_value_generator": data.get( 

168 "discriminator_value_generator", 

169 ), 

170 } 

171 elif not has_disc_field and not has_union_mode: 

172 msg = "Either union_mode or discriminator_field must be given" 

173 raise ValueError(msg) 

174 

175 return data 

176 

177 @property 

178 def _discriminated(self) -> DiscriminatedConfig | None: 

179 """Return the DiscriminatedMode config, or None if not discriminated.""" 

180 return ( 

181 self.union_mode 

182 if isinstance(self.union_mode, DiscriminatedConfig) 

183 else None 

184 ) 

185 

186 def load_plugins(self) -> None: 

187 """Load plugins to discover/register additional models""" 

188 if self.plugin_entry_point is None: 

189 return 

190 

191 from importlib.metadata import entry_points # noqa: PLC0415 

192 

193 for ep in entry_points().select(group=self.plugin_entry_point): 

194 plugin = ep.load() 

195 if callable(plugin): 

196 plugin() 

197 

198 @ty.overload 

199 def register(self, value: str | None = None) -> ty.Callable[[type], type]: ... 

200 

201 @ty.overload 

202 def register(self, value: type[pydantic.BaseModel]) -> type[pydantic.BaseModel]: ... 

203 

204 def register( 

205 self, 

206 value: str | type[pydantic.BaseModel] | None = None, 

207 ) -> ty.Callable[[type], type] | type[pydantic.BaseModel]: 

208 """Register a model into this group (decorator) 

209 

210 Parameters 

211 ---------- 

212 value 

213 Value for the discriminator field. If not given, then 

214 discriminator_value_generator must be non-None or the 

215 discriminator field must be declared by hand. Can also be the type 

216 itself to register (if the ()'s are omitted from the decorator). 

217 """ 

218 if isinstance(value, type): 

219 self.register_model(value) 

220 return value 

221 

222 def _wrapper(cls: type[pydantic.BaseModel]) -> type[pydantic.BaseModel]: 

223 self.register_model(cls, value) 

224 return cls 

225 

226 return _wrapper 

227 

228 def register_model( 

229 self, 

230 cls: type[pydantic.BaseModel], 

231 discriminator_value: str | None = None, 

232 ) -> None: 

233 """Register the given model into this group 

234 

235 Parameters 

236 ---------- 

237 cls 

238 The model to register 

239 discriminator_value 

240 Value for the discriminator field. If not given, then 

241 discriminator_value_generator must be non-None or the 

242 discriminator field must be declared by hand. 

243 """ 

244 if discriminator_value is not None and not isinstance(discriminator_value, str): 

245 msg = ( 

246 "discriminator_value must be a str if given, was " 

247 f"{type(discriminator_value).__name__}" 

248 ) 

249 raise RegistrationError(msg) 

250 

251 if not isinstance(cls, type) or not issubclass(cls, pydantic.BaseModel): 

252 msg = ( 

253 "only pydantic BaseModel subclasses can be registered in a " 

254 f"TrackingGroup. Got {cls}, which was not." 

255 ) 

256 raise RegistrationError(msg) 

257 

258 if (dm := self._discriminated) is not None: 

259 disc = dm.discriminator_field 

260 field = cls.model_fields.get(disc) 

261 

262 if field is None: 

263 if discriminator_value is not None: 

264 _inject_discriminator_field(cls, disc, discriminator_value) 

265 elif dm.discriminator_value_generator is not None: 

266 _inject_discriminator_field( 

267 cls, 

268 disc, 

269 dm.discriminator_value_generator(cls), 

270 ) 

271 else: 

272 msg = ( 

273 f"unable to determine a discriminator value for " 

274 f'{cls.__name__} in tracking group "{self.name}". ' 

275 "No value was passed, discriminator_value_generator " 

276 f'was None and the "{disc}" field was not defined.' 

277 ) 

278 raise RegistrationError(msg) 

279 elif ty.get_origin(field.annotation) is not ty.Literal: 

280 msg = ( 

281 f'the discriminator field "{disc}" already existed in ' 

282 f"{cls.__name__}, but its type annotation was " 

283 f"{field.annotation}, not Literal." 

284 ) 

285 raise RegistrationError(msg) 

286 elif ( 

287 discriminator_value is not None and field.default != discriminator_value 

288 ): 

289 msg = ( 

290 f"the discriminator value for {cls.__name__} was " 

291 f'ambiguous, the passed value was "{discriminator_value}" ' 

292 f' and "{field.default}" via the discriminator ' 

293 f"field ({disc})." 

294 ) 

295 raise AmbiguousDiscriminatorValueError(msg) 

296 

297 self._register_with_discriminator_field(cls) 

298 else: 

299 if discriminator_value is not None: 

300 warnings.warn( 

301 f'A discriminator_value of "{discriminator_value}" was ' 

302 f"explicitly passed for {cls.__name__}, but " 

303 f'union_mode="{self.union_mode}" does not use a ' 

304 "discriminator. The value will be ignored.", 

305 stacklevel=2, 

306 ) 

307 self._register_plain(cls) 

308 

309 def union( 

310 self, 

311 *, 

312 plain: bool | None = None, 

313 annotated: bool | None = None, 

314 ) -> ty.Any: # noqa: ANN401 

315 """Return the union of all registered models 

316 

317 Parameters 

318 ---------- 

319 plain 

320 If set to `True`, a plain union of all members will be returned. 

321 Otherwise, the returned union will be annotated in accordance with 

322 the union mode. 

323 annotated 

324 Deprecated. Use `plain=True` when you would have used 

325 `annotated=False`. 

326 

327 Returns 

328 ------- 

329 Any 

330 If there is 1 registered type, the type itself. If there is > 1, a 

331 union of all registered types. This union may be annotated if 

332 `plain` is not `True`. 

333 

334 Raises 

335 ------ 

336 NoRegisteredTypesError 

337 If no types have been registered yet. 

338 """ 

339 if annotated is not None: 

340 warnings.warn( 

341 "The `annotated` parameter is deprecated. Use `plain=True` to " 

342 "get a plain union. By default, behavior is governed by " 

343 "`union_mode`. Will be removed in version 0.5.0.", 

344 DeprecationWarning, 

345 stacklevel=2, 

346 ) 

347 plain = True if not annotated else plain 

348 

349 n = len(self.models) 

350 if n == 0: 

351 msg = ( 

352 "Unable to produce a union from the tracking group " 

353 f'"{self.name}", as no types have been registered yet.' 

354 ) 

355 raise NoRegisteredTypesError(msg) 

356 if n == 1: 

357 return next(iter(self.models.values())) 

358 

359 union_mode = "smart" if plain else self.union_mode 

360 

361 if isinstance(union_mode, DiscriminatedConfig): 

362 return ty.Annotated[ 

363 functools.reduce( 

364 operator.or_, 

365 tuple( 

366 ty.Annotated[x, pydantic.Tag(v)] for v, x in self.models.items() 

367 ), 

368 ), 

369 pydantic.Field(discriminator=union_mode.discriminator_field), 

370 ] 

371 

372 plain_union = functools.reduce(operator.or_, self.models.values()) 

373 if union_mode == "left_to_right": 

374 return ty.Annotated[plain_union, pydantic.Field(union_mode="left_to_right")] 

375 

376 # "smart" mode is pydantic's default behavior on a plain union 

377 return plain_union 

378 

379 @property 

380 def generation(self) -> int: 

381 """The generation of the tracking group. 

382 

383 This is a counter that increments every time a new registration occurs 

384 """ 

385 return self._generation 

386 

387 @property 

388 def type_adapter(self) -> pydantic.TypeAdapter: 

389 """Get the pydantic TypeAdapter for the union of all group members""" 

390 if self.generation != self._adapter_generation: 

391 self._adapter = pydantic.TypeAdapter(self.union()) 

392 self._adapter_generation = self.generation 

393 

394 # casting because the if statement ensures it is non-None (because 

395 # _adapter_generation starts at -1 and generation increments from 0. 

396 return ty.cast("pydantic.TypeAdapter", self._adapter) 

397 

398 def _register_with_discriminator_field(self, cls: type[pydantic.BaseModel]) -> None: 

399 """Register the model with the default of the discriminator field 

400 

401 Parameters 

402 ---------- 

403 cls 

404 The class to register, must have the disciminator field set with a 

405 unique default value in the group. 

406 """ 

407 disc = ty.cast("DiscriminatedConfig", self.union_mode).discriminator_field 

408 value = cls.model_fields[disc].default 

409 if value == pydantic_core.PydanticUndefined: 

410 msg = ( 

411 f"{cls.__name__}.{disc} had no default value, it must " 

412 "have one which is unique among all tracked models." 

413 ) 

414 raise RegistrationError(msg) 

415 if not isinstance(value, str): 

416 msg = ( 

417 f"{cls.__name__}.{disc} had a default value of {value}, which " 

418 f"was of type {type(value).__name__}, not str." 

419 ) 

420 raise RegistrationError(msg) 

421 

422 self._do_register(value, cls) 

423 

424 def _register_plain(self, cls: type[pydantic.BaseModel]) -> None: 

425 """Register the model keyed by its class name. 

426 

427 Used for smart / left_to_right modes where no discriminator field 

428 is involved. 

429 

430 Parameters 

431 ---------- 

432 cls 

433 The model to register. 

434 """ 

435 self._do_register(str(id(cls)), cls) 

436 

437 def _do_register(self, key: str, cls: type[pydantic.BaseModel]) -> None: 

438 """Register the given model under the given key 

439 

440 Parameters 

441 ---------- 

442 key 

443 The key under which to register the model 

444 cls 

445 The model to register. 

446 """ 

447 if (other := self.models.get(key)) is not None: 

448 if other is not cls: 

449 msg = ( 

450 f'Cannot register {cls.__name__} under the "{key}" ' 

451 f"identifier, which is already in use by {other.__name__}." 

452 ) 

453 raise RegistrationError(msg) 

454 else: 

455 self._generation += 1 

456 self.models[key] = cls