Coverage for dj/api/helpers.py: 100%

146 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-17 20:05 -0700

1""" 

2Helpers for API endpoints 

3""" 

4from http import HTTPStatus 

5from typing import Dict, List, Optional, Tuple, Union 

6 

7from fastapi import HTTPException 

8from sqlalchemy.exc import NoResultFound 

9from sqlalchemy.orm import joinedload 

10from sqlmodel import Session, select 

11 

12from dj.construction.build import build_node 

13from dj.construction.dj_query import build_dj_metric_query 

14from dj.errors import DJError, DJException, DJInvalidInputException, ErrorCode 

15from dj.models import AttributeType, Catalog, Column, Engine 

16from dj.models.attribute import RESERVED_ATTRIBUTE_NAMESPACE 

17from dj.models.engine import Dialect 

18from dj.models.node import ( 

19 BuildCriteria, 

20 MissingParent, 

21 Node, 

22 NodeMissingParents, 

23 NodeMode, 

24 NodeNamespace, 

25 NodeRelationship, 

26 NodeRevision, 

27 NodeRevisionBase, 

28 NodeStatus, 

29 NodeType, 

30) 

31from dj.sql.parsing import ast 

32from dj.sql.parsing.backends.antlr4 import SqlSyntaxError, parse 

33from dj.sql.parsing.backends.exceptions import DJParseException 

34 

35 

36def get_node_namespace( # pylint: disable=too-many-arguments 

37 session: Session, 

38 namespace: str, 

39 raise_if_not_exists: bool = True, 

40) -> str: 

41 """ 

42 Get a node namespace 

43 """ 

44 statement = select(NodeNamespace).where(NodeNamespace.namespace == namespace) 

45 node_namespace = session.exec(statement).one_or_none() 

46 if raise_if_not_exists: # pragma: no cover 

47 if not node_namespace: 

48 raise DJException( 

49 message=(f"node namespace `{namespace}` does not exist."), 

50 http_status_code=404, 

51 ) 

52 return node_namespace 

53 

54 

55def get_node_by_name( # pylint: disable=too-many-arguments 

56 session: Session, 

57 name: Optional[str], 

58 node_type: Optional[NodeType] = None, 

59 with_current: bool = False, 

60 raise_if_not_exists: bool = True, 

61) -> Node: 

62 """ 

63 Get a node by name 

64 """ 

65 statement = select(Node).where(Node.name == name) 

66 if node_type: 

67 statement = statement.where(Node.type == node_type) 

68 if with_current: 

69 statement = statement.options(joinedload(Node.current)) 

70 node = session.exec(statement).unique().one_or_none() 

71 else: 

72 node = session.exec(statement).one_or_none() 

73 if raise_if_not_exists: 

74 if not node: 

75 raise DJException( 

76 message=( 

77 f"A {'' if not node_type else node_type + ' '}" 

78 f"node with name `{name}` does not exist." 

79 ), 

80 http_status_code=404, 

81 ) 

82 return node 

83 

84 

85def raise_if_node_exists(session: Session, name: str) -> None: 

86 """ 

87 Raise an error if the node with the given name already exists. 

88 """ 

89 node = get_node_by_name(session, name, raise_if_not_exists=False) 

90 if node: 

91 raise DJException( 

92 message=f"A node with name `{name}` already exists.", 

93 http_status_code=HTTPStatus.CONFLICT, 

94 ) 

95 

96 

97def get_column(node: NodeRevision, column_name: str) -> Column: 

98 """ 

99 Get a column from a node revision 

100 """ 

101 requested_column = None 

102 for node_column in node.columns: 

103 if node_column.name == column_name: 

104 requested_column = node_column 

105 break 

106 

107 if not requested_column: 

108 raise DJException( 

109 message=f"Column {column_name} does not exist on node {node.name}", 

110 http_status_code=404, 

111 ) 

112 return requested_column 

113 

114 

115def get_attribute_type( 

116 session: Session, 

117 name: str, 

118 namespace: Optional[str] = RESERVED_ATTRIBUTE_NAMESPACE, 

119) -> Optional[AttributeType]: 

120 """ 

121 Gets an attribute type by name. 

122 """ 

123 statement = ( 

124 select(AttributeType) 

125 .where(AttributeType.name == name) 

126 .where(AttributeType.namespace == namespace) 

127 ) 

128 return session.exec(statement).one_or_none() 

129 

130 

131def get_catalog(session: Session, name: str) -> Catalog: 

132 """ 

133 Get a catalog by name 

134 """ 

135 statement = select(Catalog).where(Catalog.name == name) 

136 catalog = session.exec(statement).one_or_none() 

137 if not catalog: 

138 raise DJException( 

139 message=f"Catalog with name `{name}` does not exist.", 

140 http_status_code=404, 

141 ) 

142 return catalog 

143 

144 

145def get_query( # pylint: disable=too-many-arguments 

146 session: Session, 

147 node_name: str, 

148 dimensions: List[str], 

149 filters: List[str], 

150 engine: Optional[Engine], 

151) -> ast.Query: 

152 """ 

153 Get a query for a metric, dimensions, and filters 

154 """ 

155 node = get_node_by_name(session=session, name=node_name) 

156 

157 if node.type in (NodeType.DIMENSION, NodeType.SOURCE): 

158 if dimensions: 

159 raise DJInvalidInputException( 

160 message=f"Cannot set dimensions for node type {node.type}!", 

161 ) 

162 

163 # Builds the node for the engine's dialect if one is set or defaults to Spark 

164 if ( 

165 not engine 

166 and node.current 

167 and node.current.catalog 

168 and node.current.catalog.engines 

169 ): 

170 engine = node.current.catalog.engines[0] 

171 build_criteria = BuildCriteria( 

172 dialect=(engine.dialect if engine and engine.dialect else Dialect.SPARK), 

173 ) 

174 

175 query_ast = build_node( 

176 session=session, 

177 node=node.current, 

178 filters=filters, 

179 dimensions=dimensions, 

180 build_criteria=build_criteria, 

181 ) 

182 return query_ast 

183 

184 

185def get_dj_query( 

186 session: Session, 

187 query: str, 

188) -> ast.Query: 

189 """ 

190 Get a query for a metric, dimensions, and filters 

191 """ 

192 

193 query_ast = build_dj_metric_query( 

194 session=session, 

195 query=query, 

196 ) 

197 return query_ast 

198 

199 

200def get_engine(session: Session, name: str, version: str) -> Engine: 

201 """ 

202 Return an Engine instance given an engine name and version 

203 """ 

204 statement = ( 

205 select(Engine).where(Engine.name == name).where(Engine.version == version) 

206 ) 

207 try: 

208 engine = session.exec(statement).one() 

209 except NoResultFound as exc: 

210 raise HTTPException( 

211 status_code=HTTPStatus.NOT_FOUND, 

212 detail=f"Engine not found: `{name}` version `{version}`", 

213 ) from exc 

214 return engine 

215 

216 

217def get_downstream_nodes( 

218 session: Session, 

219 node_name: str, 

220 node_type: NodeType = None, 

221) -> List[Node]: 

222 """ 

223 Gets all downstream children of the given node, filterable by node type. 

224 Uses a recursive CTE query to build out all descendants from the node. 

225 """ 

226 node = get_node_by_name(session=session, name=node_name) 

227 

228 dag = ( 

229 select( 

230 NodeRelationship.parent_id, 

231 NodeRevision.node_id, 

232 ) 

233 .where(NodeRelationship.parent_id == node.id) 

234 .join(NodeRevision, NodeRelationship.child_id == NodeRevision.id) 

235 .join(Node, Node.id == NodeRevision.node_id) 

236 ).cte("dag", recursive=True) 

237 

238 paths = dag.union_all( 

239 select( 

240 dag.c.parent_id, 

241 NodeRevision.node_id, 

242 ) 

243 .join(NodeRelationship, dag.c.node_id == NodeRelationship.parent_id) 

244 .join(NodeRevision, NodeRelationship.child_id == NodeRevision.id) 

245 .join(Node, Node.id == NodeRevision.node_id), 

246 ) 

247 

248 statement = ( 

249 select(Node) 

250 .join(paths, paths.c.node_id == Node.id) 

251 .options(joinedload(Node.current)) 

252 ) 

253 

254 results = session.exec(statement).unique().all() 

255 

256 return [ 

257 downstream 

258 for downstream in results 

259 if downstream.type == node_type or node_type is None 

260 ] 

261 

262 

263def validate_node_data( 

264 data: Union[NodeRevisionBase, NodeRevision], 

265 session: Session, 

266) -> Tuple[ 

267 NodeRevision, 

268 Dict[NodeRevision, List[ast.Table]], 

269 Dict[str, List[ast.Table]], 

270 List[str], 

271]: 

272 """ 

273 Validate a node. 

274 """ 

275 

276 if isinstance(data, NodeRevision): 

277 validated_node = data 

278 else: 

279 node = Node(name=data.name, type=data.type) 

280 validated_node = NodeRevision.parse_obj(data) 

281 validated_node.node = node 

282 validated_node.status = NodeStatus.VALID 

283 

284 # Try to parse the node's query and extract dependencies 

285 try: 

286 query_ast = parse(validated_node.query) # type: ignore 

287 exc = DJException() 

288 ctx = ast.CompileContext(session=session, exception=exc) 

289 dependencies_map, missing_parents_map = query_ast.extract_dependencies(ctx) 

290 except (ValueError, SqlSyntaxError) as exc: 

291 raise DJException(message=str(exc)) from exc 

292 

293 # Only raise on missing parents if the node mode is set to published 

294 if missing_parents_map: 

295 if validated_node.mode == NodeMode.DRAFT: 

296 validated_node.status = NodeStatus.INVALID 

297 else: 

298 raise DJException( 

299 errors=[ 

300 DJError( 

301 code=ErrorCode.MISSING_PARENT, 

302 message="Node definition contains references to nodes that do not exist", 

303 debug={"missing_parents": list(missing_parents_map.keys())}, 

304 ), 

305 ], 

306 ) 

307 

308 # Add aliases for any unnamed columns and confirm that all column types can be inferred 

309 query_ast.select.add_aliases_to_unnamed_columns() 

310 

311 validated_node.columns = [] 

312 type_inference_failed_columns = [] 

313 for col in query_ast.select.projection: 

314 try: 

315 column_type = col.type # type: ignore 

316 validated_node.columns.append( 

317 Column(name=col.alias_or_name.name, type=column_type), # type: ignore 

318 ) 

319 except DJParseException: 

320 type_inference_failed_columns.append(col.alias_or_name.name) # type: ignore 

321 validated_node.status = NodeStatus.INVALID 

322 return ( 

323 validated_node, 

324 dependencies_map, 

325 missing_parents_map, 

326 type_inference_failed_columns, 

327 ) 

328 

329 

330def resolve_downstream_references( 

331 session: Session, 

332 node_revision: NodeRevision, 

333) -> List[NodeRevision]: 

334 """ 

335 Find all node revisions with missing parent references to `node` and resolve them 

336 """ 

337 missing_parents = session.exec( 

338 select(MissingParent).where(MissingParent.name == node_revision.name), 

339 ).all() 

340 newly_valid_nodes = [] 

341 for missing_parent in missing_parents: 

342 missing_parent_links = session.exec( 

343 select(NodeMissingParents).where( 

344 NodeMissingParents.missing_parent_id == missing_parent.id, 

345 ), 

346 ).all() 

347 for ( 

348 link 

349 ) in missing_parent_links: # Remove from missing parents and add to parents 

350 downstream_node_id = link.referencing_node_id 

351 downstream_node_revision = ( 

352 session.exec( 

353 select(NodeRevision).where(NodeRevision.id == downstream_node_id), 

354 ) 

355 .unique() 

356 .one() 

357 ) 

358 downstream_node_revision.parents.append(node_revision.node) 

359 downstream_node_revision.missing_parents.remove(missing_parent) 

360 ( 

361 _, 

362 _, 

363 missing_parents_map, 

364 type_inference_failed_columns, 

365 ) = validate_node_data(data=downstream_node_revision, session=session) 

366 if not missing_parents_map and not type_inference_failed_columns: 

367 newly_valid_nodes.append(downstream_node_revision) 

368 session.add(downstream_node_revision) 

369 session.commit() 

370 

371 session.delete(missing_parent) # Remove missing parent reference to node 

372 return newly_valid_nodes 

373 

374 

375def propagate_valid_status( 

376 session: Session, 

377 valid_nodes: List[NodeRevision], 

378 catalog_id: int, 

379) -> None: 

380 """ 

381 Propagate a valid status by revalidating all downstream nodes 

382 """ 

383 while valid_nodes: 

384 resolved_nodes = [] 

385 for node_revision in valid_nodes: 

386 if node_revision.status != NodeStatus.VALID: 

387 raise DJException( 

388 f"Cannot propagate valid status: Node `{node_revision.name}` is not valid", 

389 ) 

390 downstream_nodes = get_downstream_nodes( 

391 session=session, 

392 node_name=node_revision.name, 

393 ) 

394 newly_valid_nodes = [] 

395 for node in downstream_nodes: 

396 ( 

397 validated_node, 

398 _, 

399 missing_parents_map, 

400 type_inference_failed_columns, 

401 ) = validate_node_data(data=node.current, session=session) 

402 if not missing_parents_map and not type_inference_failed_columns: 

403 node.current.columns = validated_node.columns or [] 

404 node.current.status = NodeStatus.VALID 

405 node.current.catalog_id = catalog_id 

406 session.add(node.current) 

407 session.commit() 

408 newly_valid_nodes.append(node.current) 

409 resolved_nodes.extend(newly_valid_nodes) 

410 valid_nodes = resolved_nodes