Coverage for dj/api/data.py: 100%
43 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
1"""
2Data related APIs.
3"""
5import logging
6from typing import List, Optional
8from fastapi import APIRouter, Depends, Query
9from fastapi.responses import JSONResponse
10from sqlmodel import Session
12from dj.api.helpers import get_engine, get_node_by_name, get_query
13from dj.errors import DJException, DJInvalidInputException
14from dj.models.metric import TranslatedSQL
15from dj.models.node import AvailabilityState, AvailabilityStateBase, NodeType
16from dj.models.query import ColumnMetadata, QueryCreate, QueryWithResults
17from dj.service_clients import QueryServiceClient
18from dj.utils import get_query_service_client, get_session
20_logger = logging.getLogger(__name__)
21router = APIRouter()
24@router.post("/data/{node_name}/availability/")
25def add_an_availability_state(
26 node_name: str,
27 data: AvailabilityStateBase,
28 *,
29 session: Session = Depends(get_session),
30) -> JSONResponse:
31 """
32 Add an availability state to a node
33 """
34 node = get_node_by_name(session, node_name)
36 # Source nodes require that any availability states set are for one of the defined tables
37 node_revision = node.current
38 if data.catalog != node_revision.catalog.name:
39 raise DJException(
40 "Cannot set availability state in different catalog: "
41 f"{data.catalog}, {node_revision.catalog}",
42 )
43 if node.current.type == NodeType.SOURCE:
44 if node_revision.schema_ != data.schema_ or node_revision.table != data.table:
45 raise DJException(
46 message=(
47 "Cannot set availability state, "
48 "source nodes require availability "
49 "states to match the set table: "
50 f"{data.catalog}."
51 f"{data.schema_}."
52 f"{data.table} "
53 "does not match "
54 f"{node_revision.catalog.name}."
55 f"{node_revision.schema_}."
56 f"{node_revision.table} "
57 ),
58 )
60 # Merge the new availability state with the current availability state if one exists
61 if (
62 node_revision.availability
63 and node_revision.availability.catalog == node.current.catalog.name
64 and node_revision.availability.schema_ == data.schema_
65 and node_revision.availability.table == data.table
66 ):
67 # Currently, we do not consider type information. We should eventually check the type of
68 # the partition values in order to cast them before sorting.
69 data.max_partition = max(
70 (
71 node_revision.availability.max_partition,
72 data.max_partition,
73 ),
74 )
75 data.min_partition = min(
76 (
77 node_revision.availability.min_partition,
78 data.min_partition,
79 ),
80 )
82 db_new_availability = AvailabilityState.from_orm(data)
83 node_revision.availability = db_new_availability
84 session.add(node_revision)
85 session.commit()
86 return JSONResponse(
87 status_code=200,
88 content={"message": "Availability state successfully posted"},
89 )
92@router.get("/data/{node_name}/")
93def get_data( # pylint: disable=too-many-locals
94 node_name: str,
95 *,
96 dimensions: List[str] = Query([]),
97 filters: List[str] = Query([]),
98 async_: bool = False,
99 session: Session = Depends(get_session),
100 query_service_client: QueryServiceClient = Depends(get_query_service_client),
101 engine_name: Optional[str] = None,
102 engine_version: Optional[str] = None,
103) -> QueryWithResults:
104 """
105 Gets data for a node
106 """
107 node = get_node_by_name(session, node_name)
109 available_engines = node.current.catalog.engines
110 engine = (
111 get_engine(session, engine_name, engine_version) # type: ignore
112 if engine_name
113 else available_engines[0]
114 )
115 if engine not in available_engines:
116 raise DJInvalidInputException( # pragma: no cover
117 f"The selected engine is not available for the node {node_name}. "
118 f"Available engines include: {', '.join(engine.name for engine in available_engines)}",
119 )
121 query_ast = get_query(
122 session=session,
123 node_name=node_name,
124 dimensions=dimensions,
125 filters=filters,
126 engine=engine,
127 )
128 columns = [
129 ColumnMetadata(name=col.alias_or_name.name, type=str(col.type)) # type: ignore
130 for col in query_ast.select.projection
131 ]
132 query = TranslatedSQL(
133 sql=str(query_ast),
134 columns=columns,
135 )
137 query_create = QueryCreate(
138 engine_name=engine.name,
139 catalog_name=node.current.catalog.name,
140 engine_version=engine.version,
141 submitted_query=query.sql,
142 async_=async_,
143 )
144 result = query_service_client.submit_query(query_create)
145 # Inject column info if there are results
146 if result.results.__root__: # pragma: no cover
147 result.results.__root__[0].columns = columns
148 return result