Coverage for dj/construction/dj_query.py: 100%
52 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
1"""
2Functions for making queries directly against DJ
3"""
5from typing import List, Optional, Set, cast
7from sqlmodel import Session
9from dj.construction.build import build_ast
10from dj.construction.utils import amenable_name, get_dj_node
11from dj.errors import DJErrorException
12from dj.models.node import NodeRevision, NodeType
13from dj.sql.parsing.backends.antlr4 import ast, parse
14from dj.sql.parsing.backends.exceptions import DJParseException
17def try_get_dj_node(
18 session: Session,
19 name: str,
20 kinds: Set[NodeType],
21) -> Optional[NodeRevision]:
22 "wraps get dj node to return None if no node is found"
23 try:
24 return get_dj_node(session, name, kinds)
25 except DJErrorException:
26 return None
29def _resolve_metric_nodes(session, col):
30 """
31 Check if a column is a metric and modify the
32 select accordingly
33 """
34 joins = []
35 col_name = col.identifier(False)
36 if metric_node := try_get_dj_node(
37 session,
38 col_name,
39 {NodeType.METRIC},
40 ):
41 # if we found a metric node we need to check where it came from
42 parent_select = cast(ast.Select, col.get_nearest_parent_of_type(ast.Select))
43 if not getattr(
44 parent_select,
45 "_validated",
46 False,
47 ): # pragma: no cover
48 if (
49 len(parent_select.from_.relations) != 1
50 or parent_select.from_.relations[0].primary.alias_or_name.name
51 != "metrics"
52 ):
53 raise DJParseException(
54 "Any SELECT referencing a Metric must source "
55 "from a single unaliased Table named `metrics`.",
56 )
57 metrics_ref = parent_select.from_.relations[0].primary
58 try:
59 metrics_ref_name = metrics_ref.alias_or_name.identifier(False)
60 except AttributeError: # pragma: no cover
61 metrics_ref_name = ""
62 if metrics_ref_name != "metrics":
63 raise DJParseException(
64 "The name of the table in a Metric query must be `metrics`.",
65 )
66 parent_select.from_ = ast.From(
67 [],
68 ) # clear the FROM to prep it for the actual tables
69 parent_select._validated = True # pylint: disable=W0212
71 # we have a metric from `metrics`
72 metric_name = amenable_name(metric_node.name)
73 metric_select = parse( # pylint: disable=W0212
74 cast(str, metric_node.query),
75 ).select
76 tables = metric_select.from_.find_all(ast.Table)
77 metric_table_expression = ast.Alias(
78 ast.Name(metric_name),
79 None,
80 metric_select,
81 )
83 for table in tables:
84 joins += _hoist_metric_source_tables(
85 session,
86 table,
87 metric_select,
88 metric_table_expression,
89 )
91 metric_column = ast.Column(
92 ast.Name(metric_node.columns[0].name),
93 _table=metric_table_expression,
94 as_=True,
95 )
97 metric_table_expression.child.parenthesized = True
98 parent_select.replace(col, metric_column)
99 parent_select.from_.relations = [
100 ast.Relation(primary=metric_table_expression.child, extensions=joins),
101 ]
104def _hoist_metric_source_tables(
105 session,
106 table,
107 metric_select,
108 metric_table_expression,
109) -> List[ast.Join]:
110 """
111 Hoist tables in a metric query
112 we go through all the dep nodes directly in the metric's FROM
113 we need to surface the node itself to join potential dims
114 and to surface the node we need to source all its columns
115 - in the metric for an implicit join
116 """
117 joins = []
118 if isinstance(table, ast.Select):
119 return [] # pragma: no cover
120 if isinstance(table, ast.Alias):
121 if isinstance(table.child, ast.Select): # pragma: no cover
122 return [] # pragma: no cover
123 table = table.child # pragma: no cover
124 table_name = table.identifier(False)
125 if table_node := try_get_dj_node( # pragma: no cover
126 session,
127 table_name,
128 {NodeType.SOURCE, NodeType.TRANSFORM, NodeType.DIMENSION},
129 ):
130 source_cols = []
131 for tbl_col in table_node.columns:
132 source_cols.append(_make_source_columns(tbl_col, table))
133 # add the source's columns to the metric projection
134 # so we can left join hoist the source alongside the metric select
135 # so that dimensions can join properly in build
136 metric_select.projection += source_cols
137 # make the comparison expressions for the left join
138 # that will hoist the source up
139 ons = []
140 for src_col in source_cols:
141 ons.append(
142 _source_column_join_on_expression(src_col, metric_table_expression),
143 )
144 # make the join
145 if ons: # pragma: no cover
146 joins.append(
147 ast.Join(
148 join_type="LEFT OUTER",
149 right=table.copy(),
150 criteria=ast.JoinCriteria(on=ast.BinaryOp.And(*ons)), # type: ignore # pylint: disable=no-value-for-parameter
151 ),
152 )
153 return joins
156def _make_source_columns(tbl_col, table) -> ast.Alias[ast.Column]:
157 """
158 Make the source columns for hoisting
159 """
160 temp_col = ast.Column(
161 ast.Name(tbl_col.name),
162 _table=table,
163 as_=True,
164 )
165 return ast.Alias(
166 ast.Name(amenable_name(str(temp_col))),
167 child=temp_col,
168 )
171def _source_column_join_on_expression(
172 src_col,
173 metric_table_expression,
174) -> List[ast.BinaryOp]:
175 """
176 Make the part of the ON for the source column
177 """
178 return ast.BinaryOp.Eq( # type: ignore
179 ast.Column(
180 src_col.alias_or_name,
181 _table=metric_table_expression,
182 ),
183 src_col.child.copy(),
184 )
187def build_dj_metric_query( # pylint: disable=R0914,R0912
188 session: Session,
189 query: str,
190 dialect: Optional[str] = None, # pylint: disable=unused-argument
191) -> ast.Query:
192 """
193 Build a dj query in SQL that may include dj metrics
194 """
195 query_ast = parse(query)
196 select = query_ast.select
197 # we check all columns looking for metric nodes
198 for col in select.find_all(ast.Column):
199 _resolve_metric_nodes(session, col)
201 return build_ast(
202 session,
203 query=ast.Query(select=select),
204 build_criteria=None,
205 )