Coverage for src / dynapydantic / subclass_tracking_model.py: 100%
67 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 21:42 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 21:42 +0000
1"""Base class for dynamic pydantic models"""
3import dataclasses
4import inspect
5import typing as ty
7import pydantic
8from pydantic import BaseModel, GetCoreSchemaHandler
9from pydantic_core import PydanticCustomError, core_schema
11from .exceptions import ConfigurationError, Error
12from .tracking_group import TrackingGroup
13from .union_mode import UnionRealization
16class SubclassTrackingModel(pydantic.BaseModel):
17 """Subclass-tracking BaseModel
19 This will inject a [`TrackingGroup`][dynapydantic.TrackingGroup] into your
20 class and automate the registration of subclasses.
22 Similar to `BaseModel`, `SubclassTrackingModel` can take arguments in the
23 class declaration. Arguments from `BaseModel` will be forwarded.
24 Additionally, any fields from `TrackingGroup` will be forwarded to the
25 internal `TrackingGroup` instance. The following additional arguments are
26 supported:
28 1. `exclude_from_union`: This flag is intended to be used with descendents
29 of `SubclassTrackingModel`. If `True`, this subclass will be omitted
30 from tracking. The default for this flag is `True` for direct
31 descendents of `SubclassTrackingModel` and `False` otherwise.
32 2. `union_realization`: When the union should be realized. See
33 [`UnionRealization`][dynapydantic.UnionRealization] for more details
34 on the various options. The default is to realize unions at model
35 construction time.
36 """
38 def __init_subclass__(cls, *args, **kwargs) -> None:
39 """Subclass hook"""
40 # Intercept any kwargs that are intended for TrackingGroup or
41 # __pydantic_init_subclass__
42 sig = inspect.signature(SubclassTrackingModel.__pydantic_init_subclass__)
43 super().__init_subclass__(
44 *args,
45 **{
46 k: v
47 for k, v in kwargs.items()
48 if k not in TrackingGroup.model_fields and k not in sig.parameters
49 },
50 )
52 @classmethod
53 def __pydantic_init_subclass__(
54 cls,
55 *args,
56 exclude_from_union: bool | None = None,
57 union_realization: str | UnionRealization | None = None,
58 **kwargs,
59 ) -> None:
60 """Pydantic subclass hook"""
61 # Forward along any unexpected arguments that were not intended
62 # for TrackingGroup.
63 super().__pydantic_init_subclass__(
64 *args,
65 **{k: v for k, v in kwargs.items() if k not in TrackingGroup.model_fields},
66 )
68 # Initialize the tracking group
69 cls.__DYNAPYDANTIC__: ty.ClassVar[TrackingGroup] = _init_tracking_group(
70 cls, **kwargs
71 )
73 # Initialize our SubclassTrackingModel-specific config
74 cls.__DYNAPYDANTIC_STM_CONFIG__: ty.ClassVar[_StmConfig] = _StmConfig.create(
75 cls,
76 exclude_from_union=exclude_from_union,
77 union_realization=union_realization,
78 inherited=getattr(cls, "__DYNAPYDANTIC_STM_CONFIG__", None),
79 )
81 # If we are going to be tracked, walk the entire MRO (to support
82 # multi-level tree) and register ourselves with each oe.
83 if not cls.__DYNAPYDANTIC_STM_CONFIG__.exclude_from_union:
84 for base in cls.__mro__:
85 if (
86 issubclass(base, SubclassTrackingModel)
87 and base is not SubclassTrackingModel
88 ):
89 base.__DYNAPYDANTIC__.register_model(cls)
92def _init_tracking_group(
93 cls: type[SubclassTrackingModel],
94 **kwargs,
95) -> TrackingGroup:
96 """Initialize the tracking model embedded in this model"""
97 # If the user already defined one, use it
98 if isinstance((tc := getattr(cls, "tracking_config", None)), TrackingGroup):
99 return tc
101 # Otherwise, we need to make it. We can inherit arguments from our
102 # parent class(es) if they have TrackingGroup's and then allow any
103 # kwargs directly passed here to override.
104 if isinstance(parent_tg := getattr(cls, "__DYNAPYDANTIC__", None), TrackingGroup):
105 tg_kwargs = parent_tg.model_dump(
106 exclude={
107 "name",
108 "models",
109 "discriminator_field",
110 "discriminator_value_generator",
111 }
112 )
113 tg_kwargs |= kwargs
114 if "discriminator_field" in kwargs:
115 tg_kwargs.pop("union_mode", None)
116 else:
117 tg_kwargs = kwargs
118 tg_kwargs.setdefault("name", f"{cls.__name__}-subclasses")
120 try:
121 return TrackingGroup(**tg_kwargs)
122 except pydantic.ValidationError as e:
123 msg = (
124 "SubclassTrackingModel subclasses must either have a "
125 "tracking_config: ClassVar[dynapydantic.TrackingGroup] "
126 "member or pass kwargs sufficient to construct a "
127 "dynapydantic.TrackingGroup in the class declaration. "
128 "The latter approach produced the following "
129 f"ValidationError:\n{e}"
130 )
131 raise ConfigurationError(msg) from e
134@dataclasses.dataclass(frozen=True)
135class _StmConfig:
136 """Config for SubclassTrackingModel"""
138 union_realization: UnionRealization
139 exclude_from_union: bool
141 @classmethod
142 def create(
143 cls,
144 model_t: type[SubclassTrackingModel],
145 *,
146 exclude_from_union: bool | None,
147 union_realization: str | UnionRealization | None = None,
148 inherited: "_StmConfig | None" = None,
149 ) -> "_StmConfig":
150 """Create this model from the user's specified keyword arguments"""
151 # Figure out the union realization time. Prefer direct argument, then
152 # inherited value, then default of model construction time.
153 if union_realization is None:
154 union_realization = (
155 inherited.union_realization
156 if inherited is not None
157 else UnionRealization.MODEL_CONSTRUCTION
158 )
159 elif not isinstance(union_realization, UnionRealization):
160 try:
161 union_realization = UnionRealization(union_realization)
162 except (ValueError, TypeError) as e:
163 msg = f"invalid union_realization: {e}"
164 raise ConfigurationError(msg) from e
166 # Figure out if model_t is are excluded from tracking unions. Prefer
167 # direct argument, default to True if we are direct descendent of
168 # SubclassTrackingModel and False otherwise. This is because direct
169 # descendents tend to be the abstract base classes.
170 if exclude_from_union is None:
171 exclude_from_union = SubclassTrackingModel in model_t.__bases__
173 return cls(
174 union_realization=union_realization,
175 exclude_from_union=exclude_from_union,
176 )
179class ValidationTimeAdapter:
180 """Pydantic type adapter for a dynapydantic-tracked field
182 This adapter returns a validator that evaluates the union at validation time
183 """
185 @staticmethod
186 def __get_pydantic_core_schema__(
187 source_type: type[SubclassTrackingModel],
188 _handler: GetCoreSchemaHandler,
189 ) -> core_schema.CoreSchema:
190 """Get the pydantic schema for this type"""
192 def _validate(value: ty.Any) -> ty.Any: # noqa: ANN401
193 try:
194 adapter = source_type.__DYNAPYDANTIC__.type_adapter
195 except Error as e:
196 err_t = "dynapydantic_error"
197 raise PydanticCustomError(err_t, "{e}", {"e": str(e)}) from e
198 return adapter.validate_python(value)
200 def _serialize(
201 value: BaseModel,
202 info: core_schema.SerializationInfo,
203 ) -> dict[str, ty.Any]:
204 return value.model_dump(mode=info.mode)
206 return core_schema.no_info_plain_validator_function(
207 _validate,
208 serialization=core_schema.plain_serializer_function_ser_schema(
209 _serialize,
210 info_arg=True,
211 when_used="unless-none",
212 return_schema=core_schema.dict_schema(
213 core_schema.str_schema(), core_schema.any_schema()
214 ),
215 ),
216 )