Coverage for src / dynapydantic / tracking_group.py: 100%

147 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-13 20:14 +0000

1"""Base class for dynamic pydantic models""" 

2 

3import contextlib 

4import functools 

5import operator 

6import typing as ty 

7import warnings 

8 

9import pydantic 

10import pydantic.fields 

11import pydantic_core 

12 

13from .exceptions import ( 

14 AmbiguousDiscriminatorValueError, 

15 NoRegisteredTypesError, 

16 RegistrationError, 

17) 

18from .union_mode import DiscriminatedConfig, UnionMode 

19 

20 

21def _inject_discriminator_field( 

22 cls: type[pydantic.BaseModel], 

23 disc_field: str, 

24 value: str, 

25) -> pydantic.fields.FieldInfo: 

26 """Injects the discriminator field into the given model 

27 

28 Parameters 

29 ---------- 

30 cls 

31 The BaseModel subclass 

32 disc_field 

33 Name of the discriminator field 

34 value 

35 Value of the discriminator field 

36 """ 

37 if hasattr(cls, disc_field): 

38 msg = ( 

39 f'Cannot inject discriminator field "{disc_field}" into ' 

40 f"{cls.__name__}: an attribute with that name already exists. " 

41 "Rename either the attribute or the discriminator_field to avoid " 

42 "the conflict." 

43 ) 

44 raise RegistrationError(msg) 

45 

46 cls.model_fields[disc_field] = pydantic.fields.FieldInfo( 

47 default=value, 

48 annotation=ty.Literal[value], # type: ignore[not-a-type] 

49 frozen=True, 

50 ) 

51 with contextlib.suppress(pydantic.errors.PydanticUndefinedAnnotation): 

52 cls.model_rebuild(force=True) 

53 return cls.model_fields[disc_field] 

54 

55 

56class TrackingGroup(pydantic.BaseModel): 

57 """Tracker for pydantic models""" 

58 

59 name: str = pydantic.Field( 

60 description=( 

61 "Name of the tracking group. This is for human display, so it " 

62 "doesn't technically need to be globally unique, but it should be " 

63 "meaningfully named, as it will be used in error messages." 

64 ), 

65 ) 

66 union_mode: UnionMode | None = pydantic.Field( 

67 None, 

68 description=( 

69 "Union validation strategy. Pass a DiscriminatedConfig instance " 

70 'or one of the plain strings "smart" or "left_to_right". You can ' 

71 "also just pass the fields for DiscriminatedConfig to this " 

72 "model and they will be forwarded." 

73 ), 

74 ) 

75 discriminator_field: str | None = pydantic.Field( 

76 None, 

77 description=( 

78 "Name of the discriminator field. NOTE: This field is " 

79 "here as an alias for union_mode.discriminator_field. Passing " 

80 "both a discriminator_field and a union_mode will result in an " 

81 "error." 

82 ), 

83 ) 

84 discriminator_value_generator: ty.Callable[[type], str] | None = pydantic.Field( 

85 None, 

86 description=( 

87 "A callable that produces default values for the discriminator field" 

88 ), 

89 ) 

90 plugin_entry_point: str | None = pydantic.Field( 

91 None, 

92 description=( 

93 "If given, then plugins packages will be supported through this " 

94 "Python entrypoint. The entrypoint can either be a function, " 

95 "which will be called, or simply a module, which will be " 

96 "imported. In either case, models found along the import path of " 

97 "the entrypoint will be registered. If the entrypoint is a " 

98 "function, additional models may be declared in the function." 

99 ), 

100 ) 

101 models: dict[str, type[pydantic.BaseModel]] = pydantic.Field( 

102 {}, 

103 description="The tracked models", 

104 ) 

105 

106 _generation: int = pydantic.PrivateAttr(default=0) 

107 _adapter: pydantic.TypeAdapter | None = pydantic.PrivateAttr(default=None) 

108 _adapter_generation: int = pydantic.PrivateAttr(default=-1) 

109 

110 @pydantic.model_validator(mode="after") 

111 def _ensure_union_mode(self) -> "TrackingGroup": 

112 """There must be a union_mode 

113 

114 This validator works as a guard on _coerce_union_mode to make 

115 """ 

116 if self.union_mode is None: 

117 msg = ( 

118 "union_mode is required. This normally indicates that you " 

119 "subclasses TrackingGroup and wrote an invalid validator, but " 

120 "could also be a bug with dynapydantic, so please file a bug " 

121 "report with a reproducer on how you got here if you suspect " 

122 "a bug." 

123 ) 

124 raise ValueError(msg) 

125 

126 # Ensure the top-level fields are in-sync 

127 if isinstance(self.union_mode, DiscriminatedConfig): 

128 self.discriminator_field = self.union_mode.discriminator_field 

129 self.discriminator_value_generator = ( 

130 self.union_mode.discriminator_value_generator 

131 ) 

132 else: 

133 self.discriminator_field = None 

134 self.discriminator_value_generator = None 

135 

136 return self 

137 

138 @pydantic.model_validator(mode="before") 

139 @classmethod 

140 def _coerce_union_mode(cls, data: ty.Any) -> ty.Any: # noqa: ANN401 

141 """Coerce flat discriminator kwargs into a DiscriminatedConfig. 

142 

143 Allows callers to pass ``discriminator_field`` and 

144 ``discriminator_value_generator`` at the top level and transparently 

145 assembles a ``DiscriminatedConfig`` from them. This avoids an extra 

146 import/nesting layer for the user. 

147 """ 

148 if not isinstance(data, dict): 

149 return data 

150 

151 disc_field = data.get("discriminator_field", None) 

152 has_disc_field = disc_field is not None 

153 union_mode = data.get("union_mode", None) 

154 has_union_mode = union_mode is not None 

155 

156 # If the user passed us both a discriminator field and a union_mode, 

157 # things must be perfectly consistent 

158 if has_disc_field and has_union_mode: 

159 consistent = ( 

160 isinstance(union_mode, DiscriminatedConfig) 

161 and disc_field == union_mode.discriminator_field 

162 and data.get("discriminator_value_generator") 

163 is union_mode.discriminator_value_generator 

164 ) or ( 

165 isinstance(union_mode, dict) 

166 and disc_field == union_mode.get("discriminator_field") 

167 and data.get("discriminator_value_generator") 

168 is union_mode.get("discriminator_value_generator") 

169 ) 

170 if not consistent: 

171 msg = ( 

172 "Received both union_mode and discriminator_field; pass one " 

173 "or the other." 

174 ) 

175 raise ValueError(msg) 

176 

177 if has_disc_field and not has_union_mode: 

178 # Forward arguments to DiscriminatedConfig 

179 data["union_mode"] = { 

180 "discriminator_field": disc_field, 

181 "discriminator_value_generator": data.get( 

182 "discriminator_value_generator", 

183 ), 

184 } 

185 elif not has_disc_field and not has_union_mode: 

186 msg = "Either union_mode or discriminator_field must be given" 

187 raise ValueError(msg) 

188 

189 return data 

190 

191 @property 

192 def _discriminated(self) -> DiscriminatedConfig | None: 

193 """Return the DiscriminatedMode config, or None if not discriminated.""" 

194 return ( 

195 self.union_mode 

196 if isinstance(self.union_mode, DiscriminatedConfig) 

197 else None 

198 ) 

199 

200 def load_plugins(self) -> None: 

201 """Load plugins to discover/register additional models""" 

202 if self.plugin_entry_point is None: 

203 return 

204 

205 from importlib.metadata import entry_points # noqa: PLC0415 

206 

207 for ep in entry_points().select(group=self.plugin_entry_point): 

208 plugin = ep.load() 

209 if callable(plugin): 

210 plugin() 

211 

212 @ty.overload 

213 def register(self, value: str | None = None) -> ty.Callable[[type], type]: ... 

214 

215 @ty.overload 

216 def register(self, value: type[pydantic.BaseModel]) -> type[pydantic.BaseModel]: ... 

217 

218 def register( 

219 self, 

220 value: str | type[pydantic.BaseModel] | None = None, 

221 ) -> ty.Callable[[type], type] | type[pydantic.BaseModel]: 

222 """Register a model into this group (decorator) 

223 

224 Parameters 

225 ---------- 

226 value 

227 Value for the discriminator field. If not given, then 

228 discriminator_value_generator must be non-None or the 

229 discriminator field must be declared by hand. Can also be the type 

230 itself to register (if the ()'s are omitted from the decorator). 

231 """ 

232 if isinstance(value, type): 

233 self.register_model(value) 

234 return value 

235 

236 def _wrapper(cls: type[pydantic.BaseModel]) -> type[pydantic.BaseModel]: 

237 self.register_model(cls, value) 

238 return cls 

239 

240 return _wrapper 

241 

242 def register_model( 

243 self, 

244 cls: type[pydantic.BaseModel], 

245 discriminator_value: str | None = None, 

246 ) -> None: 

247 """Register the given model into this group 

248 

249 Parameters 

250 ---------- 

251 cls 

252 The model to register 

253 discriminator_value 

254 Value for the discriminator field. If not given, then 

255 discriminator_value_generator must be non-None or the 

256 discriminator field must be declared by hand. 

257 """ 

258 if discriminator_value is not None and not isinstance(discriminator_value, str): 

259 msg = ( 

260 "discriminator_value must be a str if given, was " 

261 f"{type(discriminator_value).__name__}" 

262 ) 

263 raise RegistrationError(msg) 

264 

265 if not isinstance(cls, type) or not issubclass(cls, pydantic.BaseModel): 

266 msg = ( 

267 "only pydantic BaseModel subclasses can be registered in a " 

268 f"TrackingGroup. Got {cls}, which was not." 

269 ) 

270 raise RegistrationError(msg) 

271 

272 if (dm := self._discriminated) is not None: 

273 disc = dm.discriminator_field 

274 field = cls.model_fields.get(disc) 

275 

276 if field is None: 

277 if discriminator_value is not None: 

278 _inject_discriminator_field(cls, disc, discriminator_value) 

279 elif dm.discriminator_value_generator is not None: 

280 _inject_discriminator_field( 

281 cls, 

282 disc, 

283 dm.discriminator_value_generator(cls), 

284 ) 

285 else: 

286 msg = ( 

287 f"unable to determine a discriminator value for " 

288 f'{cls.__name__} in tracking group "{self.name}". ' 

289 "No value was passed, discriminator_value_generator " 

290 f'was None and the "{disc}" field was not defined.' 

291 ) 

292 raise RegistrationError(msg) 

293 elif ty.get_origin(field.annotation) is not ty.Literal: 

294 msg = ( 

295 f'the discriminator field "{disc}" already existed in ' 

296 f"{cls.__name__}, but its type annotation was " 

297 f"{field.annotation}, not Literal." 

298 ) 

299 raise RegistrationError(msg) 

300 elif ( 

301 discriminator_value is not None and field.default != discriminator_value 

302 ): 

303 msg = ( 

304 f"the discriminator value for {cls.__name__} was " 

305 f'ambiguous, the passed value was "{discriminator_value}" ' 

306 f' and "{field.default}" via the discriminator ' 

307 f"field ({disc})." 

308 ) 

309 raise AmbiguousDiscriminatorValueError(msg) 

310 

311 self._register_with_discriminator_field(cls) 

312 else: 

313 if discriminator_value is not None: 

314 warnings.warn( 

315 f'A discriminator_value of "{discriminator_value}" was ' 

316 f"explicitly passed for {cls.__name__}, but " 

317 f'union_mode="{self.union_mode}" does not use a ' 

318 "discriminator. The value will be ignored.", 

319 stacklevel=2, 

320 ) 

321 self._register_plain(cls) 

322 

323 def union( 

324 self, 

325 *, 

326 plain: bool | None = None, 

327 ) -> ty.Any: # noqa: ANN401 

328 """Return the union of all registered models 

329 

330 Parameters 

331 ---------- 

332 plain 

333 If set to `True`, a plain union of all members will be returned. 

334 Otherwise, the returned union will be annotated in accordance with 

335 the union mode. 

336 

337 Returns 

338 ------- 

339 Any 

340 If there is 1 registered type, the type itself. If there is > 1, a 

341 union of all registered types. This union may be annotated if 

342 `plain` is not `True`. 

343 

344 Raises 

345 ------ 

346 NoRegisteredTypesError 

347 If no types have been registered yet. 

348 """ 

349 n = len(self.models) 

350 if n == 0: 

351 msg = ( 

352 "Unable to produce a union from the tracking group " 

353 f'"{self.name}", as no types have been registered yet.' 

354 ) 

355 raise NoRegisteredTypesError(msg) 

356 if n == 1: 

357 return next(iter(self.models.values())) 

358 

359 union_mode = "smart" if plain else self.union_mode 

360 

361 if isinstance(union_mode, DiscriminatedConfig): 

362 return ty.Annotated[ 

363 functools.reduce( 

364 operator.or_, 

365 tuple( 

366 ty.Annotated[x, pydantic.Tag(v)] for v, x in self.models.items() 

367 ), 

368 ), 

369 pydantic.Field(discriminator=union_mode.discriminator_field), 

370 ] 

371 

372 plain_union = functools.reduce(operator.or_, self.models.values()) 

373 if union_mode == "left_to_right": 

374 return ty.Annotated[plain_union, pydantic.Field(union_mode="left_to_right")] 

375 

376 # "smart" mode is pydantic's default behavior on a plain union 

377 return plain_union 

378 

379 @property 

380 def generation(self) -> int: 

381 """The generation of the tracking group. 

382 

383 This is a counter that increments every time a new registration occurs 

384 """ 

385 return self._generation 

386 

387 @property 

388 def type_adapter(self) -> pydantic.TypeAdapter: 

389 """Get the pydantic TypeAdapter for the union of all group members""" 

390 if self.generation != self._adapter_generation: 

391 self._adapter = pydantic.TypeAdapter(self.union()) 

392 self._adapter_generation = self.generation 

393 

394 # casting because the if statement ensures it is non-None (because 

395 # _adapter_generation starts at -1 and generation increments from 0. 

396 return ty.cast("pydantic.TypeAdapter", self._adapter) 

397 

398 def _register_with_discriminator_field(self, cls: type[pydantic.BaseModel]) -> None: 

399 """Register the model with the default of the discriminator field 

400 

401 Parameters 

402 ---------- 

403 cls 

404 The class to register, must have the disciminator field set with a 

405 unique default value in the group. 

406 """ 

407 disc = ty.cast("DiscriminatedConfig", self.union_mode).discriminator_field 

408 value = cls.model_fields[disc].default 

409 if value == pydantic_core.PydanticUndefined: 

410 msg = ( 

411 f"{cls.__name__}.{disc} had no default value, it must " 

412 "have one which is unique among all tracked models." 

413 ) 

414 raise RegistrationError(msg) 

415 if not isinstance(value, str): 

416 msg = ( 

417 f"{cls.__name__}.{disc} had a default value of {value}, which " 

418 f"was of type {type(value).__name__}, not str." 

419 ) 

420 raise RegistrationError(msg) 

421 

422 self._do_register(value, cls) 

423 

424 def _register_plain(self, cls: type[pydantic.BaseModel]) -> None: 

425 """Register the model keyed by its class name. 

426 

427 Used for smart / left_to_right modes where no discriminator field 

428 is involved. 

429 

430 Parameters 

431 ---------- 

432 cls 

433 The model to register. 

434 """ 

435 self._do_register(str(id(cls)), cls) 

436 

437 def _do_register(self, key: str, cls: type[pydantic.BaseModel]) -> None: 

438 """Register the given model under the given key 

439 

440 Parameters 

441 ---------- 

442 key 

443 The key under which to register the model 

444 cls 

445 The model to register. 

446 """ 

447 if (other := self.models.get(key)) is not None: 

448 if other is not cls: 

449 msg = ( 

450 f'Cannot register {cls.__name__} under the "{key}" ' 

451 f"identifier, which is already in use by {other.__name__}." 

452 ) 

453 raise RegistrationError(msg) 

454 else: 

455 self._generation += 1 

456 self.models[key] = cls