Coverage for src / dynapydantic / tracking_group.py: 100%

62 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2026-02-06 15:20 +0000

1"""Base class for dynamic pydantic models""" 

2 

3import contextlib 

4import typing as ty 

5 

6import pydantic 

7import pydantic.fields 

8import pydantic_core 

9 

10from .exceptions import AmbiguousDiscriminatorValueError, RegistrationError 

11 

12 

13def _inject_discriminator_field( 

14 cls: type[pydantic.BaseModel], 

15 disc_field: str, 

16 value: str, 

17) -> pydantic.fields.FieldInfo: 

18 """Injects the discriminator field into the given model 

19 

20 Parameters 

21 ---------- 

22 cls 

23 The BaseModel subclass 

24 disc_field 

25 Name of the discriminator field 

26 value 

27 Value of the discriminator field 

28 """ 

29 cls.model_fields[disc_field] = pydantic.fields.FieldInfo( 

30 default=value, 

31 annotation=ty.Literal[value], # type: ignore[not-a-type] 

32 frozen=True, 

33 ) 

34 with contextlib.suppress(pydantic.errors.PydanticUndefinedAnnotation): 

35 cls.model_rebuild(force=True) 

36 return cls.model_fields[disc_field] 

37 

38 

39class TrackingGroup(pydantic.BaseModel): 

40 """Tracker for pydantic models""" 

41 

42 name: str = pydantic.Field( 

43 description=( 

44 "Name of the tracking group. This is for human display, so it " 

45 "doesn't technically need to be globally unique, but it should be " 

46 "meaningfully named, as it will be used in error messages." 

47 ), 

48 ) 

49 discriminator_field: str = pydantic.Field( 

50 description="Name of the discriminator field", 

51 ) 

52 plugin_entry_point: str | None = pydantic.Field( 

53 None, 

54 description=( 

55 "If given, then plugins packages will be supported through this " 

56 "Python entrypoint. The entrypoint can either be a function, " 

57 "which will be called, or simply a module, which will be " 

58 "imported. In either case, models found along the import path of " 

59 "the entrypoint will be registered. If the entrypoint is a " 

60 "function, additional models may be declared in the function." 

61 ), 

62 ) 

63 discriminator_value_generator: ty.Callable[[type], str] | None = pydantic.Field( 

64 None, 

65 description=( 

66 "A callable that produces default values for the discriminator field" 

67 ), 

68 ) 

69 models: dict[str, type[pydantic.BaseModel]] = pydantic.Field( 

70 {}, 

71 description="The tracked models", 

72 ) 

73 

74 def load_plugins(self) -> None: 

75 """Load plugins to discover/register additional models""" 

76 if self.plugin_entry_point is None: 

77 return 

78 

79 from importlib.metadata import entry_points # noqa: PLC0415 

80 

81 for ep in entry_points().select(group=self.plugin_entry_point): 

82 plugin = ep.load() 

83 if callable(plugin): 

84 plugin() 

85 

86 def register( 

87 self, 

88 discriminator_value: str | None = None, 

89 ) -> ty.Callable[[type], type]: 

90 """Register a model into this group (decorator) 

91 

92 Parameters 

93 ---------- 

94 discriminator_value 

95 Value for the discriminator field. If not given, then 

96 discriminator_value_generator must be non-None or the 

97 discriminator field must be declared by hand. 

98 """ 

99 

100 def _wrapper(cls: type[pydantic.BaseModel]) -> type[pydantic.BaseModel]: 

101 disc = self.discriminator_field 

102 field = cls.model_fields.get(self.discriminator_field) 

103 if field is None: 

104 if discriminator_value is not None: 

105 _inject_discriminator_field(cls, disc, discriminator_value) 

106 elif self.discriminator_value_generator is not None: 

107 _inject_discriminator_field( 

108 cls, 

109 disc, 

110 self.discriminator_value_generator(cls), 

111 ) 

112 else: 

113 msg = ( 

114 f"unable to determine a discriminator value for " 

115 f'{cls.__name__} in tracking group "{self.name}". No ' 

116 "value was passed to register(), " 

117 "discriminator_value_generator was None and the " 

118 f'"{disc}" field was not defined.' 

119 ) 

120 raise RegistrationError(msg) 

121 elif ( 

122 discriminator_value is not None and field.default != discriminator_value 

123 ): 

124 msg = ( 

125 f"the discriminator value for {cls.__name__} was " 

126 f'ambiguous, it was set to "{discriminator_value}" via ' 

127 f'register() and "{field.default}" via the discriminator ' 

128 f"field ({self.discriminator_field})." 

129 ) 

130 raise AmbiguousDiscriminatorValueError(msg) 

131 

132 self._register_with_discriminator_field(cls) 

133 return cls 

134 

135 return _wrapper 

136 

137 def register_model(self, cls: type[pydantic.BaseModel]) -> None: 

138 """Register the given model into this group 

139 

140 Parameters 

141 ---------- 

142 cls 

143 The model to register 

144 """ 

145 disc = self.discriminator_field 

146 if cls.model_fields.get(self.discriminator_field) is None: 

147 if self.discriminator_value_generator is not None: 

148 _inject_discriminator_field( 

149 cls, 

150 disc, 

151 self.discriminator_value_generator(cls), 

152 ) 

153 else: 

154 msg = ( 

155 f"unable to determine a discriminator value for " 

156 f'{cls.__name__} in tracking group "{self.name}", ' 

157 "discriminator_value_generator was None and the " 

158 f'"{disc}" field was not defined.' 

159 ) 

160 raise RegistrationError(msg) 

161 

162 self._register_with_discriminator_field(cls) 

163 

164 def _register_with_discriminator_field(self, cls: type[pydantic.BaseModel]) -> None: 

165 """Register the model with the default of the discriminator field 

166 

167 Parameters 

168 ---------- 

169 cls 

170 The class to register, must have the disciminator field set with a 

171 unique default value in the group. 

172 """ 

173 disc = self.discriminator_field 

174 value = cls.model_fields[disc].default 

175 if value == pydantic_core.PydanticUndefined: 

176 msg = ( 

177 f"{cls.__name__}.{disc} had no default value, it must " 

178 "have one which is unique among all tracked models." 

179 ) 

180 raise RegistrationError(msg) 

181 

182 if (other := self.models.get(value)) is not None and other is not cls: 

183 msg = ( 

184 f'Cannot register {cls.__name__} under the "{value}" ' 

185 f"identifier, which is already in use by {other.__name__}." 

186 ) 

187 raise RegistrationError(msg) 

188 

189 self.models[value] = cls 

190 

191 def union(self, *, annotated: bool = True) -> ty.Any: # noqa: ANN401 

192 """Return the union of all registered models""" 

193 return ( 

194 ty.Annotated[ 

195 ty.Union[ # noqa: UP007 

196 # This is fundamentally incompatible with static type 

197 # checking, as this is resolved at runtime. 

198 tuple( # type: ignore[not-a-type] 

199 ty.Annotated[x, pydantic.Tag(v)] for v, x in self.models.items() 

200 ) 

201 ], 

202 pydantic.Field(discriminator=self.discriminator_field), 

203 ] 

204 if annotated 

205 # type: ignore[not-a-type] 

206 else ty.Union[tuple(self.models.values())] # noqa: UP007 

207 )