Coverage for src / dynapydantic / subclass_tracking_model.py: 100%
71 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 17:07 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 17:07 +0000
1"""Base class for dynamic pydantic models"""
3import inspect
4import typing as ty
5import warnings
7import pydantic
8from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
9from pydantic.json_schema import JsonSchemaValue
10from pydantic_core import PydanticCustomError, core_schema
12from .exceptions import ConfigurationError, Error
13from .tracking_group import TrackingGroup
16def direct_children_of_base_in_mro(derived: type, base: type) -> list[type]:
17 """Find all classes in derived's MRO that are direct subclasses of base.
19 Parameters
20 ----------
21 derived
22 The class whose MRO is being examined.
23 base
24 The base class to find direct subclasses of.
26 Returns
27 -------
28 Classes in derived's MRO that are direct subclasses of base.
29 """
30 return [cls for cls in derived.__mro__ if cls is not base and base in cls.__bases__]
33class SubclassTrackingModel(pydantic.BaseModel):
34 """Subclass-tracking BaseModel
36 This will inject a [`TrackingGroup`][dynapydantic.TrackingGroup] into your
37 class and automate the registration of subclasses.
39 Similar to `BaseModel`, `SubclassTrackingModel` can take arguments in the
40 class declaration. Arguments from `BaseModel` will be forwarded.
41 Additionally, any fields from `TrackingGroup` will be forwarded to the
42 internal `TrackingGroup` instance. The following additional arguments are
43 supported:
45 1. `exclude_from_union`: This flag is intended to be used with descendents
46 of `SubclassTrackingModel`. If `True`, this subclass will be omitted
47 from tracking.
48 2. `implicit_polymorphic`: This flag is intended to be used with direct
49 descendents of `SubclassTrackingModel`. If `True`, then the core
50 schema of this class will be overridden. This allows polymorphic
51 parsing to occur without the use of
52 [`Polymorphic`][dynapydantic.Polymorphic]. In addition, it is not
53 necessary to call `model_rebuild` on recursive models. This feature
54 is currently **EXPERIMENTAL** and does incur a runtime penalty.
56 **DEPRECATED:**
58 Inheriting from this class will augment your class with the following
59 members functions:
61 1. `registered_subclasses() -> dict[str, type[cls]]`:
62 This will return a mapping of discriminator value to the corresponding
63 subclass. See
64 [`TrackingGroup.models`][dynapydantic.TrackingGroup.models] for details.
65 2. `union() -> typing.Any`:
66 This will return an (optionally) annotated subclass union. See
67 [`TrackingGroup.union()`][dynapydantic.TrackingGroup.union] for details.
68 3. `load_plugins() -> None`:
69 If plugin_entry_point was specified, then this method will load plugin
70 packages to discover additional subclasses. See
71 [`TrackingGroup.load_plugins()`][dynapydantic.TrackingGroup.load_plugins]
72 for more details.
74 These methods will be removed in 0.5.0, please migrate to their
75 corresponding free functions:
77 1. `registered_subclasses()` ->
78 [`registered_models()`][dynapydantic.registered_models]
79 2. `union()` -> [`union()`][dynapydantic.union] or
80 [`Union[T]`][dynapydantic.Union]
81 3. `load_plugins()` -> [`load_plugins()`][dynapydantic.load_plugins]
82 """
84 def __init_subclass__(cls, *args, **kwargs) -> None:
85 """Subclass hook"""
86 # Intercept any kwargs that are intended for TrackingGroup or
87 # __pydantic_init_subclass__
88 sig = inspect.signature(SubclassTrackingModel.__pydantic_init_subclass__)
89 super().__init_subclass__(
90 *args,
91 **{
92 k: v
93 for k, v in kwargs.items()
94 if k not in TrackingGroup.model_fields and k not in sig.parameters
95 },
96 )
98 # This method is too complex, here's the plan to simplify it:
99 # We're polluting this models attributes by injecting and forwarding methods
100 # from tracking group. As a result, we're limiting the possible field names
101 # that these models can have. These should be free functions. We're going
102 # to deprecate the methods to give people a release cycle to migrate off.
103 # We should be able to remove the noqa after these are removed.
104 @classmethod
105 def __pydantic_init_subclass__( # noqa: C901
106 cls,
107 *args,
108 exclude_from_union: bool | None = None,
109 implicit_polymorphic: bool | None = None,
110 **kwargs,
111 ) -> None:
112 """Pydantic subclass hook"""
113 if SubclassTrackingModel in cls.__bases__:
114 # Intercept any kwargs that are intended for TrackingGroup
115 super().__pydantic_init_subclass__(
116 *args,
117 **{
118 k: v
119 for k, v in kwargs.items()
120 if k not in TrackingGroup.model_fields
121 },
122 )
124 cls.__DYNAPYDANTIC_IMPLICIT_POLYMORPHIC__: ty.ClassVar[bool] = (
125 implicit_polymorphic if implicit_polymorphic is not None else False
126 )
128 if isinstance((tc := getattr(cls, "tracking_config", None)), TrackingGroup):
129 cls.__DYNAPYDANTIC__: ty.ClassVar[TrackingGroup] = tc
130 else:
131 try:
132 cls.__DYNAPYDANTIC__: ty.ClassVar[TrackingGroup] = (
133 TrackingGroup.model_validate(
134 {"name": f"{cls.__name__}-subclasses"} | kwargs,
135 )
136 )
137 except pydantic.ValidationError as e:
138 msg = (
139 "SubclassTrackingModel subclasses must either have a "
140 "tracking_config: ClassVar[dynapydantic.TrackingGroup] "
141 "member or pass kwargs sufficient to construct a "
142 "dynapydantic.TrackingGroup in the class declaration. "
143 "The latter approach produced the following "
144 f"ValidationError:\n{e}"
145 )
146 raise ConfigurationError(msg) from e
148 # Promote the tracking group's methods to the parent class
149 if cls.__DYNAPYDANTIC__.plugin_entry_point is not None:
151 def _load_plugins() -> None:
152 """Load plugins to register more models
154 DEPRECATED: use
155 [`dynapydantic.load_plugins`][dynapydantic.load_plugins]
156 """
157 msg = (
158 "SubclassTrackingModel.load_plugins() is deprecated, "
159 "please swap dynapydantic.load_plugins()."
160 )
161 warnings.warn(msg, DeprecationWarning, stacklevel=2)
162 cls.__DYNAPYDANTIC__.load_plugins()
164 cls.load_plugins = staticmethod(_load_plugins)
166 def _union(
167 *,
168 plain: bool | None = None,
169 annotated: bool | None = None,
170 ) -> ty.Any: # noqa: ANN401 - return type is runtime-determined
171 """Get the union of all tracked subclasses
173 DEPRECATED: use [`Union[T]`][dynapydantic.Union] or
174 [`union()`][dynapydantic.union] instead.
176 Parameters
177 ----------
178 plain
179 If set to `True`, a plain union of all members will be returned.
180 Otherwise, the returned union will be annotated in accordance with
181 the union mode.
182 annotated
183 Deprecated. Use `plain=True` when you would have used
184 `annotated=False`.
185 """
186 msg = (
187 "SubclassTrackingModel.union() is deprecated, please swap "
188 "to dynapydantic.Union[T] (for annotations) or "
189 "dynapydantic.union() (for runtime calls)."
190 )
191 warnings.warn(msg, DeprecationWarning, stacklevel=2)
193 # deprecation warning for annotated is in TrackingGroup
194 return cls.__DYNAPYDANTIC__.union(plain=plain, annotated=annotated)
196 cls.union = staticmethod(_union)
198 def _subclasses() -> dict[str, type[pydantic.BaseModel]]:
199 """Return a mapping of discriminator values to registered model
201 DEPRECATED: use dynapydantic.registered_models().
202 """
203 msg = (
204 "SubclassTrackingModel.registered_subclasses() is "
205 "deprecated, please swap to "
206 "dynapydantic.registered_models()."
207 )
208 warnings.warn(msg, DeprecationWarning, stacklevel=2)
210 return cls.__DYNAPYDANTIC__.models
212 cls.registered_subclasses = staticmethod(_subclasses)
214 if implicit_polymorphic:
215 cls.__get_pydantic_core_schema__ = classmethod( # type: ignore[bad-assignment]
216 _get_pydantic_core_schema
217 )
219 cls.__get_pydantic_json_schema__ = classmethod( # type: ignore[bad-assignment]
220 _get_pydantic_json_schema
221 )
223 return
225 super().__pydantic_init_subclass__(*args, **kwargs)
227 if exclude_from_union:
228 return
230 supers = direct_children_of_base_in_mro(cls, SubclassTrackingModel)
231 for base in supers:
232 base.__DYNAPYDANTIC__.register_model(cls)
235def _get_adapter(
236 source_type: type[SubclassTrackingModel],
237) -> pydantic.TypeAdapter:
238 try:
239 return source_type.__DYNAPYDANTIC__.type_adapter
240 except Error as e:
241 err_t = "dynapydantic_error"
242 raise PydanticCustomError(err_t, "{e}", {"e": str(e)}) from e
245def _get_pydantic_core_schema(
246 cls: type[SubclassTrackingModel],
247 source_type: type[pydantic.BaseModel],
248 handler: GetCoreSchemaHandler,
249 /,
250) -> core_schema.CoreSchema:
251 """Get the pydantic core schema for this type"""
252 if SubclassTrackingModel not in cls.__bases__:
253 return handler(source_type)
255 def _validate(value: ty.Any) -> ty.Any: # noqa: ANN401
256 return _get_adapter(
257 ty.cast("type[SubclassTrackingModel]", source_type)
258 ).validate_python(value)
260 def _serialize(
261 value: pydantic.BaseModel,
262 info: core_schema.SerializationInfo,
263 ) -> dict[str, ty.Any]:
264 return value.model_dump(mode=info.mode)
266 return core_schema.no_info_plain_validator_function(
267 _validate,
268 serialization=core_schema.plain_serializer_function_ser_schema(
269 _serialize,
270 info_arg=True,
271 when_used="unless-none",
272 return_schema=core_schema.dict_schema(
273 core_schema.str_schema(), core_schema.any_schema()
274 ),
275 ),
276 )
279def _get_pydantic_json_schema(
280 cls: type[SubclassTrackingModel],
281 cs: core_schema.CoreSchema,
282 handler: GetJsonSchemaHandler,
283 /,
284) -> JsonSchemaValue:
285 """Get the pydantic JSON schema for this type"""
286 if SubclassTrackingModel in cls.__bases__:
287 return handler(_get_adapter(cls).core_schema)
288 return handler(cs)