Coverage for src / dynapydantic / tracking_group.py: 100%
148 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 17:07 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 17:07 +0000
1"""Base class for dynamic pydantic models"""
3import contextlib
4import functools
5import operator
6import typing as ty
7import warnings
9import pydantic
10import pydantic.fields
11import pydantic_core
13from .exceptions import (
14 AmbiguousDiscriminatorValueError,
15 NoRegisteredTypesError,
16 RegistrationError,
17)
18from .union_mode import DiscriminatedConfig, UnionMode
21def _inject_discriminator_field(
22 cls: type[pydantic.BaseModel],
23 disc_field: str,
24 value: str,
25) -> pydantic.fields.FieldInfo:
26 """Injects the discriminator field into the given model
28 Parameters
29 ----------
30 cls
31 The BaseModel subclass
32 disc_field
33 Name of the discriminator field
34 value
35 Value of the discriminator field
36 """
37 if hasattr(cls, disc_field):
38 msg = (
39 f'Cannot inject discriminator field "{disc_field}" into '
40 f"{cls.__name__}: an attribute with that name already exists. "
41 "Rename either the attribute or the discriminator_field to avoid "
42 "the conflict."
43 )
44 raise RegistrationError(msg)
46 cls.model_fields[disc_field] = pydantic.fields.FieldInfo(
47 default=value,
48 annotation=ty.Literal[value], # type: ignore[not-a-type]
49 frozen=True,
50 )
51 with contextlib.suppress(pydantic.errors.PydanticUndefinedAnnotation):
52 cls.model_rebuild(force=True)
53 return cls.model_fields[disc_field]
56class TrackingGroup(pydantic.BaseModel):
57 """Tracker for pydantic models"""
59 name: str = pydantic.Field(
60 description=(
61 "Name of the tracking group. This is for human display, so it "
62 "doesn't technically need to be globally unique, but it should be "
63 "meaningfully named, as it will be used in error messages."
64 ),
65 )
66 union_mode: UnionMode | None = pydantic.Field(
67 None,
68 description=(
69 "Union validation strategy. Pass a DiscriminatedConfig instance "
70 'or one of the plain strings "smart" or "left_to_right". You can '
71 "also just pass the fields for DiscriminatedConfig to this "
72 "model and they will be forwarded."
73 ),
74 )
75 discriminator_field: str | None = pydantic.Field(
76 None,
77 description=(
78 "Name of the discriminator field. NOTE: This field is "
79 "here as an alias for union_mode.discriminator_field. Passing "
80 "both a discriminator_field and a union_mode will result in an "
81 "error."
82 ),
83 )
84 discriminator_value_generator: ty.Callable[[type], str] | None = pydantic.Field(
85 None,
86 description=(
87 "A callable that produces default values for the discriminator field"
88 ),
89 )
90 plugin_entry_point: str | None = pydantic.Field(
91 None,
92 description=(
93 "If given, then plugins packages will be supported through this "
94 "Python entrypoint. The entrypoint can either be a function, "
95 "which will be called, or simply a module, which will be "
96 "imported. In either case, models found along the import path of "
97 "the entrypoint will be registered. If the entrypoint is a "
98 "function, additional models may be declared in the function."
99 ),
100 )
101 models: dict[str, type[pydantic.BaseModel]] = pydantic.Field(
102 {},
103 description="The tracked models",
104 )
106 _generation: int = pydantic.PrivateAttr(default=0)
107 _adapter: pydantic.TypeAdapter | None = pydantic.PrivateAttr(default=None)
108 _adapter_generation: int = pydantic.PrivateAttr(default=-1)
110 @pydantic.model_validator(mode="after")
111 def _ensure_union_mode(self) -> "TrackingGroup":
112 """There must be a union_mode
114 This validator works as a guard on _coerce_union_mode to make
115 """
116 if self.union_mode is None:
117 msg = (
118 "union_mode is required. This normally indicates that you "
119 "subclasses TrackingGroup and wrote an invalid validator, but "
120 "could also be a bug with dynapydantic, so please file a bug "
121 "report with a reproducer on how you got here if you suspect "
122 "a bug."
123 )
124 raise ValueError(msg)
126 # Ensure the top-level fields are in-sync
127 if isinstance(self.union_mode, DiscriminatedConfig):
128 self.discriminator_field = self.union_mode.discriminator_field
129 self.discriminator_value_generator = (
130 self.union_mode.discriminator_value_generator
131 )
132 else:
133 self.discriminator_field = None
134 self.discriminator_value_generator = None
136 return self
138 @pydantic.model_validator(mode="before")
139 @classmethod
140 def _coerce_union_mode(cls, data: ty.Any) -> ty.Any: # noqa: ANN401
141 """Coerce flat discriminator kwargs into a DiscriminatedConfig.
143 Allows callers to pass ``discriminator_field`` and
144 ``discriminator_value_generator`` at the top level and transparently
145 assembles a ``DiscriminatedConfig`` from them. This avoids an extra
146 import/nesting layer for the user.
147 """
148 if not isinstance(data, dict):
149 return data
151 disc_field = data.get("discriminator_field", None)
152 has_disc_field = disc_field is not None
153 union_mode = data.get("union_mode", None)
154 has_union_mode = union_mode is not None
156 if has_disc_field and has_union_mode:
157 msg = (
158 "Received both union_mode and discriminator_field; pass one "
159 "or the other."
160 )
161 raise ValueError(msg)
163 if has_disc_field and not has_union_mode:
164 # Forward arguments to DiscriminatedConfig
165 data["union_mode"] = {
166 "discriminator_field": disc_field,
167 "discriminator_value_generator": data.get(
168 "discriminator_value_generator",
169 ),
170 }
171 elif not has_disc_field and not has_union_mode:
172 msg = "Either union_mode or discriminator_field must be given"
173 raise ValueError(msg)
175 return data
177 @property
178 def _discriminated(self) -> DiscriminatedConfig | None:
179 """Return the DiscriminatedMode config, or None if not discriminated."""
180 return (
181 self.union_mode
182 if isinstance(self.union_mode, DiscriminatedConfig)
183 else None
184 )
186 def load_plugins(self) -> None:
187 """Load plugins to discover/register additional models"""
188 if self.plugin_entry_point is None:
189 return
191 from importlib.metadata import entry_points # noqa: PLC0415
193 for ep in entry_points().select(group=self.plugin_entry_point):
194 plugin = ep.load()
195 if callable(plugin):
196 plugin()
198 @ty.overload
199 def register(self, value: str | None = None) -> ty.Callable[[type], type]: ...
201 @ty.overload
202 def register(self, value: type[pydantic.BaseModel]) -> type[pydantic.BaseModel]: ...
204 def register(
205 self,
206 value: str | type[pydantic.BaseModel] | None = None,
207 ) -> ty.Callable[[type], type] | type[pydantic.BaseModel]:
208 """Register a model into this group (decorator)
210 Parameters
211 ----------
212 value
213 Value for the discriminator field. If not given, then
214 discriminator_value_generator must be non-None or the
215 discriminator field must be declared by hand. Can also be the type
216 itself to register (if the ()'s are omitted from the decorator).
217 """
218 if isinstance(value, type):
219 self.register_model(value)
220 return value
222 def _wrapper(cls: type[pydantic.BaseModel]) -> type[pydantic.BaseModel]:
223 self.register_model(cls, value)
224 return cls
226 return _wrapper
228 def register_model(
229 self,
230 cls: type[pydantic.BaseModel],
231 discriminator_value: str | None = None,
232 ) -> None:
233 """Register the given model into this group
235 Parameters
236 ----------
237 cls
238 The model to register
239 discriminator_value
240 Value for the discriminator field. If not given, then
241 discriminator_value_generator must be non-None or the
242 discriminator field must be declared by hand.
243 """
244 if discriminator_value is not None and not isinstance(discriminator_value, str):
245 msg = (
246 "discriminator_value must be a str if given, was "
247 f"{type(discriminator_value).__name__}"
248 )
249 raise RegistrationError(msg)
251 if not isinstance(cls, type) or not issubclass(cls, pydantic.BaseModel):
252 msg = (
253 "only pydantic BaseModel subclasses can be registered in a "
254 f"TrackingGroup. Got {cls}, which was not."
255 )
256 raise RegistrationError(msg)
258 if (dm := self._discriminated) is not None:
259 disc = dm.discriminator_field
260 field = cls.model_fields.get(disc)
262 if field is None:
263 if discriminator_value is not None:
264 _inject_discriminator_field(cls, disc, discriminator_value)
265 elif dm.discriminator_value_generator is not None:
266 _inject_discriminator_field(
267 cls,
268 disc,
269 dm.discriminator_value_generator(cls),
270 )
271 else:
272 msg = (
273 f"unable to determine a discriminator value for "
274 f'{cls.__name__} in tracking group "{self.name}". '
275 "No value was passed, discriminator_value_generator "
276 f'was None and the "{disc}" field was not defined.'
277 )
278 raise RegistrationError(msg)
279 elif ty.get_origin(field.annotation) is not ty.Literal:
280 msg = (
281 f'the discriminator field "{disc}" already existed in '
282 f"{cls.__name__}, but its type annotation was "
283 f"{field.annotation}, not Literal."
284 )
285 raise RegistrationError(msg)
286 elif (
287 discriminator_value is not None and field.default != discriminator_value
288 ):
289 msg = (
290 f"the discriminator value for {cls.__name__} was "
291 f'ambiguous, the passed value was "{discriminator_value}" '
292 f' and "{field.default}" via the discriminator '
293 f"field ({disc})."
294 )
295 raise AmbiguousDiscriminatorValueError(msg)
297 self._register_with_discriminator_field(cls)
298 else:
299 if discriminator_value is not None:
300 warnings.warn(
301 f'A discriminator_value of "{discriminator_value}" was '
302 f"explicitly passed for {cls.__name__}, but "
303 f'union_mode="{self.union_mode}" does not use a '
304 "discriminator. The value will be ignored.",
305 stacklevel=2,
306 )
307 self._register_plain(cls)
309 def union(
310 self,
311 *,
312 plain: bool | None = None,
313 annotated: bool | None = None,
314 ) -> ty.Any: # noqa: ANN401
315 """Return the union of all registered models
317 Parameters
318 ----------
319 plain
320 If set to `True`, a plain union of all members will be returned.
321 Otherwise, the returned union will be annotated in accordance with
322 the union mode.
323 annotated
324 Deprecated. Use `plain=True` when you would have used
325 `annotated=False`.
327 Returns
328 -------
329 Any
330 If there is 1 registered type, the type itself. If there is > 1, a
331 union of all registered types. This union may be annotated if
332 `plain` is not `True`.
334 Raises
335 ------
336 NoRegisteredTypesError
337 If no types have been registered yet.
338 """
339 if annotated is not None:
340 warnings.warn(
341 "The `annotated` parameter is deprecated. Use `plain=True` to "
342 "get a plain union. By default, behavior is governed by "
343 "`union_mode`. Will be removed in version 0.5.0.",
344 DeprecationWarning,
345 stacklevel=2,
346 )
347 plain = True if not annotated else plain
349 n = len(self.models)
350 if n == 0:
351 msg = (
352 "Unable to produce a union from the tracking group "
353 f'"{self.name}", as no types have been registered yet.'
354 )
355 raise NoRegisteredTypesError(msg)
356 if n == 1:
357 return next(iter(self.models.values()))
359 union_mode = "smart" if plain else self.union_mode
361 if isinstance(union_mode, DiscriminatedConfig):
362 return ty.Annotated[
363 functools.reduce(
364 operator.or_,
365 tuple(
366 ty.Annotated[x, pydantic.Tag(v)] for v, x in self.models.items()
367 ),
368 ),
369 pydantic.Field(discriminator=union_mode.discriminator_field),
370 ]
372 plain_union = functools.reduce(operator.or_, self.models.values())
373 if union_mode == "left_to_right":
374 return ty.Annotated[plain_union, pydantic.Field(union_mode="left_to_right")]
376 # "smart" mode is pydantic's default behavior on a plain union
377 return plain_union
379 @property
380 def generation(self) -> int:
381 """The generation of the tracking group.
383 This is a counter that increments every time a new registration occurs
384 """
385 return self._generation
387 @property
388 def type_adapter(self) -> pydantic.TypeAdapter:
389 """Get the pydantic TypeAdapter for the union of all group members"""
390 if self.generation != self._adapter_generation:
391 self._adapter = pydantic.TypeAdapter(self.union())
392 self._adapter_generation = self.generation
394 # casting because the if statement ensures it is non-None (because
395 # _adapter_generation starts at -1 and generation increments from 0.
396 return ty.cast("pydantic.TypeAdapter", self._adapter)
398 def _register_with_discriminator_field(self, cls: type[pydantic.BaseModel]) -> None:
399 """Register the model with the default of the discriminator field
401 Parameters
402 ----------
403 cls
404 The class to register, must have the disciminator field set with a
405 unique default value in the group.
406 """
407 disc = ty.cast("DiscriminatedConfig", self.union_mode).discriminator_field
408 value = cls.model_fields[disc].default
409 if value == pydantic_core.PydanticUndefined:
410 msg = (
411 f"{cls.__name__}.{disc} had no default value, it must "
412 "have one which is unique among all tracked models."
413 )
414 raise RegistrationError(msg)
415 if not isinstance(value, str):
416 msg = (
417 f"{cls.__name__}.{disc} had a default value of {value}, which "
418 f"was of type {type(value).__name__}, not str."
419 )
420 raise RegistrationError(msg)
422 self._do_register(value, cls)
424 def _register_plain(self, cls: type[pydantic.BaseModel]) -> None:
425 """Register the model keyed by its class name.
427 Used for smart / left_to_right modes where no discriminator field
428 is involved.
430 Parameters
431 ----------
432 cls
433 The model to register.
434 """
435 self._do_register(str(id(cls)), cls)
437 def _do_register(self, key: str, cls: type[pydantic.BaseModel]) -> None:
438 """Register the given model under the given key
440 Parameters
441 ----------
442 key
443 The key under which to register the model
444 cls
445 The model to register.
446 """
447 if (other := self.models.get(key)) is not None:
448 if other is not cls:
449 msg = (
450 f'Cannot register {cls.__name__} under the "{key}" '
451 f"identifier, which is already in use by {other.__name__}."
452 )
453 raise RegistrationError(msg)
454 else:
455 self._generation += 1
456 self.models[key] = cls