Coverage for src / dynapydantic / subclass_tracking_model.py: 100%

71 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-17 17:07 +0000

1"""Base class for dynamic pydantic models""" 

2 

3import inspect 

4import typing as ty 

5import warnings 

6 

7import pydantic 

8from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler 

9from pydantic.json_schema import JsonSchemaValue 

10from pydantic_core import PydanticCustomError, core_schema 

11 

12from .exceptions import ConfigurationError, Error 

13from .tracking_group import TrackingGroup 

14 

15 

16def direct_children_of_base_in_mro(derived: type, base: type) -> list[type]: 

17 """Find all classes in derived's MRO that are direct subclasses of base. 

18 

19 Parameters 

20 ---------- 

21 derived 

22 The class whose MRO is being examined. 

23 base 

24 The base class to find direct subclasses of. 

25 

26 Returns 

27 ------- 

28 Classes in derived's MRO that are direct subclasses of base. 

29 """ 

30 return [cls for cls in derived.__mro__ if cls is not base and base in cls.__bases__] 

31 

32 

33class SubclassTrackingModel(pydantic.BaseModel): 

34 """Subclass-tracking BaseModel 

35 

36 This will inject a [`TrackingGroup`][dynapydantic.TrackingGroup] into your 

37 class and automate the registration of subclasses. 

38 

39 Similar to `BaseModel`, `SubclassTrackingModel` can take arguments in the 

40 class declaration. Arguments from `BaseModel` will be forwarded. 

41 Additionally, any fields from `TrackingGroup` will be forwarded to the 

42 internal `TrackingGroup` instance. The following additional arguments are 

43 supported: 

44 

45 1. `exclude_from_union`: This flag is intended to be used with descendents 

46 of `SubclassTrackingModel`. If `True`, this subclass will be omitted 

47 from tracking. 

48 2. `implicit_polymorphic`: This flag is intended to be used with direct 

49 descendents of `SubclassTrackingModel`. If `True`, then the core 

50 schema of this class will be overridden. This allows polymorphic 

51 parsing to occur without the use of 

52 [`Polymorphic`][dynapydantic.Polymorphic]. In addition, it is not 

53 necessary to call `model_rebuild` on recursive models. This feature 

54 is currently **EXPERIMENTAL** and does incur a runtime penalty. 

55 

56 **DEPRECATED:** 

57 

58 Inheriting from this class will augment your class with the following 

59 members functions: 

60 

61 1. `registered_subclasses() -> dict[str, type[cls]]`: 

62 This will return a mapping of discriminator value to the corresponding 

63 subclass. See 

64 [`TrackingGroup.models`][dynapydantic.TrackingGroup.models] for details. 

65 2. `union() -> typing.Any`: 

66 This will return an (optionally) annotated subclass union. See 

67 [`TrackingGroup.union()`][dynapydantic.TrackingGroup.union] for details. 

68 3. `load_plugins() -> None`: 

69 If plugin_entry_point was specified, then this method will load plugin 

70 packages to discover additional subclasses. See 

71 [`TrackingGroup.load_plugins()`][dynapydantic.TrackingGroup.load_plugins] 

72 for more details. 

73 

74 These methods will be removed in 0.5.0, please migrate to their 

75 corresponding free functions: 

76 

77 1. `registered_subclasses()` -> 

78 [`registered_models()`][dynapydantic.registered_models] 

79 2. `union()` -> [`union()`][dynapydantic.union] or 

80 [`Union[T]`][dynapydantic.Union] 

81 3. `load_plugins()` -> [`load_plugins()`][dynapydantic.load_plugins] 

82 """ 

83 

84 def __init_subclass__(cls, *args, **kwargs) -> None: 

85 """Subclass hook""" 

86 # Intercept any kwargs that are intended for TrackingGroup or 

87 # __pydantic_init_subclass__ 

88 sig = inspect.signature(SubclassTrackingModel.__pydantic_init_subclass__) 

89 super().__init_subclass__( 

90 *args, 

91 **{ 

92 k: v 

93 for k, v in kwargs.items() 

94 if k not in TrackingGroup.model_fields and k not in sig.parameters 

95 }, 

96 ) 

97 

98 # This method is too complex, here's the plan to simplify it: 

99 # We're polluting this models attributes by injecting and forwarding methods 

100 # from tracking group. As a result, we're limiting the possible field names 

101 # that these models can have. These should be free functions. We're going 

102 # to deprecate the methods to give people a release cycle to migrate off. 

103 # We should be able to remove the noqa after these are removed. 

104 @classmethod 

105 def __pydantic_init_subclass__( # noqa: C901 

106 cls, 

107 *args, 

108 exclude_from_union: bool | None = None, 

109 implicit_polymorphic: bool | None = None, 

110 **kwargs, 

111 ) -> None: 

112 """Pydantic subclass hook""" 

113 if SubclassTrackingModel in cls.__bases__: 

114 # Intercept any kwargs that are intended for TrackingGroup 

115 super().__pydantic_init_subclass__( 

116 *args, 

117 **{ 

118 k: v 

119 for k, v in kwargs.items() 

120 if k not in TrackingGroup.model_fields 

121 }, 

122 ) 

123 

124 cls.__DYNAPYDANTIC_IMPLICIT_POLYMORPHIC__: ty.ClassVar[bool] = ( 

125 implicit_polymorphic if implicit_polymorphic is not None else False 

126 ) 

127 

128 if isinstance((tc := getattr(cls, "tracking_config", None)), TrackingGroup): 

129 cls.__DYNAPYDANTIC__: ty.ClassVar[TrackingGroup] = tc 

130 else: 

131 try: 

132 cls.__DYNAPYDANTIC__: ty.ClassVar[TrackingGroup] = ( 

133 TrackingGroup.model_validate( 

134 {"name": f"{cls.__name__}-subclasses"} | kwargs, 

135 ) 

136 ) 

137 except pydantic.ValidationError as e: 

138 msg = ( 

139 "SubclassTrackingModel subclasses must either have a " 

140 "tracking_config: ClassVar[dynapydantic.TrackingGroup] " 

141 "member or pass kwargs sufficient to construct a " 

142 "dynapydantic.TrackingGroup in the class declaration. " 

143 "The latter approach produced the following " 

144 f"ValidationError:\n{e}" 

145 ) 

146 raise ConfigurationError(msg) from e 

147 

148 # Promote the tracking group's methods to the parent class 

149 if cls.__DYNAPYDANTIC__.plugin_entry_point is not None: 

150 

151 def _load_plugins() -> None: 

152 """Load plugins to register more models 

153 

154 DEPRECATED: use 

155 [`dynapydantic.load_plugins`][dynapydantic.load_plugins] 

156 """ 

157 msg = ( 

158 "SubclassTrackingModel.load_plugins() is deprecated, " 

159 "please swap dynapydantic.load_plugins()." 

160 ) 

161 warnings.warn(msg, DeprecationWarning, stacklevel=2) 

162 cls.__DYNAPYDANTIC__.load_plugins() 

163 

164 cls.load_plugins = staticmethod(_load_plugins) 

165 

166 def _union( 

167 *, 

168 plain: bool | None = None, 

169 annotated: bool | None = None, 

170 ) -> ty.Any: # noqa: ANN401 - return type is runtime-determined 

171 """Get the union of all tracked subclasses 

172 

173 DEPRECATED: use [`Union[T]`][dynapydantic.Union] or 

174 [`union()`][dynapydantic.union] instead. 

175 

176 Parameters 

177 ---------- 

178 plain 

179 If set to `True`, a plain union of all members will be returned. 

180 Otherwise, the returned union will be annotated in accordance with 

181 the union mode. 

182 annotated 

183 Deprecated. Use `plain=True` when you would have used 

184 `annotated=False`. 

185 """ 

186 msg = ( 

187 "SubclassTrackingModel.union() is deprecated, please swap " 

188 "to dynapydantic.Union[T] (for annotations) or " 

189 "dynapydantic.union() (for runtime calls)." 

190 ) 

191 warnings.warn(msg, DeprecationWarning, stacklevel=2) 

192 

193 # deprecation warning for annotated is in TrackingGroup 

194 return cls.__DYNAPYDANTIC__.union(plain=plain, annotated=annotated) 

195 

196 cls.union = staticmethod(_union) 

197 

198 def _subclasses() -> dict[str, type[pydantic.BaseModel]]: 

199 """Return a mapping of discriminator values to registered model 

200 

201 DEPRECATED: use dynapydantic.registered_models(). 

202 """ 

203 msg = ( 

204 "SubclassTrackingModel.registered_subclasses() is " 

205 "deprecated, please swap to " 

206 "dynapydantic.registered_models()." 

207 ) 

208 warnings.warn(msg, DeprecationWarning, stacklevel=2) 

209 

210 return cls.__DYNAPYDANTIC__.models 

211 

212 cls.registered_subclasses = staticmethod(_subclasses) 

213 

214 if implicit_polymorphic: 

215 cls.__get_pydantic_core_schema__ = classmethod( # type: ignore[bad-assignment] 

216 _get_pydantic_core_schema 

217 ) 

218 

219 cls.__get_pydantic_json_schema__ = classmethod( # type: ignore[bad-assignment] 

220 _get_pydantic_json_schema 

221 ) 

222 

223 return 

224 

225 super().__pydantic_init_subclass__(*args, **kwargs) 

226 

227 if exclude_from_union: 

228 return 

229 

230 supers = direct_children_of_base_in_mro(cls, SubclassTrackingModel) 

231 for base in supers: 

232 base.__DYNAPYDANTIC__.register_model(cls) 

233 

234 

235def _get_adapter( 

236 source_type: type[SubclassTrackingModel], 

237) -> pydantic.TypeAdapter: 

238 try: 

239 return source_type.__DYNAPYDANTIC__.type_adapter 

240 except Error as e: 

241 err_t = "dynapydantic_error" 

242 raise PydanticCustomError(err_t, "{e}", {"e": str(e)}) from e 

243 

244 

245def _get_pydantic_core_schema( 

246 cls: type[SubclassTrackingModel], 

247 source_type: type[pydantic.BaseModel], 

248 handler: GetCoreSchemaHandler, 

249 /, 

250) -> core_schema.CoreSchema: 

251 """Get the pydantic core schema for this type""" 

252 if SubclassTrackingModel not in cls.__bases__: 

253 return handler(source_type) 

254 

255 def _validate(value: ty.Any) -> ty.Any: # noqa: ANN401 

256 return _get_adapter( 

257 ty.cast("type[SubclassTrackingModel]", source_type) 

258 ).validate_python(value) 

259 

260 def _serialize( 

261 value: pydantic.BaseModel, 

262 info: core_schema.SerializationInfo, 

263 ) -> dict[str, ty.Any]: 

264 return value.model_dump(mode=info.mode) 

265 

266 return core_schema.no_info_plain_validator_function( 

267 _validate, 

268 serialization=core_schema.plain_serializer_function_ser_schema( 

269 _serialize, 

270 info_arg=True, 

271 when_used="unless-none", 

272 return_schema=core_schema.dict_schema( 

273 core_schema.str_schema(), core_schema.any_schema() 

274 ), 

275 ), 

276 ) 

277 

278 

279def _get_pydantic_json_schema( 

280 cls: type[SubclassTrackingModel], 

281 cs: core_schema.CoreSchema, 

282 handler: GetJsonSchemaHandler, 

283 /, 

284) -> JsonSchemaValue: 

285 """Get the pydantic JSON schema for this type""" 

286 if SubclassTrackingModel in cls.__bases__: 

287 return handler(_get_adapter(cls).core_schema) 

288 return handler(cs)