Coverage for dj/construction/build.py: 100%
148 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
1"""Functions to add to an ast DJ node queries"""
2import collections
4# pylint: disable=too-many-arguments,too-many-locals,too-many-nested-blocks,too-many-branches,R0401
5from typing import DefaultDict, Deque, Dict, List, Optional, Set, Tuple, Union, cast
7from sqlmodel import Session
9from dj.construction.utils import amenable_name, to_namespaced_name
10from dj.errors import DJException, DJInvalidInputException
11from dj.models.column import Column
12from dj.models.engine import Dialect
13from dj.models.node import BuildCriteria, NodeRevision, NodeType
14from dj.sql.parsing.ast import CompileContext
15from dj.sql.parsing.backends.antlr4 import ast, parse
18def _get_tables_from_select(
19 select: ast.SelectExpression,
20) -> DefaultDict[NodeRevision, List[ast.Table]]:
21 """
22 Extract all tables (source, transform, dimensions)
23 directly on the select that have an attached DJ node
24 """
25 tables: DefaultDict[NodeRevision, List[ast.Table]] = collections.defaultdict(list)
27 for table in select.find_all(ast.Table):
28 if node := table.dj_node: # pragma: no cover
29 tables[node].append(table)
30 return tables
33def _join_path(
34 dimension_node: NodeRevision,
35 initial_nodes: Set[NodeRevision],
36) -> Tuple[NodeRevision, Dict[Tuple[NodeRevision, NodeRevision], List[Column]]]:
37 """
38 For a dimension node, we want to find a possible join path between it
39 and any of the nodes that are directly referenced in the original query. If
40 no join path exists, returns an empty dict.
41 """
42 processed = set()
44 to_process: Deque[
45 Tuple[NodeRevision, Dict[Tuple[NodeRevision], List[Column]]]
46 ] = collections.deque([])
47 join_info: Dict[Tuple[NodeRevision], List[Column]] = {}
48 to_process.extend([(node, join_info.copy()) for node in initial_nodes])
49 possible_join_paths = []
51 while to_process:
52 current_node, path = to_process.popleft()
53 processed.add(current_node)
54 dimensions_to_columns = collections.defaultdict(list)
56 # From the columns on the current node, find the next layer of
57 # dimension nodes that can be joined in
58 for col in current_node.columns:
59 if col.dimension and col.dimension.type == NodeType.DIMENSION:
60 dimensions_to_columns[col.dimension.current].append(col)
62 # Go through all potential dimensions and their join columns
63 for joinable_dim, join_cols in dimensions_to_columns.items():
64 next_join_path = {**path, **{(current_node, joinable_dim): join_cols}}
65 full_join_path = (joinable_dim, next_join_path)
66 if joinable_dim == dimension_node:
67 for col in join_cols:
68 if col.dimension_column is None and not any(
69 dim_col.name == "id" for dim_col in dimension_node.columns
70 ):
71 raise DJException(
72 f"Node {current_node.name} specifying dimension "
73 f"{joinable_dim.name} on column {col.name} does not"
74 f" specify a dimension column, but {dimension_node.name} "
75 f"does not have the default key `id`.",
76 )
77 possible_join_paths.append(full_join_path) # type: ignore
78 if joinable_dim not in processed: # pragma: no cover
79 to_process.append(full_join_path)
80 for parent in joinable_dim.parents:
81 to_process.append((parent.current, next_join_path))
82 return min(possible_join_paths, key=len) # type: ignore
85def _get_or_build_join_table(
86 session: Session,
87 table_node: NodeRevision,
88 build_criteria: Optional[BuildCriteria],
89):
90 """
91 Build the join table from a materialization if one is available, or recurse
92 to build it from the dimension node's query if not
93 """
94 table_node_alias = amenable_name(table_node.name)
95 join_table = cast(
96 Optional[ast.TableExpression],
97 _get_node_table(table_node, build_criteria),
98 )
99 if not join_table: # pragma: no cover
100 join_query = parse(cast(str, table_node.query))
101 join_table = build_ast(session, join_query) # type: ignore
102 join_table.parenthesized = True # type: ignore
104 for col in join_table.columns:
105 col._table = join_table # pylint: disable=protected-access
107 join_table = cast(ast.TableExpression, join_table) # type: ignore
108 right_alias = ast.Name(table_node_alias)
109 join_right = ast.Alias( # type: ignore
110 right_alias,
111 child=join_table,
112 as_=True,
113 )
114 join_table.set_alias(right_alias) # type: ignore
115 return join_right
118def _build_joins_for_dimension(
119 session: Session,
120 dim_node: NodeRevision,
121 initial_nodes: Set[NodeRevision],
122 tables: DefaultDict[NodeRevision, List[ast.Table]],
123 build_criteria: Optional[BuildCriteria],
124 required_dimension_columns: List[ast.Column],
125) -> List[ast.Join]:
126 """
127 Returns the join ASTs needed to bring in the dimension node from
128 the set of initial nodes.
129 """
130 _, paths = _join_path(dim_node, initial_nodes)
131 asts = []
132 for connecting_nodes, join_columns in paths.items():
133 start_node, table_node = connecting_nodes # type: ignore
134 join_on = []
136 # Assemble table on left of join
137 left_table = (
138 tables[start_node][0].child # type: ignore
139 if isinstance(tables[start_node][0], ast.Alias)
140 else tables[start_node][0]
141 )
142 join_left_columns = {
143 col.alias_or_name.name: col for col in left_table.columns # type: ignore
144 }
146 # Assemble table on right of join
147 join_right = _get_or_build_join_table(
148 session,
149 table_node,
150 build_criteria,
151 )
153 # Optimize query by filtering down to only the necessary columns
154 selected_columns = {col.name.name for col in required_dimension_columns}
155 available_join_columns = {
156 col.dimension_column for col in join_columns if col.dimension_column
157 }
158 primary_key_columns = {col.name for col in table_node.primary_key()}
159 joinable_dim_columns = {
160 col.name for col in table_node.columns if col.dimension_id
161 }
162 required_mapping = (
163 selected_columns.union(available_join_columns)
164 .union(primary_key_columns)
165 .union(joinable_dim_columns)
166 )
167 join_right.child.select.projection = [
168 col
169 for col in join_right.child.select.projection
170 if col.alias_or_name.name in required_mapping
171 ]
173 initial_nodes.add(table_node)
174 tables[table_node].append(join_right) # type: ignore
175 join_right_columns = {
176 col.alias_or_name.name: col # type: ignore
177 for col in join_right.child.columns
178 }
180 # Assemble join ON clause
181 for join_col in join_columns:
182 join_table_pk = table_node.primary_key()
183 if join_col.name in join_left_columns and (
184 join_col.dimension_column in join_right_columns
185 or join_table_pk[0].name in join_right_columns
186 ):
187 left_table.add_ref_column(
188 cast(ast.Column, join_left_columns[join_col.name]),
189 )
190 join_on.append(
191 ast.BinaryOp.Eq(
192 join_left_columns[join_col.name],
193 join_right_columns[
194 join_col.dimension_column or join_table_pk[0].name
195 ],
196 use_alias_as_name=True,
197 ),
198 )
199 else:
200 raise DJInvalidInputException( # pragma: no cover
201 f"The specified join column {join_col.dimension_column} "
202 f"does not exist on {table_node.name}",
203 )
204 for dim_col in required_dimension_columns:
205 join_right.child.add_ref_column(dim_col)
207 if join_on: # pragma: no cover
208 asts.append(
209 ast.Join(
210 "LEFT OUTER",
211 join_right, # type: ignore
212 ast.JoinCriteria(
213 on=ast.BinaryOp.And(*join_on), # pylint: disable=E1120
214 ),
215 ),
216 )
217 return asts
220def join_tables_for_dimensions(
221 session: Session,
222 dimension_nodes_to_columns: Dict[NodeRevision, List[ast.Column]],
223 tables: DefaultDict[NodeRevision, List[ast.Table]],
224 build_criteria: Optional[BuildCriteria] = None,
225):
226 """
227 Joins the tables necessary for a set of filter and group by dimensions
228 onto the select expression.
230 In some cases, the necessary tables will already be on the select and
231 no additional joins will be needed. However, if the tables are not in
232 the select, it will traverse through available linked tables (via dimension
233 nodes) and join them in.
234 """
235 for dim_node, required_dimension_columns in sorted(
236 dimension_nodes_to_columns.items(),
237 key=lambda x: x[0].name,
238 ):
239 # Find all the selects that contain the different dimension columns
240 selects_map = {
241 cast(ast.Select, dim_col.get_nearest_parent_of_type(ast.Select))
242 for dim_col in required_dimension_columns
243 }
245 # Join the source tables (if necessary) for these dimension columns
246 # onto each select clause
247 for select in selects_map:
248 initial_nodes = set(tables)
249 if dim_node not in initial_nodes: # need to join dimension
250 join_asts = _build_joins_for_dimension(
251 session,
252 dim_node,
253 initial_nodes,
254 tables,
255 build_criteria,
256 required_dimension_columns,
257 )
258 if join_asts and select.from_:
259 select.from_.relations[-1].extensions.extend( # pragma: no cover
260 join_asts,
261 )
264def _build_tables_on_select(
265 session: Session,
266 select: ast.SelectExpression,
267 tables: Dict[NodeRevision, List[ast.Table]],
268 build_criteria: Optional[BuildCriteria] = None,
269):
270 """
271 Add all nodes not agg or filter dimensions to the select
272 """
273 for node, tbls in tables.items():
274 node_table = cast(
275 Optional[ast.Table],
276 _get_node_table(node, build_criteria),
277 ) # got a materialization
278 if node_table is None: # no materialization - recurse to node first
279 node_query = parse(cast(str, node.query))
280 node_table = build_ast( # type: ignore
281 session,
282 node_query,
283 build_criteria,
284 ).select # pylint: disable=W0212
285 node_table.parenthesized = True # type: ignore
287 alias = amenable_name(node.node.name)
288 context = CompileContext(session=session, exception=DJException())
290 node_ast = ast.Alias(ast.Name(alias), child=node_table, as_=True) # type: ignore
291 for tbl in tbls:
292 if isinstance(node_ast.child, ast.Select) and isinstance(tbl, ast.Alias):
293 node_ast.child.projection = [
294 col
295 for col in node_ast.child.projection
296 if col in set(tbl.child.select.projection)
297 ]
298 node_ast.compile(context)
299 select.replace(tbl, node_ast)
302def dimension_columns_mapping(
303 select: ast.SelectExpression,
304) -> Dict[NodeRevision, List[ast.Column]]:
305 """
306 Extract all dimension nodes referenced by columns
307 """
308 dimension_nodes_to_columns: Dict[NodeRevision, List[ast.Column]] = {}
310 for col in select.find_all(ast.Column):
311 if isinstance(col.table, ast.Table):
312 if node := col.table.dj_node: # pragma: no cover
313 if node.type == NodeType.DIMENSION:
314 dimension_nodes_to_columns[node] = dimension_nodes_to_columns.get(
315 node,
316 [],
317 )
318 dimension_nodes_to_columns[node].append(col)
319 return dimension_nodes_to_columns
322# flake8: noqa: C901
323def _build_select_ast(
324 session: Session,
325 select: ast.SelectExpression,
326 build_criteria: Optional[BuildCriteria] = None,
327):
328 """
329 Transforms a select ast by replacing dj node references with their asts
330 Starts by extracting all dimensions-backed columns from filters + group bys.
331 Some of them can be sourced directly from tables on the select, others cannot
332 For the ones that cannot be sourced directly, attempt to join them via dimension links.
333 """
334 tables = _get_tables_from_select(select)
335 dimension_columns = dimension_columns_mapping(select)
336 join_tables_for_dimensions(session, dimension_columns, tables, build_criteria)
337 _build_tables_on_select(session, select, tables, build_criteria)
340def add_filters_and_dimensions_to_query_ast(
341 query: ast.Query,
342 dialect: Optional[str] = None, # pylint: disable=unused-argument
343 filters: Optional[List[str]] = None,
344 dimensions: Optional[List[str]] = None,
345):
346 """
347 Add filters and dimensions to a query ast
348 """
349 projection_addition = []
350 if filters:
351 filter_asts = ( # pylint: disable=consider-using-ternary
352 query.select.where and [query.select.where] or []
353 )
355 for filter_ in filters:
356 temp_select = parse(f"select * where {filter_}").select
357 filter_asts.append(
358 # use parse to get the asts from the strings we got
359 temp_select.where, # type:ignore
360 )
361 query.select.where = ast.BinaryOp.And(*filter_asts)
363 if dimensions:
364 for agg in dimensions:
365 temp_select = parse(
366 f"select * group by {agg}",
367 ).select
368 query.select.group_by += temp_select.group_by # type:ignore
369 projection_addition += list(temp_select.find_all(ast.Column))
370 query.select.projection += list(projection_addition)
372 # Cannot select for columns that aren't in GROUP BY and aren't aggregations
373 if query.select.group_by:
374 query.select.projection = [
375 col
376 for col in query.select.projection
377 if col.is_aggregation() # type: ignore
378 or col.name.name in {gc.name.name for gc in query.select.group_by} # type: ignore
379 ]
382def _get_node_table(
383 node: NodeRevision,
384 build_criteria: Optional[BuildCriteria] = None,
385 as_select: bool = False,
386) -> Optional[Union[ast.Select, ast.Table]]:
387 """
388 If a node has a materialization available, return the materialized table
389 """
390 table = None
391 if node.type == NodeType.SOURCE:
392 if node.table:
393 name = ast.Name(
394 node.table,
395 namespace=ast.Name(node.schema_) if node.schema_ else None,
396 )
397 else:
398 name = to_namespaced_name(node.name)
399 table = ast.Table(name, _dj_node=node)
400 elif node.availability and node.availability.is_available(
401 criteria=build_criteria,
402 ): # pragma: no cover
403 table = ast.Table(
404 ast.Name(
405 node.availability.table,
406 namespace=(
407 ast.Name(node.availability.schema_)
408 if node.availability.schema_
409 else None
410 ),
411 ),
412 _dj_node=node,
413 )
414 if table and as_select: # pragma: no cover
415 return ast.Select(
416 projection=[ast.Wildcard()],
417 from_=ast.From(relations=[ast.Relation(table)]),
418 )
419 return table
422def build_node( # pylint: disable=too-many-arguments
423 session: Session,
424 node: NodeRevision,
425 filters: Optional[List[str]] = None,
426 dimensions: Optional[List[str]] = None,
427 build_criteria: Optional[BuildCriteria] = None,
428) -> ast.Query:
429 """
430 Determines the optimal way to build the Node and does so
431 """
432 # Set the dialect by finding available engines for this node, or default to Spark
433 if not build_criteria:
434 build_criteria = BuildCriteria(
435 dialect=(
436 node.catalog.engines[0].dialect
437 if node.catalog
438 and node.catalog.engines
439 and node.catalog.engines[0].dialect
440 else Dialect.SPARK
441 ),
442 )
444 # if no dimensions need to be added then we can see if the node is directly materialized
445 if not (filters or dimensions):
446 if select := cast(
447 ast.Select,
448 _get_node_table(node, build_criteria, as_select=True),
449 ):
450 return ast.Query(select=select) # pragma: no cover
452 if node.query:
453 query = parse(node.query)
454 else:
455 query = build_source_node_query(node)
457 add_filters_and_dimensions_to_query_ast(
458 query,
459 build_criteria.dialect,
460 filters,
461 dimensions,
462 )
464 return build_ast(session, query, build_criteria)
467def build_source_node_query(node: NodeRevision):
468 """
469 Returns a query that selects each column explicitly in the source node.
470 """
471 table = ast.Table(to_namespaced_name(node.name), None, _dj_node=node)
472 select = ast.Select(
473 projection=[
474 ast.Column(ast.Name(tbl_col.name), _table=table) for tbl_col in node.columns
475 ],
476 from_=ast.From(relations=[ast.Relation(table)]),
477 )
478 return ast.Query(select=select)
481def build_ast( # pylint: disable=too-many-arguments
482 session: Session,
483 query: ast.Query,
484 build_criteria: Optional[BuildCriteria] = None,
485) -> ast.Query:
486 """
487 Determines the optimal way to build the query AST and does so
488 """
489 context = CompileContext(session=session, exception=DJException())
490 query.compile(context)
491 query.build(session, build_criteria)
492 return query