Coverage for src / dynapydantic / subclass_tracking_model.py: 100%
49 statements
« prev ^ index » next coverage.py v7.12.0, created at 2026-02-06 15:20 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2026-02-06 15:20 +0000
1"""Base class for dynamic pydantic models"""
3import inspect
4import typing as ty
6import pydantic
7from pydantic import GetCoreSchemaHandler
8from pydantic.errors import PydanticSchemaGenerationError
9from pydantic_core import core_schema
11from .exceptions import ConfigurationError
12from .tracking_group import TrackingGroup
15def direct_children_of_base_in_mro(derived: type, base: type) -> list[type]:
16 """Find all classes in derived's MRO that are direct subclasses of base.
18 Parameters
19 ----------
20 derived
21 The class whose MRO is being examined.
22 base
23 The base class to find direct subclasses of.
25 Returns
26 -------
27 Classes in derived's MRO that are direct subclasses of base.
28 """
29 return [cls for cls in derived.__mro__ if cls is not base and base in cls.__bases__]
32class SubclassTrackingModel(pydantic.BaseModel):
33 """Subclass-tracking BaseModel
35 This will inject a TrackingGroup into your class and automate the
36 registration of subclasses.
38 Inheriting from this class will augment your class with the following
39 members functions:
40 1. registered_subclasses() -> dict[str, type[cls]]:
41 This will return a mapping of discriminator value to the corresponding
42 sublcass. See TrackingGroup.models for details.
43 2. union() -> typing.GenericAlias:
44 This will return an (optionally) annotated subclass union. See
45 TrackingGroup.union() for details.
46 3. load_plugins() -> None:
47 If plugin_entry_point was specified, then this method will load plugin
48 packages to discover additional subclasses. See
49 TrackingGroup.load_plugins for more details.
50 """
52 def __init_subclass__(cls, *args, **kwargs) -> None:
53 """Subclass hook"""
54 # Intercept any kwargs that are intended for TrackingGroup or
55 # __pydantic_init_subclass__
56 sig = inspect.signature(SubclassTrackingModel.__pydantic_init_subclass__)
57 super().__init_subclass__(
58 *args,
59 **{
60 k: v
61 for k, v in kwargs.items()
62 if k not in TrackingGroup.model_fields and k not in sig.parameters
63 },
64 )
66 @classmethod
67 def __pydantic_init_subclass__(
68 cls,
69 *args,
70 exclude_from_union: bool | None = None,
71 **kwargs,
72 ) -> None:
73 """Pydantic subclass hook"""
74 if SubclassTrackingModel in cls.__bases__:
75 # Intercept any kwargs that are intended for TrackingGroup
76 super().__pydantic_init_subclass__(
77 *args,
78 **{
79 k: v
80 for k, v in kwargs.items()
81 if k not in TrackingGroup.model_fields
82 },
83 )
85 if isinstance((tc := getattr(cls, "tracking_config", None)), TrackingGroup):
86 cls.__DYNAPYDANTIC__ = tc
87 else:
88 try:
89 cls.__DYNAPYDANTIC__: TrackingGroup = TrackingGroup.model_validate(
90 {"name": f"{cls.__name__}-subclasses"} | kwargs,
91 )
92 except pydantic.ValidationError as e:
93 msg = (
94 "SubclassTrackingModel subclasses must either have a "
95 "tracking_config: ClassVar[dynapydantic.TrackingGroup] "
96 "member or pass kwargs sufficient to construct a "
97 "dynapydantic.TrackingGroup in the class declaration. "
98 "The latter approach produced the following "
99 f"ValidationError:\n{e}"
100 )
101 raise ConfigurationError(msg) from e
103 # Promote the tracking group's methods to the parent class
104 if cls.__DYNAPYDANTIC__.plugin_entry_point is not None:
106 def _load_plugins() -> None:
107 """Load plugins to register more models"""
108 cls.__DYNAPYDANTIC__.load_plugins()
110 cls.load_plugins = staticmethod(_load_plugins)
112 def _union(*, annotated: bool = True) -> ty.GenericAlias:
113 """Get the union of all tracked subclasses
115 Parameters
116 ----------
117 annotated
118 Whether this should be an annotated union for usage as a
119 pydantic field annotation, or a plain typing.Union for a
120 regular type annotation.
121 """
122 return cls.__DYNAPYDANTIC__.union(annotated=annotated)
124 cls.union = staticmethod(_union)
126 def _subclasses() -> dict[str, type[pydantic.BaseModel]]:
127 """Return a mapping of discriminator values to registered model"""
128 return cls.__DYNAPYDANTIC__.models
130 cls.registered_subclasses = staticmethod(_subclasses)
132 return
134 super().__pydantic_init_subclass__(*args, **kwargs)
136 if exclude_from_union:
137 return
139 supers = direct_children_of_base_in_mro(cls, SubclassTrackingModel)
140 for base in supers:
141 base.__DYNAPYDANTIC__.register_model(cls)
143 class PydanticAdaptor:
144 """Pydantic type adaptor for SubclassTrackingModel"""
146 @staticmethod
147 def __get_pydantic_core_schema__(
148 source_type: ty.Any, # noqa: ANN401
149 handler: GetCoreSchemaHandler,
150 ) -> core_schema.CoreSchema:
151 """Get the pydantic schema for this type"""
152 if not isinstance(source_type, type) or not issubclass(
153 source_type,
154 SubclassTrackingModel,
155 ):
156 msg = (
157 f"{source_type} was not a SubclassTrackingModel, "
158 "so it is incompatible with dynapydantic.Polymorphic"
159 )
160 raise PydanticSchemaGenerationError(msg)
161 return handler(source_type.union())