Coverage for dj/api/nodes.py: 100%

298 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-17 20:05 -0700

1""" 

2Node related APIs. 

3""" 

4import http.client 

5import logging 

6import os 

7from collections import defaultdict 

8from http import HTTPStatus 

9from typing import List, Optional, Union 

10 

11from fastapi import APIRouter, Depends 

12from fastapi.responses import JSONResponse 

13from sqlalchemy.orm import joinedload 

14from sqlmodel import Session, select 

15from starlette.requests import Request 

16from starlette.responses import Response 

17 

18from dj.api.helpers import ( 

19 get_attribute_type, 

20 get_catalog, 

21 get_column, 

22 get_downstream_nodes, 

23 get_engine, 

24 get_node_by_name, 

25 get_node_namespace, 

26 propagate_valid_status, 

27 raise_if_node_exists, 

28 resolve_downstream_references, 

29 validate_node_data, 

30) 

31from dj.api.tags import get_tag_by_name 

32from dj.errors import DJDoesNotExistException, DJException, DJInvalidInputException 

33from dj.models import ColumnAttribute 

34from dj.models.attribute import UniquenessScope 

35from dj.models.base import generate_display_name 

36from dj.models.column import Column, ColumnAttributeInput 

37from dj.models.node import ( 

38 DEFAULT_DRAFT_VERSION, 

39 DEFAULT_PUBLISHED_VERSION, 

40 ColumnOutput, 

41 CreateCubeNode, 

42 CreateNode, 

43 CreateSourceNode, 

44 MaterializationConfig, 

45 MissingParent, 

46 Node, 

47 NodeMode, 

48 NodeOutput, 

49 NodeRevision, 

50 NodeRevisionBase, 

51 NodeRevisionOutput, 

52 NodeStatus, 

53 NodeType, 

54 NodeValidation, 

55 UpdateNode, 

56 UpsertMaterializationConfig, 

57) 

58from dj.service_clients import QueryServiceClient 

59from dj.sql.parsing.backends.antlr4 import parse 

60from dj.utils import ( 

61 Version, 

62 VersionUpgrade, 

63 get_namespace_from_name, 

64 get_query_service_client, 

65 get_session, 

66) 

67 

68_logger = logging.getLogger(__name__) 

69router = APIRouter() 

70 

71 

72@router.post("/nodes/validate/", response_model=NodeValidation) 

73def validate_a_node( 

74 data: Union[NodeRevisionBase, NodeRevision], 

75 session: Session = Depends(get_session), 

76) -> NodeValidation: 

77 """ 

78 Validate a node. 

79 """ 

80 

81 if data.type == NodeType.SOURCE: 

82 raise DJException(message="Source nodes cannot be validated") 

83 

84 ( 

85 validated_node, 

86 dependencies_map, 

87 missing_parents_map, 

88 type_inference_failed_columns, 

89 ) = validate_node_data(data, session) 

90 if missing_parents_map or type_inference_failed_columns: 

91 status = NodeStatus.INVALID 

92 else: 

93 status = NodeStatus.VALID 

94 

95 return NodeValidation( 

96 message=f"Node `{validated_node.name}` is {status}", 

97 status=status, 

98 node_revision=validated_node, 

99 dependencies=set(dependencies_map.keys()), 

100 columns=validated_node.columns, 

101 ) 

102 

103 

104def validate_and_build_attribute( 

105 session: Session, 

106 attribute_input: ColumnAttributeInput, 

107 node: Node, 

108) -> ColumnAttribute: 

109 """ 

110 Run some validation and build column attribute. 

111 """ 

112 column_map = {column.name: column for column in node.current.columns} 

113 if attribute_input.column_name not in column_map: 

114 raise DJDoesNotExistException( 

115 message=f"Column `{attribute_input.column_name}` " 

116 f"does not exist on node `{node.name}`!", 

117 ) 

118 column = column_map[attribute_input.column_name] 

119 existing_attributes = {attr.attribute_type.name: attr for attr in column.attributes} 

120 if attribute_input.attribute_type_name in existing_attributes: 

121 return existing_attributes[attribute_input.attribute_type_name] 

122 

123 # Verify attribute type exists 

124 attribute_type = get_attribute_type( 

125 session, 

126 attribute_input.attribute_type_name, 

127 attribute_input.attribute_type_namespace, 

128 ) 

129 if not attribute_type: 

130 raise DJDoesNotExistException( 

131 message=f"Attribute type `{attribute_input.attribute_type_namespace}" 

132 f".{attribute_input.attribute_type_name}` " 

133 f"does not exist!", 

134 ) 

135 

136 # Verify that the attribute type is allowed for this node 

137 if node.type not in attribute_type.allowed_node_types: 

138 raise DJException( 

139 message=f"Attribute type `{attribute_input.attribute_type_namespace}." 

140 f"{attribute_type.name}` not allowed on node " 

141 f"type `{node.type}`!", 

142 ) 

143 

144 return ColumnAttribute( 

145 attribute_type=attribute_type, 

146 column=column, 

147 ) 

148 

149 

150def set_column_attributes_on_node( 

151 session: Session, 

152 attributes: List[ColumnAttributeInput], 

153 node: Node, 

154) -> List[Column]: 

155 """ 

156 Sets the column attributes on the node if allowed. 

157 """ 

158 modified_columns_map = {} 

159 for attribute_input in attributes: 

160 new_attribute = validate_and_build_attribute(session, attribute_input, node) 

161 # pylint: disable=no-member 

162 modified_columns_map[new_attribute.column.name] = new_attribute.column 

163 

164 # Validate column attributes by building mapping between 

165 # attribute scope and columns 

166 attributes_columns_map = defaultdict(set) 

167 modified_columns = modified_columns_map.values() 

168 

169 for column in modified_columns: 

170 for attribute in column.attributes: 

171 scopes_map = { 

172 UniquenessScope.NODE: attribute.attribute_type, 

173 UniquenessScope.COLUMN_TYPE: column.type, 

174 } 

175 attributes_columns_map[ 

176 ( # type: ignore 

177 attribute.attribute_type, 

178 tuple( 

179 scopes_map[item] 

180 for item in attribute.attribute_type.uniqueness_scope 

181 ), 

182 ) 

183 ].add(column.name) 

184 

185 for (attribute, _), columns in attributes_columns_map.items(): 

186 if len(columns) > 1 and attribute.uniqueness_scope: 

187 for col in columns: 

188 modified_columns_map[col].attributes = [] 

189 raise DJException( 

190 message=f"The column attribute `{attribute.name}` is scoped to be " 

191 f"unique to the `{attribute.uniqueness_scope}` level, but there " 

192 "is more than one column tagged with it: " 

193 f"`{', '.join(sorted(list(columns)))}`", 

194 ) 

195 

196 session.add_all(modified_columns) 

197 session.commit() 

198 for col in modified_columns: 

199 session.refresh(col) 

200 

201 session.refresh(node) 

202 session.refresh(node.current) 

203 return list(modified_columns) 

204 

205 

206@router.post( 

207 "/nodes/{node_name}/attributes/", 

208 response_model=List[ColumnOutput], 

209 status_code=201, 

210) 

211def set_column_attributes( 

212 node_name: str, 

213 attributes: List[ColumnAttributeInput], 

214 *, 

215 session: Session = Depends(get_session), 

216) -> List[ColumnOutput]: 

217 """ 

218 Set column attributes for the node. 

219 """ 

220 node = get_node_by_name(session, node_name) 

221 modified_columns = set_column_attributes_on_node(session, attributes, node) 

222 return list(modified_columns) # type: ignore 

223 

224 

225@router.get("/nodes/", response_model=List[NodeOutput]) 

226def list_nodes(*, session: Session = Depends(get_session)) -> List[NodeOutput]: 

227 """ 

228 List the available nodes. 

229 """ 

230 nodes = session.exec(select(Node).options(joinedload(Node.current))).unique().all() 

231 return nodes 

232 

233 

234@router.get("/nodes/{name}/", response_model=NodeOutput) 

235def get_a_node(name: str, *, session: Session = Depends(get_session)) -> NodeOutput: 

236 """ 

237 Show the active version of the specified node. 

238 """ 

239 node = get_node_by_name(session, name, with_current=True) 

240 return node # type: ignore 

241 

242 

243@router.delete("/nodes/{name}/", status_code=204) 

244def delete_a_node(name: str, *, session: Session = Depends(get_session)): 

245 """ 

246 Delete the specified node. 

247 """ 

248 node = get_node_by_name(session, name, with_current=True) 

249 

250 # Find all downstream nodes and mark them as invalid 

251 downstreams = get_downstream_nodes(session, node.name) 

252 for downstream in downstreams: 

253 downstream.current.status = NodeStatus.INVALID 

254 session.add(downstream) 

255 

256 # If the node is a dimension, find all columns that 

257 # are linked to this dimension and remove the link 

258 if node.type == NodeType.DIMENSION: 

259 columns = ( 

260 session.exec(select(Column).where(Column.dimension_id == node.id)) 

261 .unique() 

262 .all() 

263 ) 

264 for col in columns: 

265 col.dimension_id = None 

266 col.dimension_column = None 

267 session.add(col) 

268 session.delete(node) 

269 session.commit() 

270 return Response(status_code=HTTPStatus.NO_CONTENT.value) 

271 

272 

273@router.post("/nodes/{name}/materialization/", status_code=201) 

274def upsert_a_materialization_config( 

275 name: str, 

276 data: UpsertMaterializationConfig, 

277 *, 

278 session: Session = Depends(get_session), 

279) -> JSONResponse: 

280 """ 

281 Update materialization config of the specified node. 

282 """ 

283 node = get_node_by_name(session, name, with_current=True) 

284 if node.type == NodeType.SOURCE: 

285 raise DJException( 

286 http_status_code=HTTPStatus.BAD_REQUEST, 

287 message=f"Cannot set materialization config for source node `{name}`!", 

288 ) 

289 current_revision = node.current 

290 

291 # Check to see if a config for this engine already exists with the exact same config 

292 existing_config_for_engine = [ 

293 config 

294 for config in node.current.materialization_configs 

295 if config.engine.name == data.engine_name 

296 ] 

297 if ( 

298 existing_config_for_engine 

299 and existing_config_for_engine[0].config == data.config 

300 ): 

301 return JSONResponse( 

302 status_code=HTTPStatus.NO_CONTENT, 

303 content={ 

304 "message": ( 

305 f"The same materialization config provided already exists for " 

306 f"node `{name}` so no update was performed." 

307 ), 

308 }, 

309 ) 

310 

311 # Materialization config changed, so create a new materialization config and a new node 

312 # revision that references it. 

313 engine = get_engine(session, data.engine_name, data.engine_version) 

314 new_node_revision = create_new_revision_from_existing( 

315 session, 

316 current_revision, 

317 node, 

318 version_upgrade=VersionUpgrade.MAJOR, 

319 ) 

320 

321 unchanged_existing_configs = [ 

322 config 

323 for config in node.current.materialization_configs 

324 if config.engine.name != data.engine_name 

325 ] 

326 new_config = MaterializationConfig( 

327 node_revision=new_node_revision, 

328 engine=engine, 

329 config=data.config, 

330 ) 

331 new_node_revision.materialization_configs = unchanged_existing_configs + [ # type: ignore 

332 new_config, 

333 ] 

334 node.current_version = new_node_revision.version # type: ignore 

335 

336 # This will add the materialization config, the new node rev, and update the node's version. 

337 session.add(new_node_revision) 

338 session.add(node) 

339 session.commit() 

340 

341 return JSONResponse( 

342 status_code=200, 

343 content={ 

344 "message": ( 

345 f"Successfully updated materialization config for node `{name}`" 

346 f" and engine `{engine.name}`." 

347 ), 

348 }, 

349 ) 

350 

351 

352@router.get("/nodes/{name}/revisions/", response_model=List[NodeRevisionOutput]) 

353def list_node_revisions( 

354 name: str, *, session: Session = Depends(get_session) 

355) -> List[NodeRevisionOutput]: 

356 """ 

357 List all revisions for the node. 

358 """ 

359 node = get_node_by_name(session, name, with_current=False) 

360 return node.revisions # type: ignore 

361 

362 

363def create_node_revision( 

364 data: CreateNode, 

365 node_type: NodeType, 

366 session: Session, 

367) -> NodeRevision: 

368 """ 

369 Create a non-source node revision. 

370 """ 

371 node_revision = NodeRevision( 

372 name=data.name, 

373 namespace=data.namespace, 

374 display_name=data.display_name 

375 if data.display_name 

376 else generate_display_name(data.name), 

377 description=data.description, 

378 type=node_type, 

379 status=NodeStatus.VALID, 

380 query=data.query, 

381 mode=data.mode, 

382 ) 

383 ( 

384 validated_node, 

385 dependencies_map, 

386 missing_parents_map, 

387 type_inference_failed_columns, 

388 ) = validate_node_data(node_revision, session) 

389 if missing_parents_map or type_inference_failed_columns: 

390 node_revision.status = NodeStatus.INVALID 

391 else: 

392 node_revision.status = NodeStatus.VALID 

393 node_revision.missing_parents = [ 

394 MissingParent(name=missing_parent) for missing_parent in missing_parents_map 

395 ] 

396 new_parents = [node.name for node in dependencies_map] 

397 catalog_ids = [node.catalog_id for node in dependencies_map] 

398 if node_revision.mode == NodeMode.PUBLISHED and not len(set(catalog_ids)) == 1: 

399 raise DJException( 

400 f"Cannot create nodes with multi-catalog dependencies: {set(catalog_ids)}", 

401 ) 

402 catalog_id = next(iter(catalog_ids), 0) 

403 parent_refs = session.exec( 

404 select(Node).where( 

405 # pylint: disable=no-member 

406 Node.name.in_( # type: ignore 

407 new_parents, 

408 ), 

409 ), 

410 ).all() 

411 node_revision.parents = parent_refs 

412 

413 _logger.info( 

414 "Parent nodes for %s (%s): %s", 

415 data.name, 

416 node_revision.version, 

417 [p.name for p in node_revision.parents], 

418 ) 

419 node_revision.columns = validated_node.columns or [] 

420 node_revision.catalog_id = catalog_id 

421 return node_revision 

422 

423 

424def create_cube_node_revision( 

425 session: Session, 

426 data: CreateCubeNode, 

427) -> NodeRevision: 

428 """ 

429 Create a cube node revision. 

430 """ 

431 metrics = [] 

432 dimensions = [] 

433 catalogs = [] 

434 for node_name in data.cube_elements: 

435 cube_element = get_node_by_name(session=session, name=node_name) 

436 catalogs.append(cube_element.current.catalog.name) 

437 if cube_element.type == NodeType.METRIC: 

438 metrics.append(cube_element) 

439 elif cube_element.type == NodeType.DIMENSION: 

440 dimensions.append(cube_element) 

441 else: 

442 raise DJException( 

443 message=( 

444 f"Node {cube_element.name} of type {cube_element.type} " 

445 "cannot be added to a cube" 

446 ), 

447 http_status_code=http.client.UNPROCESSABLE_ENTITY, 

448 ) 

449 if not metrics: 

450 raise DJException( 

451 message=("At least one metric is required to create a cube node"), 

452 http_status_code=http.client.UNPROCESSABLE_ENTITY, 

453 ) 

454 if not dimensions: 

455 raise DJException( 

456 message=("At least one dimension is required to create a cube node"), 

457 http_status_code=http.client.UNPROCESSABLE_ENTITY, 

458 ) 

459 if len(set(catalogs)) > 1: 

460 raise DJException( 

461 message=( 

462 f"Cannot create cube using nodes from multiple catalogs: {catalogs}" 

463 ), 

464 ) 

465 if len(set(catalogs)) < 1: # pragma: no cover 

466 raise DJException( 

467 message=("Cube elements must contain a common catalog"), 

468 ) 

469 return NodeRevision( 

470 name=data.name, 

471 namespace=data.namespace, 

472 description=data.description, 

473 type=NodeType.CUBE, 

474 cube_elements=metrics + dimensions, 

475 ) 

476 

477 

478def save_node( 

479 session: Session, 

480 node_revision: NodeRevision, 

481 node: Node, 

482 node_mode: NodeMode, 

483): 

484 """ 

485 Links the node and node revision together and saves them 

486 """ 

487 node_revision.node = node 

488 node_revision.version = ( 

489 str(DEFAULT_DRAFT_VERSION) 

490 if node_mode == NodeMode.DRAFT 

491 else str(DEFAULT_PUBLISHED_VERSION) 

492 ) 

493 node.current_version = node_revision.version 

494 node_revision.extra_validation() 

495 

496 session.add(node) 

497 session.commit() 

498 

499 newly_valid_nodes = resolve_downstream_references( 

500 session=session, 

501 node_revision=node_revision, 

502 ) 

503 propagate_valid_status( 

504 session=session, 

505 valid_nodes=newly_valid_nodes, 

506 catalog_id=node.current.catalog_id, # pylint: disable=no-member 

507 ) 

508 session.refresh(node.current) 

509 

510 

511@router.post("/nodes/source/", response_model=NodeOutput, status_code=201) 

512def create_a_source( 

513 data: CreateSourceNode, 

514 session: Session = Depends(get_session), 

515 query_service_client: QueryServiceClient = Depends(get_query_service_client), 

516) -> NodeOutput: 

517 """ 

518 Create a source node. If columns are not provided, the source node's schema 

519 will be inferred using the configured query service. 

520 """ 

521 raise_if_node_exists(session, data.name) 

522 

523 # Extract and assign namespace if one exists 

524 namespace = get_namespace_from_name(data.name) 

525 get_node_namespace( 

526 session=session, 

527 namespace=namespace, 

528 ) # Will return 404 if namespace doesn't exist 

529 data.namespace = namespace 

530 

531 node = Node( 

532 name=data.name, 

533 namespace=data.namespace, 

534 type=NodeType.SOURCE, 

535 current_version=0, 

536 ) 

537 catalog = get_catalog(session=session, name=data.catalog) 

538 

539 # When no columns are provided, attempt to find actual table columns 

540 # if a query service is set 

541 columns = ( 

542 [ 

543 Column( 

544 name=column_data.name, 

545 type=column_data.type, 

546 dimension=( 

547 get_node_by_name( 

548 session, 

549 name=column_data.dimension, 

550 node_type=NodeType.DIMENSION, 

551 raise_if_not_exists=False, 

552 ) 

553 ), 

554 ) 

555 for column_data in data.columns 

556 ] 

557 if data.columns 

558 else None 

559 ) 

560 if not columns: 

561 if not query_service_client: 

562 raise DJException( 

563 message="No table columns were provided and no query " 

564 "service is configured for table columns inference!", 

565 ) 

566 columns = query_service_client.get_columns_for_table( 

567 data.catalog, 

568 data.schema_, # type: ignore 

569 data.table, 

570 catalog.engines[0] if len(catalog.engines) >= 1 else None, 

571 ) 

572 

573 node_revision = NodeRevision( 

574 name=data.name, 

575 namespace=data.namespace, 

576 display_name=data.display_name 

577 if data.display_name 

578 else generate_display_name(data.name), 

579 description=data.description, 

580 type=NodeType.SOURCE, 

581 status=NodeStatus.VALID, 

582 catalog_id=catalog.id, 

583 schema_=data.schema_, 

584 table=data.table, 

585 columns=columns, 

586 parents=[], 

587 ) 

588 

589 # Point the node to the new node revision. 

590 save_node(session, node_revision, node, data.mode) 

591 return node # type: ignore 

592 

593 

594@router.post("/nodes/transform/", response_model=NodeOutput, status_code=201) 

595@router.post("/nodes/dimension/", response_model=NodeOutput, status_code=201) 

596@router.post("/nodes/metric/", response_model=NodeOutput, status_code=201) 

597def create_a_node( 

598 data: CreateNode, 

599 request: Request, 

600 *, 

601 session: Session = Depends(get_session), 

602) -> NodeOutput: 

603 """ 

604 Create a node. 

605 """ 

606 node_type = NodeType(os.path.basename(os.path.normpath(request.url.path))) 

607 

608 if node_type == NodeType.DIMENSION and not data.primary_key: 

609 raise DJInvalidInputException("Dimension nodes must define a primary key!") 

610 

611 raise_if_node_exists(session, data.name) 

612 

613 namespace = get_namespace_from_name(data.name) 

614 get_node_namespace( 

615 session=session, 

616 namespace=namespace, 

617 ) # Will return 404 if namespace doesn't exist 

618 data.namespace = namespace 

619 

620 node = Node( 

621 name=data.name, 

622 namespace=data.namespace, 

623 type=NodeType(node_type), 

624 current_version=0, 

625 ) 

626 node_revision = create_node_revision(data, node_type, session) 

627 save_node(session, node_revision, node, data.mode) 

628 session.refresh(node) 

629 

630 column_names = {col.name for col in node_revision.columns} 

631 if data.primary_key and any( 

632 key_column not in column_names for key_column in data.primary_key 

633 ): 

634 raise DJInvalidInputException( 

635 f"Some columns in the primary key {','.join(data.primary_key)} " 

636 f"were not found in the list of available columns for the node {node.name}.", 

637 ) 

638 if data.primary_key: 

639 attributes = [ 

640 ColumnAttributeInput( 

641 attribute_type_namespace="system", 

642 attribute_type_name="primary_key", 

643 column_name=key_column, 

644 ) 

645 for key_column in data.primary_key 

646 if key_column in column_names 

647 ] 

648 set_column_attributes_on_node(session, attributes, node) 

649 session.refresh(node) 

650 session.refresh(node.current) 

651 return node # type: ignore 

652 

653 

654@router.post("/nodes/cube/", response_model=NodeOutput, status_code=201) 

655def create_a_cube( 

656 data: CreateCubeNode, 

657 session: Session = Depends(get_session), 

658) -> NodeOutput: 

659 """ 

660 Create a node. 

661 """ 

662 raise_if_node_exists(session, data.name) 

663 node = Node( 

664 name=data.name, 

665 namespace=data.namespace, 

666 type=NodeType.CUBE, 

667 current_version=0, 

668 ) 

669 node_revision = create_cube_node_revision(session=session, data=data) 

670 save_node(session, node_revision, node, data.mode) 

671 return node # type: ignore 

672 

673 

674@router.post("/nodes/{name}/columns/{column}/", status_code=201) 

675def link_a_dimension( 

676 name: str, 

677 column: str, 

678 dimension: Optional[str] = None, 

679 dimension_column: Optional[str] = None, 

680 session: Session = Depends(get_session), 

681) -> JSONResponse: 

682 """ 

683 Add information to a node column 

684 """ 

685 if not dimension: # If no dimension is set, assume it matches the column name 

686 dimension = column 

687 

688 node = get_node_by_name(session=session, name=name) 

689 dimension_node = get_node_by_name( 

690 session=session, 

691 name=dimension, 

692 node_type=NodeType.DIMENSION, 

693 ) 

694 if node.current.catalog.name != dimension_node.current.catalog.name: 

695 raise DJException( 

696 message=( 

697 "Cannot add dimension to column, because catalogs do not match: " 

698 f"{node.current.catalog.name}, {dimension_node.current.catalog.name}" 

699 ), 

700 ) 

701 

702 target_column = get_column(node.current, column) 

703 if dimension_column: 

704 # Check that the dimension column exists 

705 column_from_dimension = get_column(dimension_node.current, dimension_column) 

706 

707 # Check the dimension column's type is compatible with the target column's type 

708 if not column_from_dimension.type.is_compatible(target_column.type): 

709 raise DJInvalidInputException( 

710 f"The column {target_column.name} has type {target_column.type} " 

711 f"and is being linked to the dimension {dimension} via the dimension" 

712 f" column {dimension_column}, which has type {column_from_dimension.type}." 

713 " These column types are incompatible and the dimension cannot be linked!", 

714 ) 

715 

716 target_column.dimension = dimension_node 

717 target_column.dimension_id = dimension_node.id 

718 target_column.dimension_column = dimension_column 

719 

720 session.add(node) 

721 session.commit() 

722 session.refresh(node) 

723 return JSONResponse( 

724 status_code=201, 

725 content={ 

726 "message": ( 

727 f"Dimension node {dimension} has been successfully " 

728 f"linked to column {column} on node {name}" 

729 ), 

730 }, 

731 ) 

732 

733 

734@router.post("/nodes/{name}/tag/", status_code=201) 

735def tag_a_node( 

736 name: str, tag_name: str, *, session: Session = Depends(get_session) 

737) -> JSONResponse: 

738 """ 

739 Add a tag to a node 

740 """ 

741 node = get_node_by_name(session=session, name=name) 

742 tag = get_tag_by_name(session, name=tag_name, raise_if_not_exists=True) 

743 node.tags.append(tag) 

744 

745 session.add(node) 

746 session.commit() 

747 session.refresh(node) 

748 session.refresh(tag) 

749 

750 return JSONResponse( 

751 status_code=201, 

752 content={ 

753 "message": ( 

754 f"Node `{name}` has been successfully tagged with tag `{tag_name}`" 

755 ), 

756 }, 

757 ) 

758 

759 

760def create_new_revision_from_existing( # pylint: disable=too-many-locals 

761 session: Session, 

762 old_revision: NodeRevision, 

763 node: Node, 

764 data: UpdateNode = None, 

765 version_upgrade: VersionUpgrade = None, 

766) -> Optional[NodeRevision]: 

767 """ 

768 Creates a new revision from an existing node revision. 

769 """ 

770 minor_changes = ( 

771 (data and data.description and old_revision.description != data.description) 

772 or (data and data.mode and old_revision.mode != data.mode) 

773 or ( 

774 data 

775 and data.display_name 

776 and old_revision.display_name != data.display_name 

777 ) 

778 ) 

779 query_changes = ( 

780 old_revision.type != NodeType.SOURCE 

781 and data 

782 and data.query 

783 and old_revision.query != data.query 

784 ) 

785 column_changes = ( 

786 old_revision.type == NodeType.SOURCE 

787 and data is not None 

788 and data.columns is not None 

789 and ({col.identifier() for col in old_revision.columns} != data.columns) 

790 ) 

791 major_changes = query_changes or column_changes 

792 

793 # If nothing has changed, do not create the new node revision 

794 if not minor_changes and not major_changes and not version_upgrade: 

795 return None 

796 

797 old_version = Version.parse(node.current_version) 

798 new_revision = NodeRevision( 

799 name=old_revision.name, 

800 node_id=node.id, 

801 version=str( 

802 old_version.next_major_version() 

803 if major_changes or version_upgrade == VersionUpgrade.MAJOR 

804 else old_version.next_minor_version(), 

805 ), 

806 display_name=( 

807 data.display_name 

808 if data and data.display_name 

809 else old_revision.display_name 

810 ), 

811 description=( 

812 data.description if data and data.description else old_revision.description 

813 ), 

814 query=(data.query if data and data.query else old_revision.query), 

815 type=old_revision.type, 

816 columns=[ 

817 Column( 

818 name=column_data.name, 

819 type=column_data.type, 

820 dimension_column=column_data.dimension, 

821 ) 

822 for column_data in data.columns 

823 ] 

824 if data and data.columns 

825 else old_revision.columns, 

826 catalog=old_revision.catalog, 

827 schema_=old_revision.schema_, 

828 table=old_revision.table, 

829 parents=[], 

830 mode=data.mode if data and data.mode else old_revision.mode, 

831 materialization_configs=old_revision.materialization_configs, 

832 ) 

833 

834 # Link the new revision to its parents if the query has changed 

835 if ( 

836 new_revision.type != NodeType.SOURCE 

837 and new_revision.query != old_revision.query 

838 ): 

839 ( 

840 validated_node, 

841 dependencies_map, 

842 missing_parents_map, 

843 type_inference_failed_columns, 

844 ) = validate_node_data(new_revision, session) 

845 new_parents = [n.name for n in dependencies_map] 

846 parent_refs = session.exec( 

847 select(Node).where( 

848 # pylint: disable=no-member 

849 Node.name.in_( # type: ignore 

850 new_parents, 

851 ), 

852 ), 

853 ).all() 

854 new_revision.parents = list(parent_refs) 

855 if missing_parents_map or type_inference_failed_columns: 

856 new_revision.status = NodeStatus.INVALID 

857 else: 

858 new_revision.status = NodeStatus.VALID 

859 new_revision.missing_parents = [ 

860 MissingParent(name=missing_parent) for missing_parent in missing_parents_map 

861 ] 

862 _logger.info( 

863 "Parent nodes for %s (v%s): %s", 

864 new_revision.name, 

865 new_revision.version, 

866 [p.name for p in new_revision.parents], 

867 ) 

868 new_revision.columns = validated_node.columns or [] 

869 return new_revision 

870 

871 

872@router.patch("/nodes/{name}/", response_model=NodeOutput) 

873def update_a_node( 

874 name: str, 

875 data: UpdateNode, 

876 *, 

877 session: Session = Depends(get_session), 

878) -> NodeOutput: 

879 """ 

880 Update a node. 

881 """ 

882 

883 query = ( 

884 select(Node) 

885 .where(Node.name == name) 

886 .with_for_update() 

887 .execution_options(populate_existing=True) 

888 ) 

889 node = session.exec(query).one_or_none() 

890 if not node: 

891 raise DJException( 

892 message=f"A node with name `{name}` does not exist.", 

893 http_status_code=404, 

894 ) 

895 

896 old_revision = node.current 

897 new_revision = create_new_revision_from_existing(session, old_revision, node, data) 

898 

899 if not new_revision: 

900 return node # type: ignore 

901 

902 node.current_version = new_revision.version 

903 

904 new_revision.extra_validation() 

905 

906 session.add(new_revision) 

907 session.add(node) 

908 session.commit() 

909 session.refresh(node.current) 

910 return node # type: ignore 

911 

912 

913@router.get("/nodes/similarity/{node1_name}/{node2_name}") 

914def calculate_node_similarity( 

915 node1_name: str, node2_name: str, *, session: Session = Depends(get_session) 

916) -> JSONResponse: 

917 """ 

918 Compare two nodes by how similar their queries are 

919 """ 

920 node1 = get_node_by_name(session=session, name=node1_name) 

921 node2 = get_node_by_name(session=session, name=node2_name) 

922 if NodeType.SOURCE in (node1.type, node2.type): 

923 raise DJException( 

924 message="Cannot determine similarity of source nodes", 

925 http_status_code=HTTPStatus.CONFLICT, 

926 ) 

927 node1_ast = parse(node1.current.query) # type: ignore 

928 node2_ast = parse(node2.current.query) # type: ignore 

929 similarity = node1_ast.similarity_score(node2_ast) 

930 return JSONResponse(status_code=200, content={"similarity": similarity}) 

931 

932 

933@router.get("/nodes/{name}/downstream/", response_model=List[NodeOutput]) 

934def list_downstream_nodes( 

935 name: str, *, node_type: NodeType = None, session: Session = Depends(get_session) 

936) -> List[NodeOutput]: 

937 """ 

938 List all nodes that are downstream from the given node, filterable by type. 

939 """ 

940 return get_downstream_nodes(session, name, node_type) # type: ignore