Coverage for src / dynapydantic / subclass_tracking_model.py: 100%

49 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2026-02-06 15:20 +0000

1"""Base class for dynamic pydantic models""" 

2 

3import inspect 

4import typing as ty 

5 

6import pydantic 

7from pydantic import GetCoreSchemaHandler 

8from pydantic.errors import PydanticSchemaGenerationError 

9from pydantic_core import core_schema 

10 

11from .exceptions import ConfigurationError 

12from .tracking_group import TrackingGroup 

13 

14 

15def direct_children_of_base_in_mro(derived: type, base: type) -> list[type]: 

16 """Find all classes in derived's MRO that are direct subclasses of base. 

17 

18 Parameters 

19 ---------- 

20 derived 

21 The class whose MRO is being examined. 

22 base 

23 The base class to find direct subclasses of. 

24 

25 Returns 

26 ------- 

27 Classes in derived's MRO that are direct subclasses of base. 

28 """ 

29 return [cls for cls in derived.__mro__ if cls is not base and base in cls.__bases__] 

30 

31 

32class SubclassTrackingModel(pydantic.BaseModel): 

33 """Subclass-tracking BaseModel 

34 

35 This will inject a TrackingGroup into your class and automate the 

36 registration of subclasses. 

37 

38 Inheriting from this class will augment your class with the following 

39 members functions: 

40 1. registered_subclasses() -> dict[str, type[cls]]: 

41 This will return a mapping of discriminator value to the corresponding 

42 sublcass. See TrackingGroup.models for details. 

43 2. union() -> typing.GenericAlias: 

44 This will return an (optionally) annotated subclass union. See 

45 TrackingGroup.union() for details. 

46 3. load_plugins() -> None: 

47 If plugin_entry_point was specified, then this method will load plugin 

48 packages to discover additional subclasses. See 

49 TrackingGroup.load_plugins for more details. 

50 """ 

51 

52 def __init_subclass__(cls, *args, **kwargs) -> None: 

53 """Subclass hook""" 

54 # Intercept any kwargs that are intended for TrackingGroup or 

55 # __pydantic_init_subclass__ 

56 sig = inspect.signature(SubclassTrackingModel.__pydantic_init_subclass__) 

57 super().__init_subclass__( 

58 *args, 

59 **{ 

60 k: v 

61 for k, v in kwargs.items() 

62 if k not in TrackingGroup.model_fields and k not in sig.parameters 

63 }, 

64 ) 

65 

66 @classmethod 

67 def __pydantic_init_subclass__( 

68 cls, 

69 *args, 

70 exclude_from_union: bool | None = None, 

71 **kwargs, 

72 ) -> None: 

73 """Pydantic subclass hook""" 

74 if SubclassTrackingModel in cls.__bases__: 

75 # Intercept any kwargs that are intended for TrackingGroup 

76 super().__pydantic_init_subclass__( 

77 *args, 

78 **{ 

79 k: v 

80 for k, v in kwargs.items() 

81 if k not in TrackingGroup.model_fields 

82 }, 

83 ) 

84 

85 if isinstance((tc := getattr(cls, "tracking_config", None)), TrackingGroup): 

86 cls.__DYNAPYDANTIC__ = tc 

87 else: 

88 try: 

89 cls.__DYNAPYDANTIC__: TrackingGroup = TrackingGroup.model_validate( 

90 {"name": f"{cls.__name__}-subclasses"} | kwargs, 

91 ) 

92 except pydantic.ValidationError as e: 

93 msg = ( 

94 "SubclassTrackingModel subclasses must either have a " 

95 "tracking_config: ClassVar[dynapydantic.TrackingGroup] " 

96 "member or pass kwargs sufficient to construct a " 

97 "dynapydantic.TrackingGroup in the class declaration. " 

98 "The latter approach produced the following " 

99 f"ValidationError:\n{e}" 

100 ) 

101 raise ConfigurationError(msg) from e 

102 

103 # Promote the tracking group's methods to the parent class 

104 if cls.__DYNAPYDANTIC__.plugin_entry_point is not None: 

105 

106 def _load_plugins() -> None: 

107 """Load plugins to register more models""" 

108 cls.__DYNAPYDANTIC__.load_plugins() 

109 

110 cls.load_plugins = staticmethod(_load_plugins) 

111 

112 def _union(*, annotated: bool = True) -> ty.GenericAlias: 

113 """Get the union of all tracked subclasses 

114 

115 Parameters 

116 ---------- 

117 annotated 

118 Whether this should be an annotated union for usage as a 

119 pydantic field annotation, or a plain typing.Union for a 

120 regular type annotation. 

121 """ 

122 return cls.__DYNAPYDANTIC__.union(annotated=annotated) 

123 

124 cls.union = staticmethod(_union) 

125 

126 def _subclasses() -> dict[str, type[pydantic.BaseModel]]: 

127 """Return a mapping of discriminator values to registered model""" 

128 return cls.__DYNAPYDANTIC__.models 

129 

130 cls.registered_subclasses = staticmethod(_subclasses) 

131 

132 return 

133 

134 super().__pydantic_init_subclass__(*args, **kwargs) 

135 

136 if exclude_from_union: 

137 return 

138 

139 supers = direct_children_of_base_in_mro(cls, SubclassTrackingModel) 

140 for base in supers: 

141 base.__DYNAPYDANTIC__.register_model(cls) 

142 

143 class PydanticAdaptor: 

144 """Pydantic type adaptor for SubclassTrackingModel""" 

145 

146 @staticmethod 

147 def __get_pydantic_core_schema__( 

148 source_type: ty.Any, # noqa: ANN401 

149 handler: GetCoreSchemaHandler, 

150 ) -> core_schema.CoreSchema: 

151 """Get the pydantic schema for this type""" 

152 if not isinstance(source_type, type) or not issubclass( 

153 source_type, 

154 SubclassTrackingModel, 

155 ): 

156 msg = ( 

157 f"{source_type} was not a SubclassTrackingModel, " 

158 "so it is incompatible with dynapydantic.Polymorphic" 

159 ) 

160 raise PydanticSchemaGenerationError(msg) 

161 return handler(source_type.union())