Coverage for dj/api/helpers.py: 100%
146 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
1"""
2Helpers for API endpoints
3"""
4from http import HTTPStatus
5from typing import Dict, List, Optional, Tuple, Union
7from fastapi import HTTPException
8from sqlalchemy.exc import NoResultFound
9from sqlalchemy.orm import joinedload
10from sqlmodel import Session, select
12from dj.construction.build import build_node
13from dj.construction.dj_query import build_dj_metric_query
14from dj.errors import DJError, DJException, DJInvalidInputException, ErrorCode
15from dj.models import AttributeType, Catalog, Column, Engine
16from dj.models.attribute import RESERVED_ATTRIBUTE_NAMESPACE
17from dj.models.engine import Dialect
18from dj.models.node import (
19 BuildCriteria,
20 MissingParent,
21 Node,
22 NodeMissingParents,
23 NodeMode,
24 NodeNamespace,
25 NodeRelationship,
26 NodeRevision,
27 NodeRevisionBase,
28 NodeStatus,
29 NodeType,
30)
31from dj.sql.parsing import ast
32from dj.sql.parsing.backends.antlr4 import SqlSyntaxError, parse
33from dj.sql.parsing.backends.exceptions import DJParseException
36def get_node_namespace( # pylint: disable=too-many-arguments
37 session: Session,
38 namespace: str,
39 raise_if_not_exists: bool = True,
40) -> str:
41 """
42 Get a node namespace
43 """
44 statement = select(NodeNamespace).where(NodeNamespace.namespace == namespace)
45 node_namespace = session.exec(statement).one_or_none()
46 if raise_if_not_exists: # pragma: no cover
47 if not node_namespace:
48 raise DJException(
49 message=(f"node namespace `{namespace}` does not exist."),
50 http_status_code=404,
51 )
52 return node_namespace
55def get_node_by_name( # pylint: disable=too-many-arguments
56 session: Session,
57 name: Optional[str],
58 node_type: Optional[NodeType] = None,
59 with_current: bool = False,
60 raise_if_not_exists: bool = True,
61) -> Node:
62 """
63 Get a node by name
64 """
65 statement = select(Node).where(Node.name == name)
66 if node_type:
67 statement = statement.where(Node.type == node_type)
68 if with_current:
69 statement = statement.options(joinedload(Node.current))
70 node = session.exec(statement).unique().one_or_none()
71 else:
72 node = session.exec(statement).one_or_none()
73 if raise_if_not_exists:
74 if not node:
75 raise DJException(
76 message=(
77 f"A {'' if not node_type else node_type + ' '}"
78 f"node with name `{name}` does not exist."
79 ),
80 http_status_code=404,
81 )
82 return node
85def raise_if_node_exists(session: Session, name: str) -> None:
86 """
87 Raise an error if the node with the given name already exists.
88 """
89 node = get_node_by_name(session, name, raise_if_not_exists=False)
90 if node:
91 raise DJException(
92 message=f"A node with name `{name}` already exists.",
93 http_status_code=HTTPStatus.CONFLICT,
94 )
97def get_column(node: NodeRevision, column_name: str) -> Column:
98 """
99 Get a column from a node revision
100 """
101 requested_column = None
102 for node_column in node.columns:
103 if node_column.name == column_name:
104 requested_column = node_column
105 break
107 if not requested_column:
108 raise DJException(
109 message=f"Column {column_name} does not exist on node {node.name}",
110 http_status_code=404,
111 )
112 return requested_column
115def get_attribute_type(
116 session: Session,
117 name: str,
118 namespace: Optional[str] = RESERVED_ATTRIBUTE_NAMESPACE,
119) -> Optional[AttributeType]:
120 """
121 Gets an attribute type by name.
122 """
123 statement = (
124 select(AttributeType)
125 .where(AttributeType.name == name)
126 .where(AttributeType.namespace == namespace)
127 )
128 return session.exec(statement).one_or_none()
131def get_catalog(session: Session, name: str) -> Catalog:
132 """
133 Get a catalog by name
134 """
135 statement = select(Catalog).where(Catalog.name == name)
136 catalog = session.exec(statement).one_or_none()
137 if not catalog:
138 raise DJException(
139 message=f"Catalog with name `{name}` does not exist.",
140 http_status_code=404,
141 )
142 return catalog
145def get_query( # pylint: disable=too-many-arguments
146 session: Session,
147 node_name: str,
148 dimensions: List[str],
149 filters: List[str],
150 engine: Optional[Engine],
151) -> ast.Query:
152 """
153 Get a query for a metric, dimensions, and filters
154 """
155 node = get_node_by_name(session=session, name=node_name)
157 if node.type in (NodeType.DIMENSION, NodeType.SOURCE):
158 if dimensions:
159 raise DJInvalidInputException(
160 message=f"Cannot set dimensions for node type {node.type}!",
161 )
163 # Builds the node for the engine's dialect if one is set or defaults to Spark
164 if (
165 not engine
166 and node.current
167 and node.current.catalog
168 and node.current.catalog.engines
169 ):
170 engine = node.current.catalog.engines[0]
171 build_criteria = BuildCriteria(
172 dialect=(engine.dialect if engine and engine.dialect else Dialect.SPARK),
173 )
175 query_ast = build_node(
176 session=session,
177 node=node.current,
178 filters=filters,
179 dimensions=dimensions,
180 build_criteria=build_criteria,
181 )
182 return query_ast
185def get_dj_query(
186 session: Session,
187 query: str,
188) -> ast.Query:
189 """
190 Get a query for a metric, dimensions, and filters
191 """
193 query_ast = build_dj_metric_query(
194 session=session,
195 query=query,
196 )
197 return query_ast
200def get_engine(session: Session, name: str, version: str) -> Engine:
201 """
202 Return an Engine instance given an engine name and version
203 """
204 statement = (
205 select(Engine).where(Engine.name == name).where(Engine.version == version)
206 )
207 try:
208 engine = session.exec(statement).one()
209 except NoResultFound as exc:
210 raise HTTPException(
211 status_code=HTTPStatus.NOT_FOUND,
212 detail=f"Engine not found: `{name}` version `{version}`",
213 ) from exc
214 return engine
217def get_downstream_nodes(
218 session: Session,
219 node_name: str,
220 node_type: NodeType = None,
221) -> List[Node]:
222 """
223 Gets all downstream children of the given node, filterable by node type.
224 Uses a recursive CTE query to build out all descendants from the node.
225 """
226 node = get_node_by_name(session=session, name=node_name)
228 dag = (
229 select(
230 NodeRelationship.parent_id,
231 NodeRevision.node_id,
232 )
233 .where(NodeRelationship.parent_id == node.id)
234 .join(NodeRevision, NodeRelationship.child_id == NodeRevision.id)
235 .join(Node, Node.id == NodeRevision.node_id)
236 ).cte("dag", recursive=True)
238 paths = dag.union_all(
239 select(
240 dag.c.parent_id,
241 NodeRevision.node_id,
242 )
243 .join(NodeRelationship, dag.c.node_id == NodeRelationship.parent_id)
244 .join(NodeRevision, NodeRelationship.child_id == NodeRevision.id)
245 .join(Node, Node.id == NodeRevision.node_id),
246 )
248 statement = (
249 select(Node)
250 .join(paths, paths.c.node_id == Node.id)
251 .options(joinedload(Node.current))
252 )
254 results = session.exec(statement).unique().all()
256 return [
257 downstream
258 for downstream in results
259 if downstream.type == node_type or node_type is None
260 ]
263def validate_node_data(
264 data: Union[NodeRevisionBase, NodeRevision],
265 session: Session,
266) -> Tuple[
267 NodeRevision,
268 Dict[NodeRevision, List[ast.Table]],
269 Dict[str, List[ast.Table]],
270 List[str],
271]:
272 """
273 Validate a node.
274 """
276 if isinstance(data, NodeRevision):
277 validated_node = data
278 else:
279 node = Node(name=data.name, type=data.type)
280 validated_node = NodeRevision.parse_obj(data)
281 validated_node.node = node
282 validated_node.status = NodeStatus.VALID
284 # Try to parse the node's query and extract dependencies
285 try:
286 query_ast = parse(validated_node.query) # type: ignore
287 exc = DJException()
288 ctx = ast.CompileContext(session=session, exception=exc)
289 dependencies_map, missing_parents_map = query_ast.extract_dependencies(ctx)
290 except (ValueError, SqlSyntaxError) as exc:
291 raise DJException(message=str(exc)) from exc
293 # Only raise on missing parents if the node mode is set to published
294 if missing_parents_map:
295 if validated_node.mode == NodeMode.DRAFT:
296 validated_node.status = NodeStatus.INVALID
297 else:
298 raise DJException(
299 errors=[
300 DJError(
301 code=ErrorCode.MISSING_PARENT,
302 message="Node definition contains references to nodes that do not exist",
303 debug={"missing_parents": list(missing_parents_map.keys())},
304 ),
305 ],
306 )
308 # Add aliases for any unnamed columns and confirm that all column types can be inferred
309 query_ast.select.add_aliases_to_unnamed_columns()
311 validated_node.columns = []
312 type_inference_failed_columns = []
313 for col in query_ast.select.projection:
314 try:
315 column_type = col.type # type: ignore
316 validated_node.columns.append(
317 Column(name=col.alias_or_name.name, type=column_type), # type: ignore
318 )
319 except DJParseException:
320 type_inference_failed_columns.append(col.alias_or_name.name) # type: ignore
321 validated_node.status = NodeStatus.INVALID
322 return (
323 validated_node,
324 dependencies_map,
325 missing_parents_map,
326 type_inference_failed_columns,
327 )
330def resolve_downstream_references(
331 session: Session,
332 node_revision: NodeRevision,
333) -> List[NodeRevision]:
334 """
335 Find all node revisions with missing parent references to `node` and resolve them
336 """
337 missing_parents = session.exec(
338 select(MissingParent).where(MissingParent.name == node_revision.name),
339 ).all()
340 newly_valid_nodes = []
341 for missing_parent in missing_parents:
342 missing_parent_links = session.exec(
343 select(NodeMissingParents).where(
344 NodeMissingParents.missing_parent_id == missing_parent.id,
345 ),
346 ).all()
347 for (
348 link
349 ) in missing_parent_links: # Remove from missing parents and add to parents
350 downstream_node_id = link.referencing_node_id
351 downstream_node_revision = (
352 session.exec(
353 select(NodeRevision).where(NodeRevision.id == downstream_node_id),
354 )
355 .unique()
356 .one()
357 )
358 downstream_node_revision.parents.append(node_revision.node)
359 downstream_node_revision.missing_parents.remove(missing_parent)
360 (
361 _,
362 _,
363 missing_parents_map,
364 type_inference_failed_columns,
365 ) = validate_node_data(data=downstream_node_revision, session=session)
366 if not missing_parents_map and not type_inference_failed_columns:
367 newly_valid_nodes.append(downstream_node_revision)
368 session.add(downstream_node_revision)
369 session.commit()
371 session.delete(missing_parent) # Remove missing parent reference to node
372 return newly_valid_nodes
375def propagate_valid_status(
376 session: Session,
377 valid_nodes: List[NodeRevision],
378 catalog_id: int,
379) -> None:
380 """
381 Propagate a valid status by revalidating all downstream nodes
382 """
383 while valid_nodes:
384 resolved_nodes = []
385 for node_revision in valid_nodes:
386 if node_revision.status != NodeStatus.VALID:
387 raise DJException(
388 f"Cannot propagate valid status: Node `{node_revision.name}` is not valid",
389 )
390 downstream_nodes = get_downstream_nodes(
391 session=session,
392 node_name=node_revision.name,
393 )
394 newly_valid_nodes = []
395 for node in downstream_nodes:
396 (
397 validated_node,
398 _,
399 missing_parents_map,
400 type_inference_failed_columns,
401 ) = validate_node_data(data=node.current, session=session)
402 if not missing_parents_map and not type_inference_failed_columns:
403 node.current.columns = validated_node.columns or []
404 node.current.status = NodeStatus.VALID
405 node.current.catalog_id = catalog_id
406 session.add(node.current)
407 session.commit()
408 newly_valid_nodes.append(node.current)
409 resolved_nodes.extend(newly_valid_nodes)
410 valid_nodes = resolved_nodes