Coverage for dj/models/node.py: 100%
263 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-17 20:05 -0700
1"""
2Model for nodes.
3"""
4# pylint: disable=too-many-instance-attributes
5import enum
6from dataclasses import dataclass
7from datetime import datetime, timezone
8from functools import partial
9from typing import Dict, List, Optional
11from pydantic import BaseModel, Extra
12from pydantic import Field as PydanticField
13from pydantic import root_validator
14from sqlalchemy import JSON, DateTime, String
15from sqlalchemy.sql.schema import Column as SqlaColumn
16from sqlalchemy.sql.schema import UniqueConstraint
17from sqlalchemy.types import Enum
18from sqlmodel import Field, Relationship, SQLModel
19from typing_extensions import TypedDict
21from dj.errors import DJInvalidInputException
22from dj.models.base import BaseSQLModel, generate_display_name
23from dj.models.catalog import Catalog
24from dj.models.column import Column, ColumnYAML
25from dj.models.database import Database
26from dj.models.engine import Dialect, Engine, EngineInfo
27from dj.models.tag import Tag, TagNodeRelationship
28from dj.sql.parse import is_metric
29from dj.sql.parsing.types import ColumnType
30from dj.typing import UTCDatetime
31from dj.utils import Version
33DEFAULT_DRAFT_VERSION = Version(major=0, minor=1)
34DEFAULT_PUBLISHED_VERSION = Version(major=1, minor=0)
37@dataclass(frozen=True)
38class BuildCriteria:
39 """
40 Criterion used for building
41 - used to deterimine whether to use an availability state
42 """
44 timestamp: Optional[UTCDatetime] = None
45 dialect: Dialect = Dialect.SPARK
48class NodeRelationship(BaseSQLModel, table=True): # type: ignore
49 """
50 Join table for self-referential many-to-many relationships between nodes.
51 """
53 parent_id: Optional[int] = Field(
54 default=None,
55 foreign_key="node.id",
56 primary_key=True,
57 )
59 # This will default to `latest`, which points to the current version of the node,
60 # or it can be a specific version.
61 parent_version: Optional[str] = Field(
62 default="latest",
63 )
65 child_id: Optional[int] = Field(
66 default=None,
67 foreign_key="noderevision.id",
68 primary_key=True,
69 )
72class CubeRelationship(BaseSQLModel, table=True): # type: ignore
73 """
74 Join table for many-to-many relationships between cube nodes and metric/dimension nodes.
75 """
77 __tablename__ = "cube"
79 cube_id: Optional[int] = Field(
80 default=None,
81 foreign_key="noderevision.id",
82 primary_key=True,
83 )
85 cube_element_id: Optional[int] = Field(
86 default=None,
87 foreign_key="node.id",
88 primary_key=True,
89 )
92class NodeColumns(BaseSQLModel, table=True): # type: ignore
93 """
94 Join table for node columns.
95 """
97 node_id: Optional[int] = Field(
98 default=None,
99 foreign_key="noderevision.id",
100 primary_key=True,
101 )
102 column_id: Optional[int] = Field(
103 default=None,
104 foreign_key="column.id",
105 primary_key=True,
106 )
109class NodeType(str, enum.Enum):
110 """
111 Node type.
113 A node can have 4 types, currently:
115 1. SOURCE nodes are root nodes in the DAG, and point to tables or views in a DB.
116 2. TRANSFORM nodes are SQL transformations, reading from SOURCE/TRANSFORM nodes.
117 3. METRIC nodes are leaves in the DAG, and have a single aggregation query.
118 4. DIMENSION nodes are special SOURCE nodes that can be auto-joined with METRICS.
119 5. CUBE nodes contain a reference to a set of METRICS and a set of DIMENSIONS.
120 """
122 SOURCE = "source"
123 TRANSFORM = "transform"
124 METRIC = "metric"
125 DIMENSION = "dimension"
126 CUBE = "cube"
129class NodeMode(str, enum.Enum):
130 """
131 Node mode.
133 A node can be in one of the following modes:
135 1. PUBLISHED - Must be valid and not cause any child nodes to be invalid
136 2. DRAFT - Can be invalid, have invalid parents, and include dangling references
137 """
139 PUBLISHED = "published"
140 DRAFT = "draft"
143class NodeStatus(str, enum.Enum):
144 """
145 Node status.
147 A node can have one of the following statuses:
149 1. VALID - All references to other nodes and node columns are valid
150 2. INVALID - One or more parent nodes are incompatible or do not exist
151 """
153 VALID = "valid"
154 INVALID = "invalid"
157class NodeYAML(TypedDict, total=False):
158 """
159 Schema of a node in the YAML file.
160 """
162 description: str
163 display_name: str
164 type: NodeType
165 query: str
166 columns: Dict[str, ColumnYAML]
169class NodeBase(BaseSQLModel):
170 """
171 A base node.
172 """
174 name: str = Field(sa_column=SqlaColumn("name", String, unique=True))
175 type: NodeType = Field(sa_column=SqlaColumn(Enum(NodeType)))
176 display_name: Optional[str] = Field(
177 sa_column=SqlaColumn(
178 "display_name",
179 String,
180 default=generate_display_name("name"),
181 ),
182 max_length=100,
183 )
186class NodeRevisionBase(BaseSQLModel):
187 """
188 A base node revision.
189 """
191 name: str = Field(
192 sa_column=SqlaColumn("name", String, unique=False),
193 foreign_key="node.name",
194 )
195 display_name: Optional[str] = Field(
196 sa_column=SqlaColumn(
197 "display_name",
198 String,
199 default=generate_display_name("name"),
200 ),
201 )
202 type: NodeType = Field(sa_column=SqlaColumn(Enum(NodeType)))
203 description: str = ""
204 query: Optional[str] = None
205 mode: NodeMode = NodeMode.PUBLISHED
208class MissingParent(BaseSQLModel, table=True): # type: ignore
209 """
210 A missing parent node
211 """
213 id: Optional[int] = Field(default=None, primary_key=True)
214 name: str = Field(sa_column=SqlaColumn("name", String))
215 created_at: UTCDatetime = Field(
216 sa_column=SqlaColumn(DateTime(timezone=True)),
217 default_factory=partial(datetime.now, timezone.utc),
218 )
221class NodeMissingParents(BaseSQLModel, table=True): # type: ignore
222 """
223 Join table for missing parents
224 """
226 missing_parent_id: Optional[int] = Field(
227 default=None,
228 foreign_key="missingparent.id",
229 primary_key=True,
230 )
231 referencing_node_id: Optional[int] = Field(
232 default=None,
233 foreign_key="noderevision.id",
234 primary_key=True,
235 )
238class AvailabilityStateBase(BaseSQLModel):
239 """
240 An availability state base
241 """
243 catalog: str
244 schema_: Optional[str] = Field(default=None)
245 table: str
246 valid_through_ts: int
247 max_partition: List[str] = Field(sa_column=SqlaColumn(JSON))
248 min_partition: List[str] = Field(sa_column=SqlaColumn(JSON))
251class AvailabilityState(AvailabilityStateBase, table=True): # type: ignore
252 """
253 The availability of materialized data for a node
254 """
256 id: Optional[int] = Field(default=None, primary_key=True)
257 updated_at: UTCDatetime = Field(
258 sa_column=SqlaColumn(DateTime(timezone=True)),
259 default_factory=partial(datetime.now, timezone.utc),
260 )
262 def is_available(
263 self,
264 criteria: Optional[BuildCriteria] = None, # pylint: disable=unused-argument
265 ) -> bool: # pragma: no cover
266 """
267 Determine whether an availability state is useable given criteria
268 """
269 # Criteria to determine if an availability state should be used needs to be added
270 return True
273class NodeAvailabilityState(BaseSQLModel, table=True): # type: ignore
274 """
275 Join table for availability state
276 """
278 availability_id: Optional[int] = Field(
279 default=None,
280 foreign_key="availabilitystate.id",
281 primary_key=True,
282 )
283 node_id: Optional[int] = Field(
284 default=None,
285 foreign_key="noderevision.id",
286 primary_key=True,
287 )
290class NodeNamespace(SQLModel, table=True): # type: ignore
291 """
292 A node namespace
293 """
295 namespace: str = Field(nullable=False, unique=True, primary_key=True)
298class Node(NodeBase, table=True): # type: ignore
299 """
300 Node that acts as an umbrella for all node revisions
301 """
303 __table_args__ = (
304 UniqueConstraint("name", "namespace", name="unique_node_namespace_name"),
305 )
307 id: Optional[int] = Field(default=None, primary_key=True)
308 namespace: Optional[str] = "default"
309 current_version: str = Field(default=str(DEFAULT_DRAFT_VERSION))
310 created_at: UTCDatetime = Field(
311 sa_column=SqlaColumn(DateTime(timezone=True)),
312 default_factory=partial(datetime.now, timezone.utc),
313 )
315 revisions: List["NodeRevision"] = Relationship(back_populates="node")
316 cubes: List["NodeRevision"] = Relationship(back_populates="cube_elements")
317 current: "NodeRevision" = Relationship(
318 sa_relationship_kwargs={
319 "primaryjoin": "and_(Node.id==NodeRevision.node_id, "
320 "Node.current_version == NodeRevision.version)",
321 "viewonly": True,
322 "uselist": False,
323 },
324 )
326 children: List["NodeRevision"] = Relationship(
327 back_populates="parents",
328 link_model=NodeRelationship,
329 sa_relationship_kwargs={
330 "primaryjoin": "Node.id==NodeRelationship.parent_id",
331 "secondaryjoin": "NodeRevision.id==NodeRelationship.child_id",
332 },
333 )
335 tags: List["Tag"] = Relationship(
336 back_populates="nodes",
337 link_model=TagNodeRelationship,
338 sa_relationship_kwargs={
339 "primaryjoin": "TagNodeRelationship.node_id==Node.id",
340 "secondaryjoin": "TagNodeRelationship.tag_id==Tag.id",
341 },
342 )
344 def __hash__(self) -> int:
345 return hash(self.id)
348class MaterializationConfig(BaseSQLModel, table=True): # type: ignore
349 """
350 Materialization configuration for a node and specific engines.
351 """
353 node_revision_id: int = Field(foreign_key="noderevision.id", primary_key=True)
354 node_revision: "NodeRevision" = Relationship(
355 back_populates="materialization_configs",
356 )
358 engine_id: int = Field(foreign_key="engine.id", primary_key=True)
359 engine: Engine = Relationship()
361 config: str = Field(nullable=False)
364class NodeRevision(NodeRevisionBase, table=True): # type: ignore
365 """
366 A node revision.
367 """
369 __table_args__ = (UniqueConstraint("version", "node_id"),)
371 id: Optional[int] = Field(default=None, primary_key=True)
372 version: Optional[str] = Field(default=str(DEFAULT_DRAFT_VERSION))
373 node_id: Optional[int] = Field(foreign_key="node.id")
374 node: Node = Relationship(back_populates="revisions")
375 catalog_id: int = Field(default=None, foreign_key="catalog.id")
376 catalog: Catalog = Relationship(
377 back_populates="node_revisions",
378 sa_relationship_kwargs={
379 "lazy": "joined",
380 },
381 )
382 schema_: Optional[str] = None
383 table: Optional[str] = None
384 cube_elements: List["Node"] = Relationship( # Only used by cube nodes
385 back_populates="cubes",
386 link_model=CubeRelationship,
387 sa_relationship_kwargs={
388 "primaryjoin": "NodeRevision.id==CubeRelationship.cube_id",
389 "secondaryjoin": "Node.id==CubeRelationship.cube_element_id",
390 "lazy": "joined",
391 },
392 )
393 status: NodeStatus = NodeStatus.INVALID
394 updated_at: UTCDatetime = Field(
395 sa_column=SqlaColumn(DateTime(timezone=True)),
396 default_factory=partial(datetime.now, timezone.utc),
397 )
399 parents: List["Node"] = Relationship(
400 back_populates="children",
401 link_model=NodeRelationship,
402 sa_relationship_kwargs={
403 "primaryjoin": "NodeRevision.id==NodeRelationship.child_id",
404 "secondaryjoin": "Node.id==NodeRelationship.parent_id",
405 },
406 )
408 parent_links: List[NodeRelationship] = Relationship()
410 missing_parents: List[MissingParent] = Relationship(
411 link_model=NodeMissingParents,
412 sa_relationship_kwargs={
413 "primaryjoin": "NodeRevision.id==NodeMissingParents.referencing_node_id",
414 "secondaryjoin": "MissingParent.id==NodeMissingParents.missing_parent_id",
415 "cascade": "all, delete",
416 },
417 )
419 columns: List[Column] = Relationship(
420 link_model=NodeColumns,
421 sa_relationship_kwargs={
422 "primaryjoin": "NodeRevision.id==NodeColumns.node_id",
423 "secondaryjoin": "Column.id==NodeColumns.column_id",
424 "cascade": "all, delete",
425 },
426 )
428 # The availability of materialized data needs to be stored on the NodeRevision
429 # level in order to support pinned versions, where a node owner wants to pin
430 # to a particular upstream node version.
431 availability: Optional[AvailabilityState] = Relationship(
432 link_model=NodeAvailabilityState,
433 sa_relationship_kwargs={
434 "primaryjoin": "NodeRevision.id==NodeAvailabilityState.node_id",
435 "secondaryjoin": "AvailabilityState.id==NodeAvailabilityState.availability_id",
436 "cascade": "all, delete",
437 "uselist": False,
438 },
439 )
441 # Nodes of type SOURCE will not have this property as their materialization
442 # is not managed as a part of this service
443 materialization_configs: List[MaterializationConfig] = Relationship(
444 back_populates="node_revision",
445 )
447 def __hash__(self) -> int:
448 return hash(self.id)
450 def primary_key(self) -> List[Column]:
451 """
452 Returns the primary key columns of this node.
453 """
454 primary_key_columns = []
455 for col in self.columns: # pylint: disable=not-an-iterable
456 if "primary_key" in {attr.attribute_type.name for attr in col.attributes}:
457 primary_key_columns.append(col)
458 return primary_key_columns
460 def extra_validation(self) -> None:
461 """
462 Extra validation for node data.
463 """
464 if self.type in (NodeType.SOURCE, NodeType.CUBE):
465 if self.query:
466 raise DJInvalidInputException(
467 f"Node {self.name} of type {self.type} should not have a query",
468 )
470 if self.type in {NodeType.TRANSFORM, NodeType.METRIC, NodeType.DIMENSION}:
471 if not self.query:
472 raise DJInvalidInputException(
473 f"Node {self.name} of type {self.type} needs a query",
474 )
476 if self.type == NodeType.METRIC:
477 if not is_metric(self.query):
478 raise DJInvalidInputException(
479 f"Node {self.name} of type metric has an invalid query, "
480 "should have a single aggregation",
481 )
483 if self.type == NodeType.CUBE:
484 if not self.cube_elements:
485 raise DJInvalidInputException(
486 f"Node {self.name} of type cube node needs cube elements",
487 )
490class ImmutableNodeFields(BaseSQLModel):
491 """
492 Node fields that cannot be changed
493 """
495 name: str
496 namespace: str = "default"
499class MutableNodeFields(BaseSQLModel):
500 """
501 Node fields that can be changed.
502 """
504 display_name: Optional[str]
505 description: str
506 mode: NodeMode
507 primary_key: Optional[List[str]]
510class MutableNodeQueryField(BaseSQLModel):
511 """
512 Query field for node.
513 """
515 query: str
518class NodeNameOutput(SQLModel):
519 """
520 Node name only
521 """
523 name: str
526class AttributeTypeName(BaseSQLModel):
527 """
528 Attribute type name.
529 """
531 namespace: str
532 name: str
535class AttributeOutput(BaseSQLModel):
536 """
537 Column attribute output.
538 """
540 attribute_type: AttributeTypeName
543class ColumnOutput(SQLModel):
544 """
545 A simplified column schema, without ID or dimensions.
546 """
548 name: str
549 type: ColumnType
550 attributes: Optional[List[AttributeOutput]]
551 dimension: Optional[NodeNameOutput]
553 class Config: # pylint: disable=too-few-public-methods
554 """
555 Should perform validation on assignment
556 """
558 validate_assignment = True
560 @root_validator
561 def type_string(cls, values): # pylint: disable=no-self-argument
562 """
563 Extracts the type as a string
564 """
565 values["type"] = str(values.get("type"))
566 return values
569class SourceColumnOutput(SQLModel):
570 """
571 A column used in creation of a source node
572 """
574 name: str
575 type: ColumnType
576 attributes: Optional[List[AttributeOutput]]
577 dimension: Optional[str]
579 class Config: # pylint: disable=too-few-public-methods
580 """
581 Should perform validation on assignment
582 """
584 validate_assignment = True
586 @root_validator
587 def type_string(cls, values): # pylint: disable=no-self-argument
588 """
589 Extracts the type as a string
590 """
591 values["type"] = str(values.get("type"))
592 return values
595class SourceNodeFields(BaseSQLModel):
596 """
597 Source node fields that can be changed.
598 """
600 catalog: str
601 schema_: str
602 table: str
603 columns: Optional[List["SourceColumnOutput"]] = []
606class CubeNodeFields(BaseSQLModel):
607 """
608 Cube node fields that can be changed
609 """
611 display_name: Optional[str]
612 cube_elements: List[str]
613 description: str
614 mode: NodeMode
617#
618# Create and Update objects
619#
622class CreateNode(ImmutableNodeFields, MutableNodeFields, MutableNodeQueryField):
623 """
624 Create non-source node object.
625 """
628class CreateSourceNode(ImmutableNodeFields, MutableNodeFields, SourceNodeFields):
629 """
630 A create object for source nodes
631 """
634class CreateCubeNode(ImmutableNodeFields, CubeNodeFields):
635 """
636 A create object for cube nodes
637 """
639 class Config: # pylint: disable=too-few-public-methods
640 """
641 Do not allow extra fields in input
642 """
644 extra = Extra.forbid
647class UpdateNode(MutableNodeFields, SourceNodeFields):
648 """
649 Update node object where all fields are optional
650 """
652 __annotations__ = {
653 k: Optional[v]
654 for k, v in {
655 **SourceNodeFields.__annotations__, # pylint: disable=E1101
656 **MutableNodeFields.__annotations__, # pylint: disable=E1101
657 **MutableNodeQueryField.__annotations__, # pylint: disable=E1101
658 }.items()
659 }
661 class Config: # pylint: disable=too-few-public-methods
662 """
663 Do not allow fields other than the ones defined here.
664 """
666 extra = Extra.forbid
669class UpsertMaterializationConfig(BaseSQLModel):
670 """
671 An upsert object for materialization configs
672 """
674 engine_name: str
675 engine_version: str
676 config: str
679#
680# Response output objects
681#
684class OutputModel(BaseModel):
685 """
686 An output model with the ability to flatten fields. When fields are created with
687 `Field(flatten=True)`, the field's values will be automatically flattened into the
688 parent output model.
689 """
691 def _iter(self, *args, to_dict: bool = False, **kwargs):
692 for dict_key, value in super()._iter(to_dict, *args, **kwargs):
693 if to_dict and self.__fields__[dict_key].field_info.extra.get(
694 "flatten",
695 False,
696 ):
697 assert isinstance(value, dict)
698 for key, val in value.items():
699 yield key, val
700 else:
701 yield dict_key, value
704class TableOutput(SQLModel):
705 """
706 Output for table information.
707 """
709 id: Optional[int]
710 catalog: Optional[Catalog]
711 schema_: Optional[str]
712 table: Optional[str]
713 database: Optional[Database]
716class MaterializationConfigOutput(SQLModel):
717 """
718 Output for materialization config.
719 """
721 engine: EngineInfo
722 config: str
725class NodeRevisionOutput(SQLModel):
726 """
727 Output for a node revision with information about columns and if it is a metric.
728 """
730 id: int = Field(alias="node_revision_id")
731 node_id: int
732 type: NodeType
733 name: str
734 display_name: str
735 version: str
736 status: NodeStatus
737 mode: NodeMode
738 catalog: Optional[Catalog]
739 schema_: Optional[str]
740 table: Optional[str]
741 description: str = ""
742 query: Optional[str] = None
743 availability: Optional[AvailabilityState] = None
744 columns: List[ColumnOutput]
745 updated_at: UTCDatetime
746 materialization_configs: List[MaterializationConfigOutput]
747 parents: List[NodeNameOutput]
749 class Config: # pylint: disable=missing-class-docstring,too-few-public-methods
750 allow_population_by_field_name = True
753class NodeOutput(OutputModel):
754 """
755 Output for a node that shows the current revision.
756 """
758 namespace: str
759 current: NodeRevisionOutput = PydanticField(flatten=True)
760 created_at: UTCDatetime
761 tags: List["Tag"] = []
764class NodeValidation(SQLModel):
765 """
766 A validation of a provided node definition
767 """
769 message: str
770 status: NodeStatus
771 node_revision: NodeRevision
772 dependencies: List[NodeRevisionOutput]
773 columns: List[Column]