Coverage for dj/models/query.py: 100%
72 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
1"""
2Models for queries.
3"""
5import uuid
6from datetime import datetime
7from enum import Enum
8from typing import Any, List, Optional
10import msgpack
11from pydantic import AnyHttpUrl, validator
12from sqlmodel import Field, SQLModel
14from dj.models.base import BaseSQLModel
15from dj.typing import QueryState, Row
18class BaseQuery(SQLModel):
19 """
20 Base class for query models.
21 """
23 catalog_name: Optional[str]
24 engine_name: Optional[str] = None
25 engine_version: Optional[str] = None
27 class Config: # pylint: disable=too-few-public-methods, missing-class-docstring
28 allow_population_by_field_name = True
31class QueryCreate(BaseQuery):
32 """
33 Model for submitted queries.
34 """
36 submitted_query: str
37 async_: bool = False
40class ColumnMetadata(BaseSQLModel):
41 """
42 A simple model for column metadata.
43 """
45 name: str
46 type: str
49class StatementResults(BaseSQLModel):
50 """
51 Results for a given statement.
53 This contains the SQL, column names and types, and rows
54 """
56 sql: str
57 columns: List[ColumnMetadata]
58 rows: List[Row]
60 # this indicates the total number of rows, and is useful for paginated requests
61 row_count: int = 0
64class QueryResults(BaseSQLModel):
65 """
66 Results for a given query.
67 """
69 __root__: List[StatementResults]
72class TableRef(BaseSQLModel):
73 """
74 Table reference
75 """
77 catalog: str
78 schema_: str = Field(alias="schema")
79 table: str
82class QueryWithResults(BaseSQLModel):
83 """
84 Model for query with results.
85 """
87 id: uuid.UUID
88 engine_name: Optional[str] = None
89 engine_version: Optional[str] = None
90 submitted_query: str
91 executed_query: Optional[str] = None
93 scheduled: Optional[datetime] = None
94 started: Optional[datetime] = None
95 finished: Optional[datetime] = None
97 state: QueryState = QueryState.UNKNOWN
98 progress: float = 0.0
100 output_table: Optional[TableRef]
101 results: QueryResults
102 next: Optional[AnyHttpUrl] = None
103 previous: Optional[AnyHttpUrl] = None
104 errors: List[str]
106 @validator("scheduled", pre=True)
107 def parse_scheduled_date_string(cls, value): # pylint: disable=no-self-argument
108 """
109 Convert string date values to datetime
110 """
111 return datetime.fromisoformat(value) if isinstance(value, str) else value
113 @validator("started", pre=True)
114 def parse_started_date_string(cls, value): # pylint: disable=no-self-argument
115 """
116 Convert string date values to datetime
117 """
118 return datetime.fromisoformat(value) if isinstance(value, str) else value
120 @validator("finished", pre=True)
121 def parse_finisheddate_string(cls, value): # pylint: disable=no-self-argument
122 """
123 Convert string date values to datetime
124 """
125 return datetime.fromisoformat(value) if isinstance(value, str) else value
128class QueryExtType(int, Enum):
129 """
130 Custom ext type for msgpack.
131 """
133 UUID = 1
134 TIMESTAMP = 2
137def encode_results(obj: Any) -> Any:
138 """
139 Custom msgpack encoder for ``QueryWithResults``.
140 """
141 if isinstance(obj, uuid.UUID):
142 return msgpack.ExtType(QueryExtType.UUID, str(obj).encode("utf-8"))
144 if isinstance(obj, datetime):
145 return msgpack.ExtType(QueryExtType.TIMESTAMP, obj.isoformat().encode("utf-8"))
147 return obj
150def decode_results(code: int, data: bytes) -> Any:
151 """
152 Custom msgpack decoder for ``QueryWithResults``.
153 """
154 if code == QueryExtType.UUID:
155 return uuid.UUID(data.decode())
157 if code == QueryExtType.TIMESTAMP:
158 return datetime.fromisoformat(data.decode())
160 return msgpack.ExtType(code, data)