Coverage for src / dynapydantic / annotations.py: 100%

34 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-13 20:14 +0000

1"""Custom annotations for dynapydantic""" 

2 

3import typing as ty 

4 

5from pydantic import GetCoreSchemaHandler, PydanticSchemaGenerationError 

6from pydantic_core import core_schema 

7 

8from .free_funcs import union 

9from .subclass_tracking_model import ( 

10 SubclassTrackingModel, 

11 ValidationTimeAdapter, 

12) 

13from .tracking_group import TrackingGroup 

14from .union_mode import UnionRealization 

15 

16ModelT = ty.TypeVar("ModelT", bound=SubclassTrackingModel) 

17 

18 

19class ModelConstructionTimeAdapter: 

20 """Pydantic type adapter for SubclassTrackingModel""" 

21 

22 @staticmethod 

23 def __get_pydantic_core_schema__( 

24 source_type: type[SubclassTrackingModel], 

25 handler: GetCoreSchemaHandler, 

26 ) -> core_schema.CoreSchema: 

27 """Get the pydantic schema for this type""" 

28 return handler(union(source_type)) 

29 

30 

31class Polymorphic: 

32 """Annotation used to mark a type as having duck-typing behavior 

33 

34 This annotation is only valid for SubclassTrackingModel's. 

35 

36 Similar to SerializeAsAny, a field annotated with this shall serialize as 

37 according to its actual type, not the field annotation type. In addition, 

38 parsing will function as if the field annotation type were the union of 

39 all tracked subclasses. 

40 

41 If a UnionRealization (or the string value of one) is passed as the 

42 second argument, it will override the default value for the union 

43 realization that is stored in the class. 

44 """ 

45 

46 def __class_getitem__( 

47 cls, 

48 item: type[ModelT] | tuple[type[ModelT], UnionRealization | str], 

49 ) -> ty.Annotated[type[ModelT], ...]: 

50 """Get the annotation for the pydantic field""" 

51 if isinstance(item, tuple): 

52 if len(item) > 2: # noqa: PLR2004 

53 msg = ( 

54 "dynapydantic.Polymorphic takes 1 or 2 arguments " 

55 f"({len(item)} given)" 

56 ) 

57 raise TypeError(msg) 

58 

59 return _polymorphic_cgi(*item) 

60 return _polymorphic_cgi(item) 

61 

62 

63def _polymorphic_cgi( 

64 cls: type[ModelT], 

65 union_realization: UnionRealization | str | None = None, 

66) -> ty.Annotated[type[ModelT], ...]: 

67 if not isinstance(cls, type): 

68 msg = f"dynapydantic.Polymorphic must be given a type, not {cls}" 

69 raise TypeError(msg) 

70 

71 if not issubclass(cls, SubclassTrackingModel): 

72 msg = f"Polymorphic was given {cls}, which was not a SubclassTrackingModel." 

73 raise PydanticSchemaGenerationError(msg) 

74 

75 cfg = cls.__DYNAPYDANTIC_STM_CONFIG__ 

76 union_realization = ( 

77 UnionRealization(union_realization) 

78 if union_realization is not None 

79 else cfg.union_realization 

80 ) 

81 adapter = ( 

82 ValidationTimeAdapter 

83 if union_realization == UnionRealization.VALIDATION 

84 else ModelConstructionTimeAdapter 

85 ) 

86 return ty.Annotated[cls, adapter] # type: ignore[bad-return] 

87 

88 

89class Union: 

90 """Annotation used to get the union out of a dynapydantic entity 

91 

92 This annotation is primarily used for using the union of all models in 

93 a `TrackingGroup` as a field annotation. It can be used with 

94 `SubclassTrackingModel`, but in general, `Polymorphic` is preferable. 

95 """ 

96 

97 @ty.overload 

98 def __class_getitem__(cls, item: type[ModelT]) -> type[ModelT]: ... 

99 

100 @ty.overload 

101 def __class_getitem__(cls, item: TrackingGroup) -> ty.Any: ... # noqa: ANN401 

102 

103 def __class_getitem__(cls, item: TrackingGroup | type[ModelT]) -> object: 

104 """Return the union""" 

105 return union(item)