Coverage for dj/models/query.py: 100%

72 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-17 20:05 -0700

1""" 

2Models for queries. 

3""" 

4 

5import uuid 

6from datetime import datetime 

7from enum import Enum 

8from typing import Any, List, Optional 

9 

10import msgpack 

11from pydantic import AnyHttpUrl, validator 

12from sqlmodel import Field, SQLModel 

13 

14from dj.models.base import BaseSQLModel 

15from dj.typing import QueryState, Row 

16 

17 

18class BaseQuery(SQLModel): 

19 """ 

20 Base class for query models. 

21 """ 

22 

23 catalog_name: Optional[str] 

24 engine_name: Optional[str] = None 

25 engine_version: Optional[str] = None 

26 

27 class Config: # pylint: disable=too-few-public-methods, missing-class-docstring 

28 allow_population_by_field_name = True 

29 

30 

31class QueryCreate(BaseQuery): 

32 """ 

33 Model for submitted queries. 

34 """ 

35 

36 submitted_query: str 

37 async_: bool = False 

38 

39 

40class ColumnMetadata(BaseSQLModel): 

41 """ 

42 A simple model for column metadata. 

43 """ 

44 

45 name: str 

46 type: str 

47 

48 

49class StatementResults(BaseSQLModel): 

50 """ 

51 Results for a given statement. 

52 

53 This contains the SQL, column names and types, and rows 

54 """ 

55 

56 sql: str 

57 columns: List[ColumnMetadata] 

58 rows: List[Row] 

59 

60 # this indicates the total number of rows, and is useful for paginated requests 

61 row_count: int = 0 

62 

63 

64class QueryResults(BaseSQLModel): 

65 """ 

66 Results for a given query. 

67 """ 

68 

69 __root__: List[StatementResults] 

70 

71 

72class TableRef(BaseSQLModel): 

73 """ 

74 Table reference 

75 """ 

76 

77 catalog: str 

78 schema_: str = Field(alias="schema") 

79 table: str 

80 

81 

82class QueryWithResults(BaseSQLModel): 

83 """ 

84 Model for query with results. 

85 """ 

86 

87 id: uuid.UUID 

88 engine_name: Optional[str] = None 

89 engine_version: Optional[str] = None 

90 submitted_query: str 

91 executed_query: Optional[str] = None 

92 

93 scheduled: Optional[datetime] = None 

94 started: Optional[datetime] = None 

95 finished: Optional[datetime] = None 

96 

97 state: QueryState = QueryState.UNKNOWN 

98 progress: float = 0.0 

99 

100 output_table: Optional[TableRef] 

101 results: QueryResults 

102 next: Optional[AnyHttpUrl] = None 

103 previous: Optional[AnyHttpUrl] = None 

104 errors: List[str] 

105 

106 @validator("scheduled", pre=True) 

107 def parse_scheduled_date_string(cls, value): # pylint: disable=no-self-argument 

108 """ 

109 Convert string date values to datetime 

110 """ 

111 return datetime.fromisoformat(value) if isinstance(value, str) else value 

112 

113 @validator("started", pre=True) 

114 def parse_started_date_string(cls, value): # pylint: disable=no-self-argument 

115 """ 

116 Convert string date values to datetime 

117 """ 

118 return datetime.fromisoformat(value) if isinstance(value, str) else value 

119 

120 @validator("finished", pre=True) 

121 def parse_finisheddate_string(cls, value): # pylint: disable=no-self-argument 

122 """ 

123 Convert string date values to datetime 

124 """ 

125 return datetime.fromisoformat(value) if isinstance(value, str) else value 

126 

127 

128class QueryExtType(int, Enum): 

129 """ 

130 Custom ext type for msgpack. 

131 """ 

132 

133 UUID = 1 

134 TIMESTAMP = 2 

135 

136 

137def encode_results(obj: Any) -> Any: 

138 """ 

139 Custom msgpack encoder for ``QueryWithResults``. 

140 """ 

141 if isinstance(obj, uuid.UUID): 

142 return msgpack.ExtType(QueryExtType.UUID, str(obj).encode("utf-8")) 

143 

144 if isinstance(obj, datetime): 

145 return msgpack.ExtType(QueryExtType.TIMESTAMP, obj.isoformat().encode("utf-8")) 

146 

147 return obj 

148 

149 

150def decode_results(code: int, data: bytes) -> Any: 

151 """ 

152 Custom msgpack decoder for ``QueryWithResults``. 

153 """ 

154 if code == QueryExtType.UUID: 

155 return uuid.UUID(data.decode()) 

156 

157 if code == QueryExtType.TIMESTAMP: 

158 return datetime.fromisoformat(data.decode()) 

159 

160 return msgpack.ExtType(code, data)