Coverage for src / dynapydantic / tracking_group.py: 100%
62 statements
« prev ^ index » next coverage.py v7.12.0, created at 2026-02-06 15:20 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2026-02-06 15:20 +0000
1"""Base class for dynamic pydantic models"""
3import contextlib
4import typing as ty
6import pydantic
7import pydantic.fields
8import pydantic_core
10from .exceptions import AmbiguousDiscriminatorValueError, RegistrationError
13def _inject_discriminator_field(
14 cls: type[pydantic.BaseModel],
15 disc_field: str,
16 value: str,
17) -> pydantic.fields.FieldInfo:
18 """Injects the discriminator field into the given model
20 Parameters
21 ----------
22 cls
23 The BaseModel subclass
24 disc_field
25 Name of the discriminator field
26 value
27 Value of the discriminator field
28 """
29 cls.model_fields[disc_field] = pydantic.fields.FieldInfo(
30 default=value,
31 annotation=ty.Literal[value], # type: ignore[not-a-type]
32 frozen=True,
33 )
34 with contextlib.suppress(pydantic.errors.PydanticUndefinedAnnotation):
35 cls.model_rebuild(force=True)
36 return cls.model_fields[disc_field]
39class TrackingGroup(pydantic.BaseModel):
40 """Tracker for pydantic models"""
42 name: str = pydantic.Field(
43 description=(
44 "Name of the tracking group. This is for human display, so it "
45 "doesn't technically need to be globally unique, but it should be "
46 "meaningfully named, as it will be used in error messages."
47 ),
48 )
49 discriminator_field: str = pydantic.Field(
50 description="Name of the discriminator field",
51 )
52 plugin_entry_point: str | None = pydantic.Field(
53 None,
54 description=(
55 "If given, then plugins packages will be supported through this "
56 "Python entrypoint. The entrypoint can either be a function, "
57 "which will be called, or simply a module, which will be "
58 "imported. In either case, models found along the import path of "
59 "the entrypoint will be registered. If the entrypoint is a "
60 "function, additional models may be declared in the function."
61 ),
62 )
63 discriminator_value_generator: ty.Callable[[type], str] | None = pydantic.Field(
64 None,
65 description=(
66 "A callable that produces default values for the discriminator field"
67 ),
68 )
69 models: dict[str, type[pydantic.BaseModel]] = pydantic.Field(
70 {},
71 description="The tracked models",
72 )
74 def load_plugins(self) -> None:
75 """Load plugins to discover/register additional models"""
76 if self.plugin_entry_point is None:
77 return
79 from importlib.metadata import entry_points # noqa: PLC0415
81 for ep in entry_points().select(group=self.plugin_entry_point):
82 plugin = ep.load()
83 if callable(plugin):
84 plugin()
86 def register(
87 self,
88 discriminator_value: str | None = None,
89 ) -> ty.Callable[[type], type]:
90 """Register a model into this group (decorator)
92 Parameters
93 ----------
94 discriminator_value
95 Value for the discriminator field. If not given, then
96 discriminator_value_generator must be non-None or the
97 discriminator field must be declared by hand.
98 """
100 def _wrapper(cls: type[pydantic.BaseModel]) -> type[pydantic.BaseModel]:
101 disc = self.discriminator_field
102 field = cls.model_fields.get(self.discriminator_field)
103 if field is None:
104 if discriminator_value is not None:
105 _inject_discriminator_field(cls, disc, discriminator_value)
106 elif self.discriminator_value_generator is not None:
107 _inject_discriminator_field(
108 cls,
109 disc,
110 self.discriminator_value_generator(cls),
111 )
112 else:
113 msg = (
114 f"unable to determine a discriminator value for "
115 f'{cls.__name__} in tracking group "{self.name}". No '
116 "value was passed to register(), "
117 "discriminator_value_generator was None and the "
118 f'"{disc}" field was not defined.'
119 )
120 raise RegistrationError(msg)
121 elif (
122 discriminator_value is not None and field.default != discriminator_value
123 ):
124 msg = (
125 f"the discriminator value for {cls.__name__} was "
126 f'ambiguous, it was set to "{discriminator_value}" via '
127 f'register() and "{field.default}" via the discriminator '
128 f"field ({self.discriminator_field})."
129 )
130 raise AmbiguousDiscriminatorValueError(msg)
132 self._register_with_discriminator_field(cls)
133 return cls
135 return _wrapper
137 def register_model(self, cls: type[pydantic.BaseModel]) -> None:
138 """Register the given model into this group
140 Parameters
141 ----------
142 cls
143 The model to register
144 """
145 disc = self.discriminator_field
146 if cls.model_fields.get(self.discriminator_field) is None:
147 if self.discriminator_value_generator is not None:
148 _inject_discriminator_field(
149 cls,
150 disc,
151 self.discriminator_value_generator(cls),
152 )
153 else:
154 msg = (
155 f"unable to determine a discriminator value for "
156 f'{cls.__name__} in tracking group "{self.name}", '
157 "discriminator_value_generator was None and the "
158 f'"{disc}" field was not defined.'
159 )
160 raise RegistrationError(msg)
162 self._register_with_discriminator_field(cls)
164 def _register_with_discriminator_field(self, cls: type[pydantic.BaseModel]) -> None:
165 """Register the model with the default of the discriminator field
167 Parameters
168 ----------
169 cls
170 The class to register, must have the disciminator field set with a
171 unique default value in the group.
172 """
173 disc = self.discriminator_field
174 value = cls.model_fields[disc].default
175 if value == pydantic_core.PydanticUndefined:
176 msg = (
177 f"{cls.__name__}.{disc} had no default value, it must "
178 "have one which is unique among all tracked models."
179 )
180 raise RegistrationError(msg)
182 if (other := self.models.get(value)) is not None and other is not cls:
183 msg = (
184 f'Cannot register {cls.__name__} under the "{value}" '
185 f"identifier, which is already in use by {other.__name__}."
186 )
187 raise RegistrationError(msg)
189 self.models[value] = cls
191 def union(self, *, annotated: bool = True) -> ty.Any: # noqa: ANN401
192 """Return the union of all registered models"""
193 return (
194 ty.Annotated[
195 ty.Union[ # noqa: UP007
196 # This is fundamentally incompatible with static type
197 # checking, as this is resolved at runtime.
198 tuple( # type: ignore[not-a-type]
199 ty.Annotated[x, pydantic.Tag(v)] for v, x in self.models.items()
200 )
201 ],
202 pydantic.Field(discriminator=self.discriminator_field),
203 ]
204 if annotated
205 # type: ignore[not-a-type]
206 else ty.Union[tuple(self.models.values())] # noqa: UP007
207 )