Coverage for src / dynapydantic / subclass_tracking_model.py: 99%

73 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-13 20:14 +0000

1"""Base class for dynamic pydantic models""" 

2 

3import dataclasses 

4import inspect 

5import typing as ty 

6 

7import pydantic 

8from pydantic import BaseModel, GetCoreSchemaHandler 

9from pydantic_core import PydanticCustomError, core_schema 

10 

11from .exceptions import ConfigurationError, Error 

12from .tracking_group import TrackingGroup 

13from .union_mode import UnionRealization 

14 

15 

16class SubclassTrackingModel(pydantic.BaseModel): 

17 """Subclass-tracking BaseModel 

18 

19 This will inject a [`TrackingGroup`][dynapydantic.TrackingGroup] into your 

20 class and automate the registration of subclasses. 

21 

22 Similar to `BaseModel`, `SubclassTrackingModel` can take arguments in the 

23 class declaration. Arguments from `BaseModel` will be forwarded. 

24 Additionally, any fields from `TrackingGroup` will be forwarded to the 

25 internal `TrackingGroup` instance. The following additional arguments are 

26 supported: 

27 

28 1. `exclude_from_union`: This flag is intended to be used with descendents 

29 of `SubclassTrackingModel`. If `True`, this subclass will be omitted 

30 from tracking. The default for this flag is `True` for direct 

31 descendents of `SubclassTrackingModel` and `False` otherwise. 

32 2. `union_realization`: When the union should be realized. See 

33 [`UnionRealization`][dynapydantic.UnionRealization] for more details 

34 on the various options. The default is to realize unions at model 

35 construction time. 

36 """ 

37 

38 def __init_subclass__(cls, *args, **kwargs) -> None: 

39 """Subclass hook""" 

40 # Intercept any kwargs that are intended for TrackingGroup or 

41 # __pydantic_init_subclass__ 

42 sig = inspect.signature(SubclassTrackingModel.__pydantic_init_subclass__) 

43 super().__init_subclass__( 

44 *args, 

45 **{ 

46 k: v 

47 for k, v in kwargs.items() 

48 if k not in TrackingGroup.model_fields and k not in sig.parameters 

49 }, 

50 ) 

51 

52 @classmethod 

53 def __pydantic_init_subclass__( 

54 cls, 

55 *args, 

56 exclude_from_union: bool | None = None, 

57 union_realization: str | UnionRealization | None = None, 

58 **kwargs, 

59 ) -> None: 

60 """Pydantic subclass hook""" 

61 # Forward along any unexpected arguments that were not intended 

62 # for TrackingGroup. 

63 super().__pydantic_init_subclass__( 

64 *args, 

65 **{k: v for k, v in kwargs.items() if k not in TrackingGroup.model_fields}, 

66 ) 

67 

68 # Initialize the tracking group 

69 cls.__DYNAPYDANTIC__: ty.ClassVar[TrackingGroup] = _init_tracking_group( 

70 cls, **kwargs 

71 ) 

72 

73 # Initialize our SubclassTrackingModel-specific config 

74 cls.__DYNAPYDANTIC_STM_CONFIG__: ty.ClassVar[_StmConfig] = _StmConfig.create( 

75 cls, 

76 exclude_from_union=exclude_from_union, 

77 union_realization=union_realization, 

78 inherited=getattr(cls, "__DYNAPYDANTIC_STM_CONFIG__", None), 

79 ) 

80 

81 # If we are going to be tracked, walk the entire MRO (to support 

82 # multi-level tree) and register ourselves with each oe. 

83 if not cls.__DYNAPYDANTIC_STM_CONFIG__.exclude_from_union: 

84 for base in cls.__mro__: 

85 if ( 

86 issubclass(base, SubclassTrackingModel) 

87 and base is not SubclassTrackingModel 

88 ): 

89 base.__DYNAPYDANTIC__.register_model(cls) 

90 

91 

92def _init_tracking_group( 

93 cls: type[SubclassTrackingModel], 

94 **kwargs, 

95) -> TrackingGroup: 

96 """Initialize the tracking model embedded in this model""" 

97 # If the user already defined one, use it 

98 if isinstance((tc := getattr(cls, "tracking_config", None)), TrackingGroup): 

99 return tc 

100 

101 # Otherwise, we need to make it. We can inherit arguments from our 

102 # parent class(es) if they have TrackingGroup's and then allow any 

103 # kwargs directly passed here to override. 

104 if isinstance(parent_tg := getattr(cls, "__DYNAPYDANTIC__", None), TrackingGroup): 

105 tg_kwargs = parent_tg.model_dump( 

106 exclude={ 

107 "name", 

108 "models", 

109 "discriminator_field", 

110 "discriminator_value_generator", 

111 } 

112 ) 

113 tg_kwargs |= kwargs 

114 if "discriminator_field" in kwargs: 

115 tg_kwargs.pop("union_mode", None) 

116 else: 

117 tg_kwargs = kwargs 

118 tg_kwargs.setdefault("name", f"{cls.__name__}-subclasses") 

119 

120 try: 

121 return TrackingGroup(**tg_kwargs) 

122 except pydantic.ValidationError as e: 

123 msg = ( 

124 "SubclassTrackingModel subclasses must either have a " 

125 "tracking_config: ClassVar[dynapydantic.TrackingGroup] " 

126 "member or pass kwargs sufficient to construct a " 

127 "dynapydantic.TrackingGroup in the class declaration. " 

128 "The latter approach produced the following " 

129 f"ValidationError:\n{e}" 

130 ) 

131 raise ConfigurationError(msg) from e 

132 

133 

134@dataclasses.dataclass(frozen=True) 

135class _StmConfig: 

136 """Config for SubclassTrackingModel""" 

137 

138 union_realization: UnionRealization 

139 exclude_from_union: bool 

140 

141 @classmethod 

142 def create( 

143 cls, 

144 model_t: type[SubclassTrackingModel], 

145 *, 

146 exclude_from_union: bool | None, 

147 union_realization: str | UnionRealization | None = None, 

148 inherited: "_StmConfig | None" = None, 

149 ) -> "_StmConfig": 

150 """Create this model from the user's specified keyword arguments""" 

151 # Figure out the union realization time. Prefer direct argument, then 

152 # inherited value, then default of model construction time. 

153 if union_realization is None: 

154 union_realization = ( 

155 inherited.union_realization 

156 if inherited is not None 

157 else UnionRealization.MODEL_CONSTRUCTION 

158 ) 

159 elif not isinstance(union_realization, UnionRealization): 

160 try: 

161 union_realization = UnionRealization(union_realization) 

162 except (ValueError, TypeError) as e: 

163 msg = f"invalid union_realization: {e}" 

164 raise ConfigurationError(msg) from e 

165 

166 # Figure out if model_t is are excluded from tracking unions. Prefer 

167 # direct argument, default to True if we are direct descendent of 

168 # SubclassTrackingModel and False otherwise. This is because direct 

169 # descendents tend to be the abstract base classes. 

170 if exclude_from_union is None: 

171 exclude_from_union = SubclassTrackingModel in model_t.__bases__ 

172 

173 return cls( 

174 union_realization=union_realization, 

175 exclude_from_union=exclude_from_union, 

176 ) 

177 

178 

179_UNSET = object() 

180 

181 

182class ValidationTimeAdapter: 

183 """Pydantic type adapter for a dynapydantic-tracked field 

184 

185 This adapter returns a validator that evaluates the union at validation time 

186 """ 

187 

188 @staticmethod 

189 def __get_pydantic_core_schema__( 

190 source_type: type[SubclassTrackingModel], 

191 _handler: GetCoreSchemaHandler, 

192 ) -> core_schema.CoreSchema: 

193 """Get the pydantic schema for this type""" 

194 

195 def _validate(value: ty.Any) -> ty.Any: # noqa: ANN401 

196 try: 

197 adapter = source_type.__DYNAPYDANTIC__.type_adapter 

198 except Error as e: 

199 err_t = "dynapydantic_error" 

200 raise PydanticCustomError(err_t, "{e}", {"e": str(e)}) from e 

201 return adapter.validate_python(value) 

202 

203 def _serialize( 

204 value: BaseModel, 

205 info: core_schema.SerializationInfo, 

206 ) -> dict[str, ty.Any]: 

207 # These arguments we're going to attempt but not require 

208 soft_args = ( 

209 # These were added after 2.0 (we pin >= 2) 

210 "context", 

211 "exclude_computed_fields", 

212 "serialize_as_any", 

213 "polymorphic_serialization", 

214 ) 

215 args: dict[str, ty.Any] = { 

216 # SerializationInfo doesn't expose warnings, so we have to 

217 # pick one option 

218 "warnings": False, 

219 } 

220 for arg in soft_args: 

221 if (v := getattr(info, arg, _UNSET)) is not _UNSET: 221 ↛ 220line 221 didn't jump to line 220 because the condition on line 221 was always true

222 args[arg] = v 

223 

224 return value.model_dump( 

225 mode=info.mode, 

226 # Pydantic's types on SerializationInfo's include/exclude don't 

227 # match up with the corresponding parameter types on model_dump 

228 include=info.include, # type: ignore[bad-argument-type] 

229 exclude=info.exclude, # type: ignore[bad-argument-type] 

230 by_alias=info.by_alias, 

231 exclude_unset=info.exclude_unset, 

232 exclude_defaults=info.exclude_defaults, 

233 exclude_none=info.exclude_none, 

234 round_trip=info.round_trip, 

235 **args, 

236 ) 

237 

238 return core_schema.no_info_plain_validator_function( 

239 _validate, 

240 serialization=core_schema.plain_serializer_function_ser_schema( 

241 _serialize, 

242 info_arg=True, 

243 when_used="unless-none", 

244 return_schema=core_schema.dict_schema( 

245 core_schema.str_schema(), core_schema.any_schema() 

246 ), 

247 ), 

248 )