Coverage for src / dynapydantic / subclass_tracking_model.py: 100%

49 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2026-04-06 21:55 +0000

1"""Base class for dynamic pydantic models""" 

2 

3import inspect 

4import typing as ty 

5 

6import pydantic 

7from pydantic import GetCoreSchemaHandler 

8from pydantic.errors import PydanticSchemaGenerationError 

9from pydantic_core import core_schema 

10 

11from .exceptions import ConfigurationError 

12from .tracking_group import TrackingGroup 

13 

14 

15def direct_children_of_base_in_mro(derived: type, base: type) -> list[type]: 

16 """Find all classes in derived's MRO that are direct subclasses of base. 

17 

18 Parameters 

19 ---------- 

20 derived 

21 The class whose MRO is being examined. 

22 base 

23 The base class to find direct subclasses of. 

24 

25 Returns 

26 ------- 

27 Classes in derived's MRO that are direct subclasses of base. 

28 """ 

29 return [cls for cls in derived.__mro__ if cls is not base and base in cls.__bases__] 

30 

31 

32class SubclassTrackingModel(pydantic.BaseModel): 

33 """Subclass-tracking BaseModel 

34 

35 This will inject a [`TrackingGroup`][dynapydantic.TrackingGroup] into your 

36 class and automate the registration of subclasses. 

37 

38 Inheriting from this class will augment your class with the following 

39 members functions: 

40 

41 1. `registered_subclasses() -> dict[str, type[cls]]`: 

42 This will return a mapping of discriminator value to the corresponding 

43 subclass. See 

44 [`TrackingGroup.models`][dynapydantic.TrackingGroup.models] for details. 

45 2. `union() -> typing.Any`: 

46 This will return an (optionally) annotated subclass union. See 

47 [`TrackingGroup.union()`][dynapydantic.TrackingGroup.union] for details. 

48 3. `load_plugins() -> None`: 

49 If plugin_entry_point was specified, then this method will load plugin 

50 packages to discover additional subclasses. See 

51 [`TrackingGroup.load_plugins()`][dynapydantic.TrackingGroup.load_plugins] 

52 for more details. 

53 """ 

54 

55 def __init_subclass__(cls, *args, **kwargs) -> None: 

56 """Subclass hook""" 

57 # Intercept any kwargs that are intended for TrackingGroup or 

58 # __pydantic_init_subclass__ 

59 sig = inspect.signature(SubclassTrackingModel.__pydantic_init_subclass__) 

60 super().__init_subclass__( 

61 *args, 

62 **{ 

63 k: v 

64 for k, v in kwargs.items() 

65 if k not in TrackingGroup.model_fields and k not in sig.parameters 

66 }, 

67 ) 

68 

69 @classmethod 

70 def __pydantic_init_subclass__( 

71 cls, 

72 *args, 

73 exclude_from_union: bool | None = None, 

74 **kwargs, 

75 ) -> None: 

76 """Pydantic subclass hook""" 

77 if SubclassTrackingModel in cls.__bases__: 

78 # Intercept any kwargs that are intended for TrackingGroup 

79 super().__pydantic_init_subclass__( 

80 *args, 

81 **{ 

82 k: v 

83 for k, v in kwargs.items() 

84 if k not in TrackingGroup.model_fields 

85 }, 

86 ) 

87 

88 if isinstance((tc := getattr(cls, "tracking_config", None)), TrackingGroup): 

89 cls.__DYNAPYDANTIC__ = tc 

90 else: 

91 try: 

92 cls.__DYNAPYDANTIC__: TrackingGroup = TrackingGroup.model_validate( 

93 {"name": f"{cls.__name__}-subclasses"} | kwargs, 

94 ) 

95 except pydantic.ValidationError as e: 

96 msg = ( 

97 "SubclassTrackingModel subclasses must either have a " 

98 "tracking_config: ClassVar[dynapydantic.TrackingGroup] " 

99 "member or pass kwargs sufficient to construct a " 

100 "dynapydantic.TrackingGroup in the class declaration. " 

101 "The latter approach produced the following " 

102 f"ValidationError:\n{e}" 

103 ) 

104 raise ConfigurationError(msg) from e 

105 

106 # Promote the tracking group's methods to the parent class 

107 if cls.__DYNAPYDANTIC__.plugin_entry_point is not None: 

108 

109 def _load_plugins() -> None: 

110 """Load plugins to register more models""" 

111 cls.__DYNAPYDANTIC__.load_plugins() 

112 

113 cls.load_plugins = staticmethod(_load_plugins) 

114 

115 def _union( 

116 *, 

117 plain: bool | None = None, 

118 annotated: bool | None = None, 

119 ) -> ty.Any: # noqa: ANN401 - return type is runtime-determined 

120 """Get the union of all tracked subclasses 

121 

122 Parameters 

123 ---------- 

124 plain 

125 If set to `True`, a plain union of all members will be returned. 

126 Otherwise, the returned union will be annotated in accordance with 

127 the union mode. 

128 annotated 

129 Deprecated. Use `plain=True` when you would have used 

130 `annotated=False`. 

131 """ 

132 # deprecation warning for annotated is in TrackingGroup 

133 return cls.__DYNAPYDANTIC__.union(plain=plain, annotated=annotated) 

134 

135 cls.union = staticmethod(_union) 

136 

137 def _subclasses() -> dict[str, type[pydantic.BaseModel]]: 

138 """Return a mapping of discriminator values to registered model""" 

139 return cls.__DYNAPYDANTIC__.models 

140 

141 cls.registered_subclasses = staticmethod(_subclasses) 

142 

143 return 

144 

145 super().__pydantic_init_subclass__(*args, **kwargs) 

146 

147 if exclude_from_union: 

148 return 

149 

150 supers = direct_children_of_base_in_mro(cls, SubclassTrackingModel) 

151 for base in supers: 

152 base.__DYNAPYDANTIC__.register_model(cls) 

153 

154 class PydanticAdaptor: 

155 """Pydantic type adaptor for SubclassTrackingModel""" 

156 

157 @staticmethod 

158 def __get_pydantic_core_schema__( 

159 source_type: ty.Any, # noqa: ANN401 

160 handler: GetCoreSchemaHandler, 

161 ) -> core_schema.CoreSchema: 

162 """Get the pydantic schema for this type""" 

163 if not isinstance(source_type, type) or not issubclass( 

164 source_type, 

165 SubclassTrackingModel, 

166 ): 

167 msg = ( 

168 f"{source_type} was not a SubclassTrackingModel, " 

169 "so it is incompatible with dynapydantic.Polymorphic" 

170 ) 

171 raise PydanticSchemaGenerationError(msg) 

172 return handler(source_type.union())