Coverage for src / dynapydantic / tracking_group.py: 100%
147 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-13 20:14 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-13 20:14 +0000
1"""Base class for dynamic pydantic models"""
3import contextlib
4import functools
5import operator
6import typing as ty
7import warnings
9import pydantic
10import pydantic.fields
11import pydantic_core
13from .exceptions import (
14 AmbiguousDiscriminatorValueError,
15 NoRegisteredTypesError,
16 RegistrationError,
17)
18from .union_mode import DiscriminatedConfig, UnionMode
21def _inject_discriminator_field(
22 cls: type[pydantic.BaseModel],
23 disc_field: str,
24 value: str,
25) -> pydantic.fields.FieldInfo:
26 """Injects the discriminator field into the given model
28 Parameters
29 ----------
30 cls
31 The BaseModel subclass
32 disc_field
33 Name of the discriminator field
34 value
35 Value of the discriminator field
36 """
37 if hasattr(cls, disc_field):
38 msg = (
39 f'Cannot inject discriminator field "{disc_field}" into '
40 f"{cls.__name__}: an attribute with that name already exists. "
41 "Rename either the attribute or the discriminator_field to avoid "
42 "the conflict."
43 )
44 raise RegistrationError(msg)
46 cls.model_fields[disc_field] = pydantic.fields.FieldInfo(
47 default=value,
48 annotation=ty.Literal[value], # type: ignore[not-a-type]
49 frozen=True,
50 )
51 with contextlib.suppress(pydantic.errors.PydanticUndefinedAnnotation):
52 cls.model_rebuild(force=True)
53 return cls.model_fields[disc_field]
56class TrackingGroup(pydantic.BaseModel):
57 """Tracker for pydantic models"""
59 name: str = pydantic.Field(
60 description=(
61 "Name of the tracking group. This is for human display, so it "
62 "doesn't technically need to be globally unique, but it should be "
63 "meaningfully named, as it will be used in error messages."
64 ),
65 )
66 union_mode: UnionMode | None = pydantic.Field(
67 None,
68 description=(
69 "Union validation strategy. Pass a DiscriminatedConfig instance "
70 'or one of the plain strings "smart" or "left_to_right". You can '
71 "also just pass the fields for DiscriminatedConfig to this "
72 "model and they will be forwarded."
73 ),
74 )
75 discriminator_field: str | None = pydantic.Field(
76 None,
77 description=(
78 "Name of the discriminator field. NOTE: This field is "
79 "here as an alias for union_mode.discriminator_field. Passing "
80 "both a discriminator_field and a union_mode will result in an "
81 "error."
82 ),
83 )
84 discriminator_value_generator: ty.Callable[[type], str] | None = pydantic.Field(
85 None,
86 description=(
87 "A callable that produces default values for the discriminator field"
88 ),
89 )
90 plugin_entry_point: str | None = pydantic.Field(
91 None,
92 description=(
93 "If given, then plugins packages will be supported through this "
94 "Python entrypoint. The entrypoint can either be a function, "
95 "which will be called, or simply a module, which will be "
96 "imported. In either case, models found along the import path of "
97 "the entrypoint will be registered. If the entrypoint is a "
98 "function, additional models may be declared in the function."
99 ),
100 )
101 models: dict[str, type[pydantic.BaseModel]] = pydantic.Field(
102 {},
103 description="The tracked models",
104 )
106 _generation: int = pydantic.PrivateAttr(default=0)
107 _adapter: pydantic.TypeAdapter | None = pydantic.PrivateAttr(default=None)
108 _adapter_generation: int = pydantic.PrivateAttr(default=-1)
110 @pydantic.model_validator(mode="after")
111 def _ensure_union_mode(self) -> "TrackingGroup":
112 """There must be a union_mode
114 This validator works as a guard on _coerce_union_mode to make
115 """
116 if self.union_mode is None:
117 msg = (
118 "union_mode is required. This normally indicates that you "
119 "subclasses TrackingGroup and wrote an invalid validator, but "
120 "could also be a bug with dynapydantic, so please file a bug "
121 "report with a reproducer on how you got here if you suspect "
122 "a bug."
123 )
124 raise ValueError(msg)
126 # Ensure the top-level fields are in-sync
127 if isinstance(self.union_mode, DiscriminatedConfig):
128 self.discriminator_field = self.union_mode.discriminator_field
129 self.discriminator_value_generator = (
130 self.union_mode.discriminator_value_generator
131 )
132 else:
133 self.discriminator_field = None
134 self.discriminator_value_generator = None
136 return self
138 @pydantic.model_validator(mode="before")
139 @classmethod
140 def _coerce_union_mode(cls, data: ty.Any) -> ty.Any: # noqa: ANN401
141 """Coerce flat discriminator kwargs into a DiscriminatedConfig.
143 Allows callers to pass ``discriminator_field`` and
144 ``discriminator_value_generator`` at the top level and transparently
145 assembles a ``DiscriminatedConfig`` from them. This avoids an extra
146 import/nesting layer for the user.
147 """
148 if not isinstance(data, dict):
149 return data
151 disc_field = data.get("discriminator_field", None)
152 has_disc_field = disc_field is not None
153 union_mode = data.get("union_mode", None)
154 has_union_mode = union_mode is not None
156 # If the user passed us both a discriminator field and a union_mode,
157 # things must be perfectly consistent
158 if has_disc_field and has_union_mode:
159 consistent = (
160 isinstance(union_mode, DiscriminatedConfig)
161 and disc_field == union_mode.discriminator_field
162 and data.get("discriminator_value_generator")
163 is union_mode.discriminator_value_generator
164 ) or (
165 isinstance(union_mode, dict)
166 and disc_field == union_mode.get("discriminator_field")
167 and data.get("discriminator_value_generator")
168 is union_mode.get("discriminator_value_generator")
169 )
170 if not consistent:
171 msg = (
172 "Received both union_mode and discriminator_field; pass one "
173 "or the other."
174 )
175 raise ValueError(msg)
177 if has_disc_field and not has_union_mode:
178 # Forward arguments to DiscriminatedConfig
179 data["union_mode"] = {
180 "discriminator_field": disc_field,
181 "discriminator_value_generator": data.get(
182 "discriminator_value_generator",
183 ),
184 }
185 elif not has_disc_field and not has_union_mode:
186 msg = "Either union_mode or discriminator_field must be given"
187 raise ValueError(msg)
189 return data
191 @property
192 def _discriminated(self) -> DiscriminatedConfig | None:
193 """Return the DiscriminatedMode config, or None if not discriminated."""
194 return (
195 self.union_mode
196 if isinstance(self.union_mode, DiscriminatedConfig)
197 else None
198 )
200 def load_plugins(self) -> None:
201 """Load plugins to discover/register additional models"""
202 if self.plugin_entry_point is None:
203 return
205 from importlib.metadata import entry_points # noqa: PLC0415
207 for ep in entry_points().select(group=self.plugin_entry_point):
208 plugin = ep.load()
209 if callable(plugin):
210 plugin()
212 @ty.overload
213 def register(self, value: str | None = None) -> ty.Callable[[type], type]: ...
215 @ty.overload
216 def register(self, value: type[pydantic.BaseModel]) -> type[pydantic.BaseModel]: ...
218 def register(
219 self,
220 value: str | type[pydantic.BaseModel] | None = None,
221 ) -> ty.Callable[[type], type] | type[pydantic.BaseModel]:
222 """Register a model into this group (decorator)
224 Parameters
225 ----------
226 value
227 Value for the discriminator field. If not given, then
228 discriminator_value_generator must be non-None or the
229 discriminator field must be declared by hand. Can also be the type
230 itself to register (if the ()'s are omitted from the decorator).
231 """
232 if isinstance(value, type):
233 self.register_model(value)
234 return value
236 def _wrapper(cls: type[pydantic.BaseModel]) -> type[pydantic.BaseModel]:
237 self.register_model(cls, value)
238 return cls
240 return _wrapper
242 def register_model(
243 self,
244 cls: type[pydantic.BaseModel],
245 discriminator_value: str | None = None,
246 ) -> None:
247 """Register the given model into this group
249 Parameters
250 ----------
251 cls
252 The model to register
253 discriminator_value
254 Value for the discriminator field. If not given, then
255 discriminator_value_generator must be non-None or the
256 discriminator field must be declared by hand.
257 """
258 if discriminator_value is not None and not isinstance(discriminator_value, str):
259 msg = (
260 "discriminator_value must be a str if given, was "
261 f"{type(discriminator_value).__name__}"
262 )
263 raise RegistrationError(msg)
265 if not isinstance(cls, type) or not issubclass(cls, pydantic.BaseModel):
266 msg = (
267 "only pydantic BaseModel subclasses can be registered in a "
268 f"TrackingGroup. Got {cls}, which was not."
269 )
270 raise RegistrationError(msg)
272 if (dm := self._discriminated) is not None:
273 disc = dm.discriminator_field
274 field = cls.model_fields.get(disc)
276 if field is None:
277 if discriminator_value is not None:
278 _inject_discriminator_field(cls, disc, discriminator_value)
279 elif dm.discriminator_value_generator is not None:
280 _inject_discriminator_field(
281 cls,
282 disc,
283 dm.discriminator_value_generator(cls),
284 )
285 else:
286 msg = (
287 f"unable to determine a discriminator value for "
288 f'{cls.__name__} in tracking group "{self.name}". '
289 "No value was passed, discriminator_value_generator "
290 f'was None and the "{disc}" field was not defined.'
291 )
292 raise RegistrationError(msg)
293 elif ty.get_origin(field.annotation) is not ty.Literal:
294 msg = (
295 f'the discriminator field "{disc}" already existed in '
296 f"{cls.__name__}, but its type annotation was "
297 f"{field.annotation}, not Literal."
298 )
299 raise RegistrationError(msg)
300 elif (
301 discriminator_value is not None and field.default != discriminator_value
302 ):
303 msg = (
304 f"the discriminator value for {cls.__name__} was "
305 f'ambiguous, the passed value was "{discriminator_value}" '
306 f' and "{field.default}" via the discriminator '
307 f"field ({disc})."
308 )
309 raise AmbiguousDiscriminatorValueError(msg)
311 self._register_with_discriminator_field(cls)
312 else:
313 if discriminator_value is not None:
314 warnings.warn(
315 f'A discriminator_value of "{discriminator_value}" was '
316 f"explicitly passed for {cls.__name__}, but "
317 f'union_mode="{self.union_mode}" does not use a '
318 "discriminator. The value will be ignored.",
319 stacklevel=2,
320 )
321 self._register_plain(cls)
323 def union(
324 self,
325 *,
326 plain: bool | None = None,
327 ) -> ty.Any: # noqa: ANN401
328 """Return the union of all registered models
330 Parameters
331 ----------
332 plain
333 If set to `True`, a plain union of all members will be returned.
334 Otherwise, the returned union will be annotated in accordance with
335 the union mode.
337 Returns
338 -------
339 Any
340 If there is 1 registered type, the type itself. If there is > 1, a
341 union of all registered types. This union may be annotated if
342 `plain` is not `True`.
344 Raises
345 ------
346 NoRegisteredTypesError
347 If no types have been registered yet.
348 """
349 n = len(self.models)
350 if n == 0:
351 msg = (
352 "Unable to produce a union from the tracking group "
353 f'"{self.name}", as no types have been registered yet.'
354 )
355 raise NoRegisteredTypesError(msg)
356 if n == 1:
357 return next(iter(self.models.values()))
359 union_mode = "smart" if plain else self.union_mode
361 if isinstance(union_mode, DiscriminatedConfig):
362 return ty.Annotated[
363 functools.reduce(
364 operator.or_,
365 tuple(
366 ty.Annotated[x, pydantic.Tag(v)] for v, x in self.models.items()
367 ),
368 ),
369 pydantic.Field(discriminator=union_mode.discriminator_field),
370 ]
372 plain_union = functools.reduce(operator.or_, self.models.values())
373 if union_mode == "left_to_right":
374 return ty.Annotated[plain_union, pydantic.Field(union_mode="left_to_right")]
376 # "smart" mode is pydantic's default behavior on a plain union
377 return plain_union
379 @property
380 def generation(self) -> int:
381 """The generation of the tracking group.
383 This is a counter that increments every time a new registration occurs
384 """
385 return self._generation
387 @property
388 def type_adapter(self) -> pydantic.TypeAdapter:
389 """Get the pydantic TypeAdapter for the union of all group members"""
390 if self.generation != self._adapter_generation:
391 self._adapter = pydantic.TypeAdapter(self.union())
392 self._adapter_generation = self.generation
394 # casting because the if statement ensures it is non-None (because
395 # _adapter_generation starts at -1 and generation increments from 0.
396 return ty.cast("pydantic.TypeAdapter", self._adapter)
398 def _register_with_discriminator_field(self, cls: type[pydantic.BaseModel]) -> None:
399 """Register the model with the default of the discriminator field
401 Parameters
402 ----------
403 cls
404 The class to register, must have the disciminator field set with a
405 unique default value in the group.
406 """
407 disc = ty.cast("DiscriminatedConfig", self.union_mode).discriminator_field
408 value = cls.model_fields[disc].default
409 if value == pydantic_core.PydanticUndefined:
410 msg = (
411 f"{cls.__name__}.{disc} had no default value, it must "
412 "have one which is unique among all tracked models."
413 )
414 raise RegistrationError(msg)
415 if not isinstance(value, str):
416 msg = (
417 f"{cls.__name__}.{disc} had a default value of {value}, which "
418 f"was of type {type(value).__name__}, not str."
419 )
420 raise RegistrationError(msg)
422 self._do_register(value, cls)
424 def _register_plain(self, cls: type[pydantic.BaseModel]) -> None:
425 """Register the model keyed by its class name.
427 Used for smart / left_to_right modes where no discriminator field
428 is involved.
430 Parameters
431 ----------
432 cls
433 The model to register.
434 """
435 self._do_register(str(id(cls)), cls)
437 def _do_register(self, key: str, cls: type[pydantic.BaseModel]) -> None:
438 """Register the given model under the given key
440 Parameters
441 ----------
442 key
443 The key under which to register the model
444 cls
445 The model to register.
446 """
447 if (other := self.models.get(key)) is not None:
448 if other is not cls:
449 msg = (
450 f'Cannot register {cls.__name__} under the "{key}" '
451 f"identifier, which is already in use by {other.__name__}."
452 )
453 raise RegistrationError(msg)
454 else:
455 self._generation += 1
456 self.models[key] = cls