Coverage for src / dynapydantic / tracking_group.py: 100%

136 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2026-04-06 21:55 +0000

1"""Base class for dynamic pydantic models""" 

2 

3import contextlib 

4import functools 

5import operator 

6import typing as ty 

7import warnings 

8 

9import pydantic 

10import pydantic.fields 

11import pydantic_core 

12 

13from .exceptions import ( 

14 AmbiguousDiscriminatorValueError, 

15 NoRegisteredTypesError, 

16 RegistrationError, 

17) 

18from .union_mode import DiscriminatedConfig, UnionMode 

19 

20 

21def _inject_discriminator_field( 

22 cls: type[pydantic.BaseModel], 

23 disc_field: str, 

24 value: str, 

25) -> pydantic.fields.FieldInfo: 

26 """Injects the discriminator field into the given model 

27 

28 Parameters 

29 ---------- 

30 cls 

31 The BaseModel subclass 

32 disc_field 

33 Name of the discriminator field 

34 value 

35 Value of the discriminator field 

36 """ 

37 if hasattr(cls, disc_field): 

38 msg = ( 

39 f'Cannot inject discriminator field "{disc_field}" into ' 

40 f"{cls.__name__}: an attribute with that name already exists. " 

41 "Rename either the attribute or the discriminator_field to avoid " 

42 "the conflict." 

43 ) 

44 raise RegistrationError(msg) 

45 

46 cls.model_fields[disc_field] = pydantic.fields.FieldInfo( 

47 default=value, 

48 annotation=ty.Literal[value], # type: ignore[not-a-type] 

49 frozen=True, 

50 ) 

51 with contextlib.suppress(pydantic.errors.PydanticUndefinedAnnotation): 

52 cls.model_rebuild(force=True) 

53 return cls.model_fields[disc_field] 

54 

55 

56class TrackingGroup(pydantic.BaseModel): 

57 """Tracker for pydantic models""" 

58 

59 name: str = pydantic.Field( 

60 description=( 

61 "Name of the tracking group. This is for human display, so it " 

62 "doesn't technically need to be globally unique, but it should be " 

63 "meaningfully named, as it will be used in error messages." 

64 ), 

65 ) 

66 union_mode: UnionMode | None = pydantic.Field( 

67 None, 

68 description=( 

69 "Union validation strategy. Pass a DiscriminatedConfig instance " 

70 'or one of the plain strings "smart" or "left_to_right". You can ' 

71 "also just pass the fields for DiscriminatedConfig to this " 

72 "model and they will be forwarded." 

73 ), 

74 ) 

75 discriminator_field: str | None = pydantic.Field( 

76 None, 

77 description=( 

78 "Name of the discriminator field. NOTE: This field is " 

79 "here as an alias for union_mode.discriminator_field. Passing " 

80 "both a discriminator_field and a union_mode will result in an " 

81 "error." 

82 ), 

83 ) 

84 discriminator_value_generator: ty.Callable[[type], str] | None = pydantic.Field( 

85 None, 

86 description=( 

87 "A callable that produces default values for the discriminator field" 

88 ), 

89 ) 

90 plugin_entry_point: str | None = pydantic.Field( 

91 None, 

92 description=( 

93 "If given, then plugins packages will be supported through this " 

94 "Python entrypoint. The entrypoint can either be a function, " 

95 "which will be called, or simply a module, which will be " 

96 "imported. In either case, models found along the import path of " 

97 "the entrypoint will be registered. If the entrypoint is a " 

98 "function, additional models may be declared in the function." 

99 ), 

100 ) 

101 models: dict[str, type[pydantic.BaseModel]] = pydantic.Field( 

102 {}, 

103 description="The tracked models", 

104 ) 

105 

106 @pydantic.model_validator(mode="after") 

107 def _ensure_union_mode(self) -> "TrackingGroup": 

108 """There must be a union_mode 

109 

110 This validator works as a guard on _coerce_union_mode to make 

111 """ 

112 if self.union_mode is None: 

113 msg = ( 

114 "union_mode is required. This normally indicates that you " 

115 "subclasses TrackingGroup and wrote an invalid validator, but " 

116 "could also be a bug with dynapydantic, so please file a bug " 

117 "report with a reproducer on how you got here if you suspect " 

118 "a bug." 

119 ) 

120 raise ValueError(msg) 

121 

122 # Ensure the top-level fields are in-sync 

123 if isinstance(self.union_mode, DiscriminatedConfig): 

124 self.discriminator_field = self.union_mode.discriminator_field 

125 self.discriminator_value_generator = ( 

126 self.union_mode.discriminator_value_generator 

127 ) 

128 else: 

129 self.discriminator_field = None 

130 self.discriminator_value_generator = None 

131 

132 return self 

133 

134 @pydantic.model_validator(mode="before") 

135 @classmethod 

136 def _coerce_union_mode(cls, data: ty.Any) -> ty.Any: # noqa: ANN401 

137 """Coerce flat discriminator kwargs into a DiscriminatedConfig. 

138 

139 Allows callers to pass ``discriminator_field`` and 

140 ``discriminator_value_generator`` at the top level and transparently 

141 assembles a ``DiscriminatedConfig`` from them. This avoids an extra 

142 import/nesting layer for the user. 

143 """ 

144 if not isinstance(data, dict): 

145 return data 

146 

147 disc_field = data.get("discriminator_field", None) 

148 has_disc_field = disc_field is not None 

149 union_mode = data.get("union_mode", None) 

150 has_union_mode = union_mode is not None 

151 

152 if has_disc_field and has_union_mode: 

153 msg = ( 

154 "Received both union_mode and discriminator_field; pass one " 

155 "or the other." 

156 ) 

157 raise ValueError(msg) 

158 

159 if has_disc_field and not has_union_mode: 

160 # Forward arguments to DiscriminatedConfig 

161 data["union_mode"] = { 

162 "discriminator_field": disc_field, 

163 "discriminator_value_generator": data.get( 

164 "discriminator_value_generator", 

165 ), 

166 } 

167 elif not has_disc_field and not has_union_mode: 

168 msg = "Either union_mode or discriminator_field must be given" 

169 raise ValueError(msg) 

170 

171 return data 

172 

173 @property 

174 def _discriminated(self) -> DiscriminatedConfig | None: 

175 """Return the DiscriminatedMode config, or None if not discriminated.""" 

176 return ( 

177 self.union_mode 

178 if isinstance(self.union_mode, DiscriminatedConfig) 

179 else None 

180 ) 

181 

182 def load_plugins(self) -> None: 

183 """Load plugins to discover/register additional models""" 

184 if self.plugin_entry_point is None: 

185 return 

186 

187 from importlib.metadata import entry_points # noqa: PLC0415 

188 

189 for ep in entry_points().select(group=self.plugin_entry_point): 

190 plugin = ep.load() 

191 if callable(plugin): 

192 plugin() 

193 

194 @ty.overload 

195 def register(self, value: str | None = None) -> ty.Callable[[type], type]: ... 

196 

197 @ty.overload 

198 def register(self, value: type[pydantic.BaseModel]) -> type[pydantic.BaseModel]: ... 

199 

200 def register( 

201 self, 

202 value: str | type[pydantic.BaseModel] | None = None, 

203 ) -> ty.Callable[[type], type] | type[pydantic.BaseModel]: 

204 """Register a model into this group (decorator) 

205 

206 Parameters 

207 ---------- 

208 value 

209 Value for the discriminator field. If not given, then 

210 discriminator_value_generator must be non-None or the 

211 discriminator field must be declared by hand. Can also be the type 

212 itself to register (if the ()'s are omitted from the decorator). 

213 """ 

214 if isinstance(value, type): 

215 self.register_model(value) 

216 return value 

217 

218 def _wrapper(cls: type[pydantic.BaseModel]) -> type[pydantic.BaseModel]: 

219 self.register_model(cls, ty.cast("str | None", value)) 

220 return cls 

221 

222 return _wrapper 

223 

224 def register_model( 

225 self, 

226 cls: type[pydantic.BaseModel], 

227 discriminator_value: str | None = None, 

228 ) -> None: 

229 """Register the given model into this group 

230 

231 Parameters 

232 ---------- 

233 cls 

234 The model to register 

235 discriminator_value 

236 Value for the discriminator field. If not given, then 

237 discriminator_value_generator must be non-None or the 

238 discriminator field must be declared by hand. 

239 """ 

240 if discriminator_value is not None and not isinstance(discriminator_value, str): 

241 msg = ( 

242 "discriminator_value must be a str if given, was " 

243 f"{type(discriminator_value).__name__}" 

244 ) 

245 raise RegistrationError(msg) 

246 

247 if not isinstance(cls, type) or not issubclass(cls, pydantic.BaseModel): 

248 msg = ( 

249 "only pydantic BaseModel subclasses can be registered in a " 

250 f"TrackingGroup. Got {cls}, which was not." 

251 ) 

252 raise RegistrationError(msg) 

253 

254 if (dm := self._discriminated) is not None: 

255 disc = dm.discriminator_field 

256 field = cls.model_fields.get(disc) 

257 

258 if field is None: 

259 if discriminator_value is not None: 

260 _inject_discriminator_field(cls, disc, discriminator_value) 

261 elif dm.discriminator_value_generator is not None: 

262 _inject_discriminator_field( 

263 cls, 

264 disc, 

265 dm.discriminator_value_generator(cls), 

266 ) 

267 else: 

268 msg = ( 

269 f"unable to determine a discriminator value for " 

270 f'{cls.__name__} in tracking group "{self.name}". ' 

271 "No value was passed, discriminator_value_generator " 

272 f'was None and the "{disc}" field was not defined.' 

273 ) 

274 raise RegistrationError(msg) 

275 elif ty.get_origin(field.annotation) is not ty.Literal: 

276 msg = ( 

277 f'the discriminator field "{disc}" already existed in ' 

278 f"{cls.__name__}, but its type annotation was " 

279 f"{field.annotation}, not Literal." 

280 ) 

281 raise RegistrationError(msg) 

282 elif ( 

283 discriminator_value is not None and field.default != discriminator_value 

284 ): 

285 msg = ( 

286 f"the discriminator value for {cls.__name__} was " 

287 f'ambiguous, the passed value was "{discriminator_value}" ' 

288 f' and "{field.default}" via the discriminator ' 

289 f"field ({disc})." 

290 ) 

291 raise AmbiguousDiscriminatorValueError(msg) 

292 

293 self._register_with_discriminator_field(cls) 

294 else: 

295 if discriminator_value is not None: 

296 warnings.warn( 

297 f'A discriminator_value of "{discriminator_value}" was ' 

298 f"explicitly passed for {cls.__name__}, but " 

299 f'union_mode="{self.union_mode}" does not use a ' 

300 "discriminator. The value will be ignored.", 

301 stacklevel=2, 

302 ) 

303 self._register_plain(cls) 

304 

305 def union( 

306 self, 

307 *, 

308 plain: bool | None = None, 

309 annotated: bool | None = None, 

310 ) -> ty.Any: # noqa: ANN401 

311 """Return the union of all registered models 

312 

313 Parameters 

314 ---------- 

315 plain 

316 If set to `True`, a plain union of all members will be returned. 

317 Otherwise, the returned union will be annotated in accordance with 

318 the union mode. 

319 annotated 

320 Deprecated. Use `plain=True` when you would have used 

321 `annotated=False`. 

322 

323 Returns 

324 ------- 

325 Any 

326 If there is 1 registered type, the type itself. If there is > 1, a 

327 union of all registered types. This union may be annotated if 

328 `plain` is not `True`. 

329 

330 Raises 

331 ------ 

332 NoRegisteredTypesError 

333 If no types have been registered yet. 

334 """ 

335 if annotated is not None: 

336 warnings.warn( 

337 "The `annotated` parameter is deprecated. Use `plain=True` to " 

338 "get a plain union. By default, behavior is governed by " 

339 "`union_mode`. Will be removed in a future version.", 

340 DeprecationWarning, 

341 stacklevel=2, 

342 ) 

343 plain = True if not annotated else plain 

344 

345 n = len(self.models) 

346 if n == 0: 

347 msg = ( 

348 "Unable to produce a union from the tracking group " 

349 f'"{self.name}", as no types have been registered yet.' 

350 ) 

351 raise NoRegisteredTypesError(msg) 

352 if n == 1: 

353 return next(iter(self.models.values())) 

354 

355 union_mode = "smart" if plain else self.union_mode 

356 

357 if isinstance(union_mode, DiscriminatedConfig): 

358 return ty.Annotated[ 

359 functools.reduce( 

360 operator.or_, 

361 tuple( 

362 ty.Annotated[x, pydantic.Tag(v)] for v, x in self.models.items() 

363 ), 

364 ), 

365 pydantic.Field(discriminator=union_mode.discriminator_field), 

366 ] 

367 

368 plain_union = functools.reduce(operator.or_, self.models.values()) 

369 if union_mode == "left_to_right": 

370 return ty.Annotated[plain_union, pydantic.Field(union_mode="left_to_right")] 

371 

372 # "smart" mode is pydantic's default behavior on a plain union 

373 return plain_union 

374 

375 def _register_with_discriminator_field(self, cls: type[pydantic.BaseModel]) -> None: 

376 """Register the model with the default of the discriminator field 

377 

378 Parameters 

379 ---------- 

380 cls 

381 The class to register, must have the disciminator field set with a 

382 unique default value in the group. 

383 """ 

384 disc = ty.cast("DiscriminatedConfig", self.union_mode).discriminator_field 

385 value = cls.model_fields[disc].default 

386 if value == pydantic_core.PydanticUndefined: 

387 msg = ( 

388 f"{cls.__name__}.{disc} had no default value, it must " 

389 "have one which is unique among all tracked models." 

390 ) 

391 raise RegistrationError(msg) 

392 if not isinstance(value, str): 

393 msg = ( 

394 f"{cls.__name__}.{disc} had a default value of {value}, which " 

395 f"was of type {type(value).__name__}, not str." 

396 ) 

397 raise RegistrationError(msg) 

398 

399 if (other := self.models.get(value)) is not None and other is not cls: 

400 msg = ( 

401 f'Cannot register {cls.__name__} under the "{value}" ' 

402 f"identifier, which is already in use by {other.__name__}." 

403 ) 

404 raise RegistrationError(msg) 

405 

406 self.models[value] = cls 

407 

408 def _register_plain(self, cls: type[pydantic.BaseModel]) -> None: 

409 """Register the model keyed by its class name. 

410 

411 Used for smart / left_to_right modes where no discriminator field 

412 is involved. 

413 

414 Parameters 

415 ---------- 

416 cls 

417 The model to register. 

418 """ 

419 key = str(id(cls)) 

420 if (other := self.models.get(key)) is not None and other is not cls: 

421 msg = ( 

422 f'Cannot register {cls.__name__} under the "{key}" ' 

423 f"identifier, which is already in use by {other.__name__}." 

424 ) 

425 raise RegistrationError(msg) 

426 self.models[key] = cls