Coverage for src / dynapydantic / subclass_tracking_model.py: 99%
73 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-13 20:14 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-13 20:14 +0000
1"""Base class for dynamic pydantic models"""
3import dataclasses
4import inspect
5import typing as ty
7import pydantic
8from pydantic import BaseModel, GetCoreSchemaHandler
9from pydantic_core import PydanticCustomError, core_schema
11from .exceptions import ConfigurationError, Error
12from .tracking_group import TrackingGroup
13from .union_mode import UnionRealization
16class SubclassTrackingModel(pydantic.BaseModel):
17 """Subclass-tracking BaseModel
19 This will inject a [`TrackingGroup`][dynapydantic.TrackingGroup] into your
20 class and automate the registration of subclasses.
22 Similar to `BaseModel`, `SubclassTrackingModel` can take arguments in the
23 class declaration. Arguments from `BaseModel` will be forwarded.
24 Additionally, any fields from `TrackingGroup` will be forwarded to the
25 internal `TrackingGroup` instance. The following additional arguments are
26 supported:
28 1. `exclude_from_union`: This flag is intended to be used with descendents
29 of `SubclassTrackingModel`. If `True`, this subclass will be omitted
30 from tracking. The default for this flag is `True` for direct
31 descendents of `SubclassTrackingModel` and `False` otherwise.
32 2. `union_realization`: When the union should be realized. See
33 [`UnionRealization`][dynapydantic.UnionRealization] for more details
34 on the various options. The default is to realize unions at model
35 construction time.
36 """
38 def __init_subclass__(cls, *args, **kwargs) -> None:
39 """Subclass hook"""
40 # Intercept any kwargs that are intended for TrackingGroup or
41 # __pydantic_init_subclass__
42 sig = inspect.signature(SubclassTrackingModel.__pydantic_init_subclass__)
43 super().__init_subclass__(
44 *args,
45 **{
46 k: v
47 for k, v in kwargs.items()
48 if k not in TrackingGroup.model_fields and k not in sig.parameters
49 },
50 )
52 @classmethod
53 def __pydantic_init_subclass__(
54 cls,
55 *args,
56 exclude_from_union: bool | None = None,
57 union_realization: str | UnionRealization | None = None,
58 **kwargs,
59 ) -> None:
60 """Pydantic subclass hook"""
61 # Forward along any unexpected arguments that were not intended
62 # for TrackingGroup.
63 super().__pydantic_init_subclass__(
64 *args,
65 **{k: v for k, v in kwargs.items() if k not in TrackingGroup.model_fields},
66 )
68 # Initialize the tracking group
69 cls.__DYNAPYDANTIC__: ty.ClassVar[TrackingGroup] = _init_tracking_group(
70 cls, **kwargs
71 )
73 # Initialize our SubclassTrackingModel-specific config
74 cls.__DYNAPYDANTIC_STM_CONFIG__: ty.ClassVar[_StmConfig] = _StmConfig.create(
75 cls,
76 exclude_from_union=exclude_from_union,
77 union_realization=union_realization,
78 inherited=getattr(cls, "__DYNAPYDANTIC_STM_CONFIG__", None),
79 )
81 # If we are going to be tracked, walk the entire MRO (to support
82 # multi-level tree) and register ourselves with each oe.
83 if not cls.__DYNAPYDANTIC_STM_CONFIG__.exclude_from_union:
84 for base in cls.__mro__:
85 if (
86 issubclass(base, SubclassTrackingModel)
87 and base is not SubclassTrackingModel
88 ):
89 base.__DYNAPYDANTIC__.register_model(cls)
92def _init_tracking_group(
93 cls: type[SubclassTrackingModel],
94 **kwargs,
95) -> TrackingGroup:
96 """Initialize the tracking model embedded in this model"""
97 # If the user already defined one, use it
98 if isinstance((tc := getattr(cls, "tracking_config", None)), TrackingGroup):
99 return tc
101 # Otherwise, we need to make it. We can inherit arguments from our
102 # parent class(es) if they have TrackingGroup's and then allow any
103 # kwargs directly passed here to override.
104 if isinstance(parent_tg := getattr(cls, "__DYNAPYDANTIC__", None), TrackingGroup):
105 tg_kwargs = parent_tg.model_dump(
106 exclude={
107 "name",
108 "models",
109 "discriminator_field",
110 "discriminator_value_generator",
111 }
112 )
113 tg_kwargs |= kwargs
114 if "discriminator_field" in kwargs:
115 tg_kwargs.pop("union_mode", None)
116 else:
117 tg_kwargs = kwargs
118 tg_kwargs.setdefault("name", f"{cls.__name__}-subclasses")
120 try:
121 return TrackingGroup(**tg_kwargs)
122 except pydantic.ValidationError as e:
123 msg = (
124 "SubclassTrackingModel subclasses must either have a "
125 "tracking_config: ClassVar[dynapydantic.TrackingGroup] "
126 "member or pass kwargs sufficient to construct a "
127 "dynapydantic.TrackingGroup in the class declaration. "
128 "The latter approach produced the following "
129 f"ValidationError:\n{e}"
130 )
131 raise ConfigurationError(msg) from e
134@dataclasses.dataclass(frozen=True)
135class _StmConfig:
136 """Config for SubclassTrackingModel"""
138 union_realization: UnionRealization
139 exclude_from_union: bool
141 @classmethod
142 def create(
143 cls,
144 model_t: type[SubclassTrackingModel],
145 *,
146 exclude_from_union: bool | None,
147 union_realization: str | UnionRealization | None = None,
148 inherited: "_StmConfig | None" = None,
149 ) -> "_StmConfig":
150 """Create this model from the user's specified keyword arguments"""
151 # Figure out the union realization time. Prefer direct argument, then
152 # inherited value, then default of model construction time.
153 if union_realization is None:
154 union_realization = (
155 inherited.union_realization
156 if inherited is not None
157 else UnionRealization.MODEL_CONSTRUCTION
158 )
159 elif not isinstance(union_realization, UnionRealization):
160 try:
161 union_realization = UnionRealization(union_realization)
162 except (ValueError, TypeError) as e:
163 msg = f"invalid union_realization: {e}"
164 raise ConfigurationError(msg) from e
166 # Figure out if model_t is are excluded from tracking unions. Prefer
167 # direct argument, default to True if we are direct descendent of
168 # SubclassTrackingModel and False otherwise. This is because direct
169 # descendents tend to be the abstract base classes.
170 if exclude_from_union is None:
171 exclude_from_union = SubclassTrackingModel in model_t.__bases__
173 return cls(
174 union_realization=union_realization,
175 exclude_from_union=exclude_from_union,
176 )
179_UNSET = object()
182class ValidationTimeAdapter:
183 """Pydantic type adapter for a dynapydantic-tracked field
185 This adapter returns a validator that evaluates the union at validation time
186 """
188 @staticmethod
189 def __get_pydantic_core_schema__(
190 source_type: type[SubclassTrackingModel],
191 _handler: GetCoreSchemaHandler,
192 ) -> core_schema.CoreSchema:
193 """Get the pydantic schema for this type"""
195 def _validate(value: ty.Any) -> ty.Any: # noqa: ANN401
196 try:
197 adapter = source_type.__DYNAPYDANTIC__.type_adapter
198 except Error as e:
199 err_t = "dynapydantic_error"
200 raise PydanticCustomError(err_t, "{e}", {"e": str(e)}) from e
201 return adapter.validate_python(value)
203 def _serialize(
204 value: BaseModel,
205 info: core_schema.SerializationInfo,
206 ) -> dict[str, ty.Any]:
207 # These arguments we're going to attempt but not require
208 soft_args = (
209 # These were added after 2.0 (we pin >= 2)
210 "context",
211 "exclude_computed_fields",
212 "serialize_as_any",
213 "polymorphic_serialization",
214 )
215 args: dict[str, ty.Any] = {
216 # SerializationInfo doesn't expose warnings, so we have to
217 # pick one option
218 "warnings": False,
219 }
220 for arg in soft_args:
221 if (v := getattr(info, arg, _UNSET)) is not _UNSET: 221 ↛ 220line 221 didn't jump to line 220 because the condition on line 221 was always true
222 args[arg] = v
224 return value.model_dump(
225 mode=info.mode,
226 # Pydantic's types on SerializationInfo's include/exclude don't
227 # match up with the corresponding parameter types on model_dump
228 include=info.include, # type: ignore[bad-argument-type]
229 exclude=info.exclude, # type: ignore[bad-argument-type]
230 by_alias=info.by_alias,
231 exclude_unset=info.exclude_unset,
232 exclude_defaults=info.exclude_defaults,
233 exclude_none=info.exclude_none,
234 round_trip=info.round_trip,
235 **args,
236 )
238 return core_schema.no_info_plain_validator_function(
239 _validate,
240 serialization=core_schema.plain_serializer_function_ser_schema(
241 _serialize,
242 info_arg=True,
243 when_used="unless-none",
244 return_schema=core_schema.dict_schema(
245 core_schema.str_schema(), core_schema.any_schema()
246 ),
247 ),
248 )