Coverage for dj/typing.py: 100%
152 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
1"""
2Custom types for annotations.
3"""
5# pylint: disable=missing-class-docstring
7from __future__ import annotations
9import datetime
10from enum import Enum
11from types import ModuleType
12from typing import Any, Iterator, List, Literal, Optional, Tuple, TypedDict, Union
14from pydantic.datetime_parse import parse_datetime
15from typing_extensions import Protocol
18class SQLADialect(Protocol): # pylint: disable=too-few-public-methods
19 """
20 A SQLAlchemy dialect.
21 """
23 dbapi: ModuleType
26# The ``type_code`` in a cursor description -- can really be anything
27TypeCode = Any
30# Cursor description
31Description = Optional[
32 List[
33 Tuple[
34 str,
35 TypeCode,
36 Optional[str],
37 Optional[str],
38 Optional[str],
39 Optional[str],
40 Optional[bool],
41 ]
42 ]
43]
46# A stream of data
47Row = Tuple[Any, ...]
48Stream = Iterator[Row]
51class TypeEnum(str, Enum):
52 """
53 PEP 249 basic types.
55 Unfortunately SQLAlchemy doesn't seem to offer an API for determining the types of the
56 columns in a (SQL Core) query, and the DB API 2.0 cursor only offers very coarse
57 types.
58 """
60 STRING = "STRING"
61 BINARY = "BINARY"
62 NUMBER = "NUMBER"
63 TIMESTAMP = "TIMESTAMP"
64 UNKNOWN = "UNKNOWN"
67class QueryState(str, Enum):
68 """
69 Different states of a query.
70 """
72 UNKNOWN = "UNKNOWN"
73 ACCEPTED = "ACCEPTED"
74 SCHEDULED = "SCHEDULED"
75 RUNNING = "RUNNING"
76 FINISHED = "FINISHED"
77 CANCELED = "CANCELED"
78 FAILED = "FAILED"
81# sqloxide type hints
82# Reference: https://github.com/sqlparser-rs/sqlparser-rs/blob/main/src/ast/query.rs
85class Value(TypedDict, total=False):
86 Number: Tuple[str, bool]
87 SingleQuotedString: str
88 Boolean: bool
91class Limit(TypedDict):
92 Value: Value
95class Identifier(TypedDict):
96 quote_style: Optional[str]
97 value: str
100class Bound(TypedDict, total=False):
101 Following: int
102 Preceding: int
105class WindowFrame(TypedDict):
106 end_bound: Bound
107 start_bound: Bound
108 units: str
111class Expression(TypedDict, total=False):
112 CompoundIdentifier: List["Identifier"]
113 Identifier: Identifier
114 Value: Value
115 Function: Function # type: ignore
116 UnaryOp: UnaryOp # type: ignore
117 BinaryOp: BinaryOp # type: ignore
118 Case: Case # type: ignore
121class Case(TypedDict):
122 conditions: List[Expression]
123 else_result: Optional[Expression]
124 operand: Optional[Expression]
125 results: List[Expression]
128class UnnamedArgument(TypedDict):
129 Expr: Expression
132class Argument(TypedDict, total=False):
133 Unnamed: Union[UnnamedArgument, Wildcard]
136class Over(TypedDict):
137 order_by: List[Expression]
138 partition_by: List[Expression]
139 window_frame: WindowFrame
142class Function(TypedDict):
143 args: List[Argument]
144 distinct: bool
145 name: List[Identifier]
146 over: Optional[Over]
149class ExpressionWithAlias(TypedDict):
150 alias: Identifier
151 expr: Expression
154class Offset(TypedDict):
155 rows: str
156 value: Expression
159class OrderBy(TypedDict, total=False):
160 asc: Optional[bool]
161 expr: Expression
162 nulls_first: Optional[bool]
165class Projection(TypedDict, total=False):
166 ExprWithAlias: ExpressionWithAlias
167 UnnamedExpr: Expression
170Wildcard = Literal["Wildcard"]
173class Fetch(TypedDict):
174 percent: bool
175 quantity: Value
176 with_ties: bool
179Top = Fetch
182class UnaryOp(TypedDict):
183 op: str
184 expr: Expression
187class BinaryOp(TypedDict):
188 left: Expression
189 op: str
190 right: Expression
193class LateralView(TypedDict):
194 lateral_col_alias: List[Identifier]
195 lateral_view: Expression
196 lateral_view_name: List[Identifier]
197 outer: bool
200class TableAlias(TypedDict):
201 columns: List[Identifier]
202 name: Identifier
205class Table(TypedDict):
206 alias: Optional[TableAlias]
207 args: List[Argument]
208 name: List[Identifier]
209 with_hints: List[Expression]
212class Derived(TypedDict):
213 lateral: bool
214 subquery: "Body" # type: ignore
215 alias: Optional[TableAlias]
218class Relation(TypedDict, total=False):
219 Table: Table
220 Derived: Derived
223class JoinConstraint(TypedDict):
224 On: Expression
225 Using: List[Identifier]
228class JoinOperator(TypedDict, total=False):
229 Inner: JoinConstraint
230 LeftOuter: JoinConstraint
231 RightOuter: JoinConstraint
232 FullOuter: JoinConstraint
235CrossJoin = Literal["CrossJoin"]
236CrossApply = Literal["CrossApply"]
237OuterApply = Literal["Outerapply"]
240class Join(TypedDict):
241 join_operator: Union[JoinOperator, CrossJoin, CrossApply, OuterApply]
242 relation: Relation
245class From(TypedDict):
246 joins: List[Join]
247 relation: Relation
250Select = TypedDict(
251 "Select",
252 {
253 "cluster_by": List[Expression],
254 "distinct": bool,
255 "distribute_by": List[Expression],
256 "from": List[From],
257 "group_by": List[Expression],
258 "having": Optional[BinaryOp],
259 "lateral_views": List[LateralView],
260 "projection": List[Union[Projection, Wildcard]],
261 "selection": Optional[BinaryOp],
262 "sort_by": List[Expression],
263 "top": Optional[Top],
264 },
265)
268class Body(TypedDict):
269 Select: Select
272CTETable = TypedDict(
273 "CTETable",
274 {
275 "alias": TableAlias,
276 "from": Optional[Identifier],
277 "query": "Query", # type: ignore
278 },
279)
282class With(TypedDict):
283 cte_tables: List[CTETable]
286Query = TypedDict(
287 "Query",
288 {
289 "body": Body,
290 "fetch": Optional[Fetch],
291 "limit": Optional[Limit],
292 "lock": Optional[Literal["Share", "Update"]],
293 "offset": Optional[Offset],
294 "order_by": List[OrderBy],
295 "with": Optional[With],
296 },
297)
300# We could support more than just ``SELECT`` here.
301class Statement(TypedDict):
302 Query: Query
305# A parse tree, result of ``sqloxide.parse_sql``.
306ParseTree = List[Statement] # type: ignore
309class UTCDatetime(datetime.datetime):
310 """
311 A UTC extension of pydantic's normal datetime handling
312 """
314 @classmethod
315 def __get_validators__(cls):
316 """
317 Extend the builtin pydantic datetime parser with a custom validate method
318 """
319 yield parse_datetime
320 yield cls.validate
322 @classmethod
323 def validate(cls, value) -> str:
324 """
325 Convert to UTC
326 """
327 if value.tzinfo is None:
328 return value.replace(tzinfo=datetime.timezone.utc)
330 return value.astimezone(datetime.timezone.utc)