Coverage for dj/construction/build.py: 100%

148 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-17 20:05 -0700

1"""Functions to add to an ast DJ node queries""" 

2import collections 

3 

4# pylint: disable=too-many-arguments,too-many-locals,too-many-nested-blocks,too-many-branches,R0401 

5from typing import DefaultDict, Deque, Dict, List, Optional, Set, Tuple, Union, cast 

6 

7from sqlmodel import Session 

8 

9from dj.construction.utils import amenable_name, to_namespaced_name 

10from dj.errors import DJException, DJInvalidInputException 

11from dj.models.column import Column 

12from dj.models.engine import Dialect 

13from dj.models.node import BuildCriteria, NodeRevision, NodeType 

14from dj.sql.parsing.ast import CompileContext 

15from dj.sql.parsing.backends.antlr4 import ast, parse 

16 

17 

18def _get_tables_from_select( 

19 select: ast.SelectExpression, 

20) -> DefaultDict[NodeRevision, List[ast.Table]]: 

21 """ 

22 Extract all tables (source, transform, dimensions) 

23 directly on the select that have an attached DJ node 

24 """ 

25 tables: DefaultDict[NodeRevision, List[ast.Table]] = collections.defaultdict(list) 

26 

27 for table in select.find_all(ast.Table): 

28 if node := table.dj_node: # pragma: no cover 

29 tables[node].append(table) 

30 return tables 

31 

32 

33def _join_path( 

34 dimension_node: NodeRevision, 

35 initial_nodes: Set[NodeRevision], 

36) -> Tuple[NodeRevision, Dict[Tuple[NodeRevision, NodeRevision], List[Column]]]: 

37 """ 

38 For a dimension node, we want to find a possible join path between it 

39 and any of the nodes that are directly referenced in the original query. If 

40 no join path exists, returns an empty dict. 

41 """ 

42 processed = set() 

43 

44 to_process: Deque[ 

45 Tuple[NodeRevision, Dict[Tuple[NodeRevision], List[Column]]] 

46 ] = collections.deque([]) 

47 join_info: Dict[Tuple[NodeRevision], List[Column]] = {} 

48 to_process.extend([(node, join_info.copy()) for node in initial_nodes]) 

49 possible_join_paths = [] 

50 

51 while to_process: 

52 current_node, path = to_process.popleft() 

53 processed.add(current_node) 

54 dimensions_to_columns = collections.defaultdict(list) 

55 

56 # From the columns on the current node, find the next layer of 

57 # dimension nodes that can be joined in 

58 for col in current_node.columns: 

59 if col.dimension and col.dimension.type == NodeType.DIMENSION: 

60 dimensions_to_columns[col.dimension.current].append(col) 

61 

62 # Go through all potential dimensions and their join columns 

63 for joinable_dim, join_cols in dimensions_to_columns.items(): 

64 next_join_path = {**path, **{(current_node, joinable_dim): join_cols}} 

65 full_join_path = (joinable_dim, next_join_path) 

66 if joinable_dim == dimension_node: 

67 for col in join_cols: 

68 if col.dimension_column is None and not any( 

69 dim_col.name == "id" for dim_col in dimension_node.columns 

70 ): 

71 raise DJException( 

72 f"Node {current_node.name} specifying dimension " 

73 f"{joinable_dim.name} on column {col.name} does not" 

74 f" specify a dimension column, but {dimension_node.name} " 

75 f"does not have the default key `id`.", 

76 ) 

77 possible_join_paths.append(full_join_path) # type: ignore 

78 if joinable_dim not in processed: # pragma: no cover 

79 to_process.append(full_join_path) 

80 for parent in joinable_dim.parents: 

81 to_process.append((parent.current, next_join_path)) 

82 return min(possible_join_paths, key=len) # type: ignore 

83 

84 

85def _get_or_build_join_table( 

86 session: Session, 

87 table_node: NodeRevision, 

88 build_criteria: Optional[BuildCriteria], 

89): 

90 """ 

91 Build the join table from a materialization if one is available, or recurse 

92 to build it from the dimension node's query if not 

93 """ 

94 table_node_alias = amenable_name(table_node.name) 

95 join_table = cast( 

96 Optional[ast.TableExpression], 

97 _get_node_table(table_node, build_criteria), 

98 ) 

99 if not join_table: # pragma: no cover 

100 join_query = parse(cast(str, table_node.query)) 

101 join_table = build_ast(session, join_query) # type: ignore 

102 join_table.parenthesized = True # type: ignore 

103 

104 for col in join_table.columns: 

105 col._table = join_table # pylint: disable=protected-access 

106 

107 join_table = cast(ast.TableExpression, join_table) # type: ignore 

108 right_alias = ast.Name(table_node_alias) 

109 join_right = ast.Alias( # type: ignore 

110 right_alias, 

111 child=join_table, 

112 as_=True, 

113 ) 

114 join_table.set_alias(right_alias) # type: ignore 

115 return join_right 

116 

117 

118def _build_joins_for_dimension( 

119 session: Session, 

120 dim_node: NodeRevision, 

121 initial_nodes: Set[NodeRevision], 

122 tables: DefaultDict[NodeRevision, List[ast.Table]], 

123 build_criteria: Optional[BuildCriteria], 

124 required_dimension_columns: List[ast.Column], 

125) -> List[ast.Join]: 

126 """ 

127 Returns the join ASTs needed to bring in the dimension node from 

128 the set of initial nodes. 

129 """ 

130 _, paths = _join_path(dim_node, initial_nodes) 

131 asts = [] 

132 for connecting_nodes, join_columns in paths.items(): 

133 start_node, table_node = connecting_nodes # type: ignore 

134 join_on = [] 

135 

136 # Assemble table on left of join 

137 left_table = ( 

138 tables[start_node][0].child # type: ignore 

139 if isinstance(tables[start_node][0], ast.Alias) 

140 else tables[start_node][0] 

141 ) 

142 join_left_columns = { 

143 col.alias_or_name.name: col for col in left_table.columns # type: ignore 

144 } 

145 

146 # Assemble table on right of join 

147 join_right = _get_or_build_join_table( 

148 session, 

149 table_node, 

150 build_criteria, 

151 ) 

152 

153 # Optimize query by filtering down to only the necessary columns 

154 selected_columns = {col.name.name for col in required_dimension_columns} 

155 available_join_columns = { 

156 col.dimension_column for col in join_columns if col.dimension_column 

157 } 

158 primary_key_columns = {col.name for col in table_node.primary_key()} 

159 joinable_dim_columns = { 

160 col.name for col in table_node.columns if col.dimension_id 

161 } 

162 required_mapping = ( 

163 selected_columns.union(available_join_columns) 

164 .union(primary_key_columns) 

165 .union(joinable_dim_columns) 

166 ) 

167 join_right.child.select.projection = [ 

168 col 

169 for col in join_right.child.select.projection 

170 if col.alias_or_name.name in required_mapping 

171 ] 

172 

173 initial_nodes.add(table_node) 

174 tables[table_node].append(join_right) # type: ignore 

175 join_right_columns = { 

176 col.alias_or_name.name: col # type: ignore 

177 for col in join_right.child.columns 

178 } 

179 

180 # Assemble join ON clause 

181 for join_col in join_columns: 

182 join_table_pk = table_node.primary_key() 

183 if join_col.name in join_left_columns and ( 

184 join_col.dimension_column in join_right_columns 

185 or join_table_pk[0].name in join_right_columns 

186 ): 

187 left_table.add_ref_column( 

188 cast(ast.Column, join_left_columns[join_col.name]), 

189 ) 

190 join_on.append( 

191 ast.BinaryOp.Eq( 

192 join_left_columns[join_col.name], 

193 join_right_columns[ 

194 join_col.dimension_column or join_table_pk[0].name 

195 ], 

196 use_alias_as_name=True, 

197 ), 

198 ) 

199 else: 

200 raise DJInvalidInputException( # pragma: no cover 

201 f"The specified join column {join_col.dimension_column} " 

202 f"does not exist on {table_node.name}", 

203 ) 

204 for dim_col in required_dimension_columns: 

205 join_right.child.add_ref_column(dim_col) 

206 

207 if join_on: # pragma: no cover 

208 asts.append( 

209 ast.Join( 

210 "LEFT OUTER", 

211 join_right, # type: ignore 

212 ast.JoinCriteria( 

213 on=ast.BinaryOp.And(*join_on), # pylint: disable=E1120 

214 ), 

215 ), 

216 ) 

217 return asts 

218 

219 

220def join_tables_for_dimensions( 

221 session: Session, 

222 dimension_nodes_to_columns: Dict[NodeRevision, List[ast.Column]], 

223 tables: DefaultDict[NodeRevision, List[ast.Table]], 

224 build_criteria: Optional[BuildCriteria] = None, 

225): 

226 """ 

227 Joins the tables necessary for a set of filter and group by dimensions 

228 onto the select expression. 

229 

230 In some cases, the necessary tables will already be on the select and 

231 no additional joins will be needed. However, if the tables are not in 

232 the select, it will traverse through available linked tables (via dimension 

233 nodes) and join them in. 

234 """ 

235 for dim_node, required_dimension_columns in sorted( 

236 dimension_nodes_to_columns.items(), 

237 key=lambda x: x[0].name, 

238 ): 

239 # Find all the selects that contain the different dimension columns 

240 selects_map = { 

241 cast(ast.Select, dim_col.get_nearest_parent_of_type(ast.Select)) 

242 for dim_col in required_dimension_columns 

243 } 

244 

245 # Join the source tables (if necessary) for these dimension columns 

246 # onto each select clause 

247 for select in selects_map: 

248 initial_nodes = set(tables) 

249 if dim_node not in initial_nodes: # need to join dimension 

250 join_asts = _build_joins_for_dimension( 

251 session, 

252 dim_node, 

253 initial_nodes, 

254 tables, 

255 build_criteria, 

256 required_dimension_columns, 

257 ) 

258 if join_asts and select.from_: 

259 select.from_.relations[-1].extensions.extend( # pragma: no cover 

260 join_asts, 

261 ) 

262 

263 

264def _build_tables_on_select( 

265 session: Session, 

266 select: ast.SelectExpression, 

267 tables: Dict[NodeRevision, List[ast.Table]], 

268 build_criteria: Optional[BuildCriteria] = None, 

269): 

270 """ 

271 Add all nodes not agg or filter dimensions to the select 

272 """ 

273 for node, tbls in tables.items(): 

274 node_table = cast( 

275 Optional[ast.Table], 

276 _get_node_table(node, build_criteria), 

277 ) # got a materialization 

278 if node_table is None: # no materialization - recurse to node first 

279 node_query = parse(cast(str, node.query)) 

280 node_table = build_ast( # type: ignore 

281 session, 

282 node_query, 

283 build_criteria, 

284 ).select # pylint: disable=W0212 

285 node_table.parenthesized = True # type: ignore 

286 

287 alias = amenable_name(node.node.name) 

288 context = CompileContext(session=session, exception=DJException()) 

289 

290 node_ast = ast.Alias(ast.Name(alias), child=node_table, as_=True) # type: ignore 

291 for tbl in tbls: 

292 if isinstance(node_ast.child, ast.Select) and isinstance(tbl, ast.Alias): 

293 node_ast.child.projection = [ 

294 col 

295 for col in node_ast.child.projection 

296 if col in set(tbl.child.select.projection) 

297 ] 

298 node_ast.compile(context) 

299 select.replace(tbl, node_ast) 

300 

301 

302def dimension_columns_mapping( 

303 select: ast.SelectExpression, 

304) -> Dict[NodeRevision, List[ast.Column]]: 

305 """ 

306 Extract all dimension nodes referenced by columns 

307 """ 

308 dimension_nodes_to_columns: Dict[NodeRevision, List[ast.Column]] = {} 

309 

310 for col in select.find_all(ast.Column): 

311 if isinstance(col.table, ast.Table): 

312 if node := col.table.dj_node: # pragma: no cover 

313 if node.type == NodeType.DIMENSION: 

314 dimension_nodes_to_columns[node] = dimension_nodes_to_columns.get( 

315 node, 

316 [], 

317 ) 

318 dimension_nodes_to_columns[node].append(col) 

319 return dimension_nodes_to_columns 

320 

321 

322# flake8: noqa: C901 

323def _build_select_ast( 

324 session: Session, 

325 select: ast.SelectExpression, 

326 build_criteria: Optional[BuildCriteria] = None, 

327): 

328 """ 

329 Transforms a select ast by replacing dj node references with their asts 

330 Starts by extracting all dimensions-backed columns from filters + group bys. 

331 Some of them can be sourced directly from tables on the select, others cannot 

332 For the ones that cannot be sourced directly, attempt to join them via dimension links. 

333 """ 

334 tables = _get_tables_from_select(select) 

335 dimension_columns = dimension_columns_mapping(select) 

336 join_tables_for_dimensions(session, dimension_columns, tables, build_criteria) 

337 _build_tables_on_select(session, select, tables, build_criteria) 

338 

339 

340def add_filters_and_dimensions_to_query_ast( 

341 query: ast.Query, 

342 dialect: Optional[str] = None, # pylint: disable=unused-argument 

343 filters: Optional[List[str]] = None, 

344 dimensions: Optional[List[str]] = None, 

345): 

346 """ 

347 Add filters and dimensions to a query ast 

348 """ 

349 projection_addition = [] 

350 if filters: 

351 filter_asts = ( # pylint: disable=consider-using-ternary 

352 query.select.where and [query.select.where] or [] 

353 ) 

354 

355 for filter_ in filters: 

356 temp_select = parse(f"select * where {filter_}").select 

357 filter_asts.append( 

358 # use parse to get the asts from the strings we got 

359 temp_select.where, # type:ignore 

360 ) 

361 query.select.where = ast.BinaryOp.And(*filter_asts) 

362 

363 if dimensions: 

364 for agg in dimensions: 

365 temp_select = parse( 

366 f"select * group by {agg}", 

367 ).select 

368 query.select.group_by += temp_select.group_by # type:ignore 

369 projection_addition += list(temp_select.find_all(ast.Column)) 

370 query.select.projection += list(projection_addition) 

371 

372 # Cannot select for columns that aren't in GROUP BY and aren't aggregations 

373 if query.select.group_by: 

374 query.select.projection = [ 

375 col 

376 for col in query.select.projection 

377 if col.is_aggregation() # type: ignore 

378 or col.name.name in {gc.name.name for gc in query.select.group_by} # type: ignore 

379 ] 

380 

381 

382def _get_node_table( 

383 node: NodeRevision, 

384 build_criteria: Optional[BuildCriteria] = None, 

385 as_select: bool = False, 

386) -> Optional[Union[ast.Select, ast.Table]]: 

387 """ 

388 If a node has a materialization available, return the materialized table 

389 """ 

390 table = None 

391 if node.type == NodeType.SOURCE: 

392 if node.table: 

393 name = ast.Name( 

394 node.table, 

395 namespace=ast.Name(node.schema_) if node.schema_ else None, 

396 ) 

397 else: 

398 name = to_namespaced_name(node.name) 

399 table = ast.Table(name, _dj_node=node) 

400 elif node.availability and node.availability.is_available( 

401 criteria=build_criteria, 

402 ): # pragma: no cover 

403 table = ast.Table( 

404 ast.Name( 

405 node.availability.table, 

406 namespace=( 

407 ast.Name(node.availability.schema_) 

408 if node.availability.schema_ 

409 else None 

410 ), 

411 ), 

412 _dj_node=node, 

413 ) 

414 if table and as_select: # pragma: no cover 

415 return ast.Select( 

416 projection=[ast.Wildcard()], 

417 from_=ast.From(relations=[ast.Relation(table)]), 

418 ) 

419 return table 

420 

421 

422def build_node( # pylint: disable=too-many-arguments 

423 session: Session, 

424 node: NodeRevision, 

425 filters: Optional[List[str]] = None, 

426 dimensions: Optional[List[str]] = None, 

427 build_criteria: Optional[BuildCriteria] = None, 

428) -> ast.Query: 

429 """ 

430 Determines the optimal way to build the Node and does so 

431 """ 

432 # Set the dialect by finding available engines for this node, or default to Spark 

433 if not build_criteria: 

434 build_criteria = BuildCriteria( 

435 dialect=( 

436 node.catalog.engines[0].dialect 

437 if node.catalog 

438 and node.catalog.engines 

439 and node.catalog.engines[0].dialect 

440 else Dialect.SPARK 

441 ), 

442 ) 

443 

444 # if no dimensions need to be added then we can see if the node is directly materialized 

445 if not (filters or dimensions): 

446 if select := cast( 

447 ast.Select, 

448 _get_node_table(node, build_criteria, as_select=True), 

449 ): 

450 return ast.Query(select=select) # pragma: no cover 

451 

452 if node.query: 

453 query = parse(node.query) 

454 else: 

455 query = build_source_node_query(node) 

456 

457 add_filters_and_dimensions_to_query_ast( 

458 query, 

459 build_criteria.dialect, 

460 filters, 

461 dimensions, 

462 ) 

463 

464 return build_ast(session, query, build_criteria) 

465 

466 

467def build_source_node_query(node: NodeRevision): 

468 """ 

469 Returns a query that selects each column explicitly in the source node. 

470 """ 

471 table = ast.Table(to_namespaced_name(node.name), None, _dj_node=node) 

472 select = ast.Select( 

473 projection=[ 

474 ast.Column(ast.Name(tbl_col.name), _table=table) for tbl_col in node.columns 

475 ], 

476 from_=ast.From(relations=[ast.Relation(table)]), 

477 ) 

478 return ast.Query(select=select) 

479 

480 

481def build_ast( # pylint: disable=too-many-arguments 

482 session: Session, 

483 query: ast.Query, 

484 build_criteria: Optional[BuildCriteria] = None, 

485) -> ast.Query: 

486 """ 

487 Determines the optimal way to build the query AST and does so 

488 """ 

489 context = CompileContext(session=session, exception=DJException()) 

490 query.compile(context) 

491 query.build(session, build_criteria) 

492 return query