Coverage for src / dynapydantic / subclass_tracking_model.py: 100%
49 statements
« prev ^ index » next coverage.py v7.12.0, created at 2026-04-06 21:55 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2026-04-06 21:55 +0000
1"""Base class for dynamic pydantic models"""
3import inspect
4import typing as ty
6import pydantic
7from pydantic import GetCoreSchemaHandler
8from pydantic.errors import PydanticSchemaGenerationError
9from pydantic_core import core_schema
11from .exceptions import ConfigurationError
12from .tracking_group import TrackingGroup
15def direct_children_of_base_in_mro(derived: type, base: type) -> list[type]:
16 """Find all classes in derived's MRO that are direct subclasses of base.
18 Parameters
19 ----------
20 derived
21 The class whose MRO is being examined.
22 base
23 The base class to find direct subclasses of.
25 Returns
26 -------
27 Classes in derived's MRO that are direct subclasses of base.
28 """
29 return [cls for cls in derived.__mro__ if cls is not base and base in cls.__bases__]
32class SubclassTrackingModel(pydantic.BaseModel):
33 """Subclass-tracking BaseModel
35 This will inject a [`TrackingGroup`][dynapydantic.TrackingGroup] into your
36 class and automate the registration of subclasses.
38 Inheriting from this class will augment your class with the following
39 members functions:
41 1. `registered_subclasses() -> dict[str, type[cls]]`:
42 This will return a mapping of discriminator value to the corresponding
43 subclass. See
44 [`TrackingGroup.models`][dynapydantic.TrackingGroup.models] for details.
45 2. `union() -> typing.Any`:
46 This will return an (optionally) annotated subclass union. See
47 [`TrackingGroup.union()`][dynapydantic.TrackingGroup.union] for details.
48 3. `load_plugins() -> None`:
49 If plugin_entry_point was specified, then this method will load plugin
50 packages to discover additional subclasses. See
51 [`TrackingGroup.load_plugins()`][dynapydantic.TrackingGroup.load_plugins]
52 for more details.
53 """
55 def __init_subclass__(cls, *args, **kwargs) -> None:
56 """Subclass hook"""
57 # Intercept any kwargs that are intended for TrackingGroup or
58 # __pydantic_init_subclass__
59 sig = inspect.signature(SubclassTrackingModel.__pydantic_init_subclass__)
60 super().__init_subclass__(
61 *args,
62 **{
63 k: v
64 for k, v in kwargs.items()
65 if k not in TrackingGroup.model_fields and k not in sig.parameters
66 },
67 )
69 @classmethod
70 def __pydantic_init_subclass__(
71 cls,
72 *args,
73 exclude_from_union: bool | None = None,
74 **kwargs,
75 ) -> None:
76 """Pydantic subclass hook"""
77 if SubclassTrackingModel in cls.__bases__:
78 # Intercept any kwargs that are intended for TrackingGroup
79 super().__pydantic_init_subclass__(
80 *args,
81 **{
82 k: v
83 for k, v in kwargs.items()
84 if k not in TrackingGroup.model_fields
85 },
86 )
88 if isinstance((tc := getattr(cls, "tracking_config", None)), TrackingGroup):
89 cls.__DYNAPYDANTIC__ = tc
90 else:
91 try:
92 cls.__DYNAPYDANTIC__: TrackingGroup = TrackingGroup.model_validate(
93 {"name": f"{cls.__name__}-subclasses"} | kwargs,
94 )
95 except pydantic.ValidationError as e:
96 msg = (
97 "SubclassTrackingModel subclasses must either have a "
98 "tracking_config: ClassVar[dynapydantic.TrackingGroup] "
99 "member or pass kwargs sufficient to construct a "
100 "dynapydantic.TrackingGroup in the class declaration. "
101 "The latter approach produced the following "
102 f"ValidationError:\n{e}"
103 )
104 raise ConfigurationError(msg) from e
106 # Promote the tracking group's methods to the parent class
107 if cls.__DYNAPYDANTIC__.plugin_entry_point is not None:
109 def _load_plugins() -> None:
110 """Load plugins to register more models"""
111 cls.__DYNAPYDANTIC__.load_plugins()
113 cls.load_plugins = staticmethod(_load_plugins)
115 def _union(
116 *,
117 plain: bool | None = None,
118 annotated: bool | None = None,
119 ) -> ty.Any: # noqa: ANN401 - return type is runtime-determined
120 """Get the union of all tracked subclasses
122 Parameters
123 ----------
124 plain
125 If set to `True`, a plain union of all members will be returned.
126 Otherwise, the returned union will be annotated in accordance with
127 the union mode.
128 annotated
129 Deprecated. Use `plain=True` when you would have used
130 `annotated=False`.
131 """
132 # deprecation warning for annotated is in TrackingGroup
133 return cls.__DYNAPYDANTIC__.union(plain=plain, annotated=annotated)
135 cls.union = staticmethod(_union)
137 def _subclasses() -> dict[str, type[pydantic.BaseModel]]:
138 """Return a mapping of discriminator values to registered model"""
139 return cls.__DYNAPYDANTIC__.models
141 cls.registered_subclasses = staticmethod(_subclasses)
143 return
145 super().__pydantic_init_subclass__(*args, **kwargs)
147 if exclude_from_union:
148 return
150 supers = direct_children_of_base_in_mro(cls, SubclassTrackingModel)
151 for base in supers:
152 base.__DYNAPYDANTIC__.register_model(cls)
154 class PydanticAdaptor:
155 """Pydantic type adaptor for SubclassTrackingModel"""
157 @staticmethod
158 def __get_pydantic_core_schema__(
159 source_type: ty.Any, # noqa: ANN401
160 handler: GetCoreSchemaHandler,
161 ) -> core_schema.CoreSchema:
162 """Get the pydantic schema for this type"""
163 if not isinstance(source_type, type) or not issubclass(
164 source_type,
165 SubclassTrackingModel,
166 ):
167 msg = (
168 f"{source_type} was not a SubclassTrackingModel, "
169 "so it is incompatible with dynapydantic.Polymorphic"
170 )
171 raise PydanticSchemaGenerationError(msg)
172 return handler(source_type.union())