Coverage for src / dynapydantic / tracking_group.py: 100%
136 statements
« prev ^ index » next coverage.py v7.12.0, created at 2026-04-06 21:55 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2026-04-06 21:55 +0000
1"""Base class for dynamic pydantic models"""
3import contextlib
4import functools
5import operator
6import typing as ty
7import warnings
9import pydantic
10import pydantic.fields
11import pydantic_core
13from .exceptions import (
14 AmbiguousDiscriminatorValueError,
15 NoRegisteredTypesError,
16 RegistrationError,
17)
18from .union_mode import DiscriminatedConfig, UnionMode
21def _inject_discriminator_field(
22 cls: type[pydantic.BaseModel],
23 disc_field: str,
24 value: str,
25) -> pydantic.fields.FieldInfo:
26 """Injects the discriminator field into the given model
28 Parameters
29 ----------
30 cls
31 The BaseModel subclass
32 disc_field
33 Name of the discriminator field
34 value
35 Value of the discriminator field
36 """
37 if hasattr(cls, disc_field):
38 msg = (
39 f'Cannot inject discriminator field "{disc_field}" into '
40 f"{cls.__name__}: an attribute with that name already exists. "
41 "Rename either the attribute or the discriminator_field to avoid "
42 "the conflict."
43 )
44 raise RegistrationError(msg)
46 cls.model_fields[disc_field] = pydantic.fields.FieldInfo(
47 default=value,
48 annotation=ty.Literal[value], # type: ignore[not-a-type]
49 frozen=True,
50 )
51 with contextlib.suppress(pydantic.errors.PydanticUndefinedAnnotation):
52 cls.model_rebuild(force=True)
53 return cls.model_fields[disc_field]
56class TrackingGroup(pydantic.BaseModel):
57 """Tracker for pydantic models"""
59 name: str = pydantic.Field(
60 description=(
61 "Name of the tracking group. This is for human display, so it "
62 "doesn't technically need to be globally unique, but it should be "
63 "meaningfully named, as it will be used in error messages."
64 ),
65 )
66 union_mode: UnionMode | None = pydantic.Field(
67 None,
68 description=(
69 "Union validation strategy. Pass a DiscriminatedConfig instance "
70 'or one of the plain strings "smart" or "left_to_right". You can '
71 "also just pass the fields for DiscriminatedConfig to this "
72 "model and they will be forwarded."
73 ),
74 )
75 discriminator_field: str | None = pydantic.Field(
76 None,
77 description=(
78 "Name of the discriminator field. NOTE: This field is "
79 "here as an alias for union_mode.discriminator_field. Passing "
80 "both a discriminator_field and a union_mode will result in an "
81 "error."
82 ),
83 )
84 discriminator_value_generator: ty.Callable[[type], str] | None = pydantic.Field(
85 None,
86 description=(
87 "A callable that produces default values for the discriminator field"
88 ),
89 )
90 plugin_entry_point: str | None = pydantic.Field(
91 None,
92 description=(
93 "If given, then plugins packages will be supported through this "
94 "Python entrypoint. The entrypoint can either be a function, "
95 "which will be called, or simply a module, which will be "
96 "imported. In either case, models found along the import path of "
97 "the entrypoint will be registered. If the entrypoint is a "
98 "function, additional models may be declared in the function."
99 ),
100 )
101 models: dict[str, type[pydantic.BaseModel]] = pydantic.Field(
102 {},
103 description="The tracked models",
104 )
106 @pydantic.model_validator(mode="after")
107 def _ensure_union_mode(self) -> "TrackingGroup":
108 """There must be a union_mode
110 This validator works as a guard on _coerce_union_mode to make
111 """
112 if self.union_mode is None:
113 msg = (
114 "union_mode is required. This normally indicates that you "
115 "subclasses TrackingGroup and wrote an invalid validator, but "
116 "could also be a bug with dynapydantic, so please file a bug "
117 "report with a reproducer on how you got here if you suspect "
118 "a bug."
119 )
120 raise ValueError(msg)
122 # Ensure the top-level fields are in-sync
123 if isinstance(self.union_mode, DiscriminatedConfig):
124 self.discriminator_field = self.union_mode.discriminator_field
125 self.discriminator_value_generator = (
126 self.union_mode.discriminator_value_generator
127 )
128 else:
129 self.discriminator_field = None
130 self.discriminator_value_generator = None
132 return self
134 @pydantic.model_validator(mode="before")
135 @classmethod
136 def _coerce_union_mode(cls, data: ty.Any) -> ty.Any: # noqa: ANN401
137 """Coerce flat discriminator kwargs into a DiscriminatedConfig.
139 Allows callers to pass ``discriminator_field`` and
140 ``discriminator_value_generator`` at the top level and transparently
141 assembles a ``DiscriminatedConfig`` from them. This avoids an extra
142 import/nesting layer for the user.
143 """
144 if not isinstance(data, dict):
145 return data
147 disc_field = data.get("discriminator_field", None)
148 has_disc_field = disc_field is not None
149 union_mode = data.get("union_mode", None)
150 has_union_mode = union_mode is not None
152 if has_disc_field and has_union_mode:
153 msg = (
154 "Received both union_mode and discriminator_field; pass one "
155 "or the other."
156 )
157 raise ValueError(msg)
159 if has_disc_field and not has_union_mode:
160 # Forward arguments to DiscriminatedConfig
161 data["union_mode"] = {
162 "discriminator_field": disc_field,
163 "discriminator_value_generator": data.get(
164 "discriminator_value_generator",
165 ),
166 }
167 elif not has_disc_field and not has_union_mode:
168 msg = "Either union_mode or discriminator_field must be given"
169 raise ValueError(msg)
171 return data
173 @property
174 def _discriminated(self) -> DiscriminatedConfig | None:
175 """Return the DiscriminatedMode config, or None if not discriminated."""
176 return (
177 self.union_mode
178 if isinstance(self.union_mode, DiscriminatedConfig)
179 else None
180 )
182 def load_plugins(self) -> None:
183 """Load plugins to discover/register additional models"""
184 if self.plugin_entry_point is None:
185 return
187 from importlib.metadata import entry_points # noqa: PLC0415
189 for ep in entry_points().select(group=self.plugin_entry_point):
190 plugin = ep.load()
191 if callable(plugin):
192 plugin()
194 @ty.overload
195 def register(self, value: str | None = None) -> ty.Callable[[type], type]: ...
197 @ty.overload
198 def register(self, value: type[pydantic.BaseModel]) -> type[pydantic.BaseModel]: ...
200 def register(
201 self,
202 value: str | type[pydantic.BaseModel] | None = None,
203 ) -> ty.Callable[[type], type] | type[pydantic.BaseModel]:
204 """Register a model into this group (decorator)
206 Parameters
207 ----------
208 value
209 Value for the discriminator field. If not given, then
210 discriminator_value_generator must be non-None or the
211 discriminator field must be declared by hand. Can also be the type
212 itself to register (if the ()'s are omitted from the decorator).
213 """
214 if isinstance(value, type):
215 self.register_model(value)
216 return value
218 def _wrapper(cls: type[pydantic.BaseModel]) -> type[pydantic.BaseModel]:
219 self.register_model(cls, ty.cast("str | None", value))
220 return cls
222 return _wrapper
224 def register_model(
225 self,
226 cls: type[pydantic.BaseModel],
227 discriminator_value: str | None = None,
228 ) -> None:
229 """Register the given model into this group
231 Parameters
232 ----------
233 cls
234 The model to register
235 discriminator_value
236 Value for the discriminator field. If not given, then
237 discriminator_value_generator must be non-None or the
238 discriminator field must be declared by hand.
239 """
240 if discriminator_value is not None and not isinstance(discriminator_value, str):
241 msg = (
242 "discriminator_value must be a str if given, was "
243 f"{type(discriminator_value).__name__}"
244 )
245 raise RegistrationError(msg)
247 if not isinstance(cls, type) or not issubclass(cls, pydantic.BaseModel):
248 msg = (
249 "only pydantic BaseModel subclasses can be registered in a "
250 f"TrackingGroup. Got {cls}, which was not."
251 )
252 raise RegistrationError(msg)
254 if (dm := self._discriminated) is not None:
255 disc = dm.discriminator_field
256 field = cls.model_fields.get(disc)
258 if field is None:
259 if discriminator_value is not None:
260 _inject_discriminator_field(cls, disc, discriminator_value)
261 elif dm.discriminator_value_generator is not None:
262 _inject_discriminator_field(
263 cls,
264 disc,
265 dm.discriminator_value_generator(cls),
266 )
267 else:
268 msg = (
269 f"unable to determine a discriminator value for "
270 f'{cls.__name__} in tracking group "{self.name}". '
271 "No value was passed, discriminator_value_generator "
272 f'was None and the "{disc}" field was not defined.'
273 )
274 raise RegistrationError(msg)
275 elif ty.get_origin(field.annotation) is not ty.Literal:
276 msg = (
277 f'the discriminator field "{disc}" already existed in '
278 f"{cls.__name__}, but its type annotation was "
279 f"{field.annotation}, not Literal."
280 )
281 raise RegistrationError(msg)
282 elif (
283 discriminator_value is not None and field.default != discriminator_value
284 ):
285 msg = (
286 f"the discriminator value for {cls.__name__} was "
287 f'ambiguous, the passed value was "{discriminator_value}" '
288 f' and "{field.default}" via the discriminator '
289 f"field ({disc})."
290 )
291 raise AmbiguousDiscriminatorValueError(msg)
293 self._register_with_discriminator_field(cls)
294 else:
295 if discriminator_value is not None:
296 warnings.warn(
297 f'A discriminator_value of "{discriminator_value}" was '
298 f"explicitly passed for {cls.__name__}, but "
299 f'union_mode="{self.union_mode}" does not use a '
300 "discriminator. The value will be ignored.",
301 stacklevel=2,
302 )
303 self._register_plain(cls)
305 def union(
306 self,
307 *,
308 plain: bool | None = None,
309 annotated: bool | None = None,
310 ) -> ty.Any: # noqa: ANN401
311 """Return the union of all registered models
313 Parameters
314 ----------
315 plain
316 If set to `True`, a plain union of all members will be returned.
317 Otherwise, the returned union will be annotated in accordance with
318 the union mode.
319 annotated
320 Deprecated. Use `plain=True` when you would have used
321 `annotated=False`.
323 Returns
324 -------
325 Any
326 If there is 1 registered type, the type itself. If there is > 1, a
327 union of all registered types. This union may be annotated if
328 `plain` is not `True`.
330 Raises
331 ------
332 NoRegisteredTypesError
333 If no types have been registered yet.
334 """
335 if annotated is not None:
336 warnings.warn(
337 "The `annotated` parameter is deprecated. Use `plain=True` to "
338 "get a plain union. By default, behavior is governed by "
339 "`union_mode`. Will be removed in a future version.",
340 DeprecationWarning,
341 stacklevel=2,
342 )
343 plain = True if not annotated else plain
345 n = len(self.models)
346 if n == 0:
347 msg = (
348 "Unable to produce a union from the tracking group "
349 f'"{self.name}", as no types have been registered yet.'
350 )
351 raise NoRegisteredTypesError(msg)
352 if n == 1:
353 return next(iter(self.models.values()))
355 union_mode = "smart" if plain else self.union_mode
357 if isinstance(union_mode, DiscriminatedConfig):
358 return ty.Annotated[
359 functools.reduce(
360 operator.or_,
361 tuple(
362 ty.Annotated[x, pydantic.Tag(v)] for v, x in self.models.items()
363 ),
364 ),
365 pydantic.Field(discriminator=union_mode.discriminator_field),
366 ]
368 plain_union = functools.reduce(operator.or_, self.models.values())
369 if union_mode == "left_to_right":
370 return ty.Annotated[plain_union, pydantic.Field(union_mode="left_to_right")]
372 # "smart" mode is pydantic's default behavior on a plain union
373 return plain_union
375 def _register_with_discriminator_field(self, cls: type[pydantic.BaseModel]) -> None:
376 """Register the model with the default of the discriminator field
378 Parameters
379 ----------
380 cls
381 The class to register, must have the disciminator field set with a
382 unique default value in the group.
383 """
384 disc = ty.cast("DiscriminatedConfig", self.union_mode).discriminator_field
385 value = cls.model_fields[disc].default
386 if value == pydantic_core.PydanticUndefined:
387 msg = (
388 f"{cls.__name__}.{disc} had no default value, it must "
389 "have one which is unique among all tracked models."
390 )
391 raise RegistrationError(msg)
392 if not isinstance(value, str):
393 msg = (
394 f"{cls.__name__}.{disc} had a default value of {value}, which "
395 f"was of type {type(value).__name__}, not str."
396 )
397 raise RegistrationError(msg)
399 if (other := self.models.get(value)) is not None and other is not cls:
400 msg = (
401 f'Cannot register {cls.__name__} under the "{value}" '
402 f"identifier, which is already in use by {other.__name__}."
403 )
404 raise RegistrationError(msg)
406 self.models[value] = cls
408 def _register_plain(self, cls: type[pydantic.BaseModel]) -> None:
409 """Register the model keyed by its class name.
411 Used for smart / left_to_right modes where no discriminator field
412 is involved.
414 Parameters
415 ----------
416 cls
417 The model to register.
418 """
419 key = str(id(cls))
420 if (other := self.models.get(key)) is not None and other is not cls:
421 msg = (
422 f'Cannot register {cls.__name__} under the "{key}" '
423 f"identifier, which is already in use by {other.__name__}."
424 )
425 raise RegistrationError(msg)
426 self.models[key] = cls