Coverage for src / invariant / node.py: 92.86%

84 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-06 12:18 +0000

1"""Graph vertex classes.""" 

2 

3from __future__ import annotations 

4 

5from dataclasses import dataclass 

6from typing import Any 

7 

8from invariant.params import ref 

9 

10 

11def _collect_refs(value: Any) -> list[ref]: 

12 """Recursively collect all ref() markers from a value.""" 

13 refs: list[ref] = [] 

14 if isinstance(value, ref): 

15 refs.append(value) 

16 elif isinstance(value, dict): 

17 for v in value.values(): 

18 refs.extend(_collect_refs(v)) 

19 elif isinstance(value, (list, tuple)): 

20 for item in value: 

21 refs.extend(_collect_refs(item)) 

22 return refs 

23 

24 

25@dataclass(frozen=True) 

26class Node: 

27 """A vertex in the DAG defining what operation to perform. 

28 

29 Attributes: 

30 op_name: The name of the operation to execute (must be registered). 

31 params: Static parameters for this node (dict of parameter name -> value). 

32 May contain ref() and cel() markers, and ${...} string interpolation. 

33 deps: List of node IDs that this node depends on (upstream dependencies). 

34 cache: When True (default), the node's result is cached unless it depends on 

35 an ephemeral upstream node. When False, the op is always executed, the 

36 result is never stored, and cache bypass cascades to downstream nodes. 

37 """ 

38 

39 op_name: str 

40 params: dict[str, Any] 

41 deps: list[str] 

42 cache: bool = True 

43 

44 def __post_init__(self) -> None: 

45 """Validate node configuration.""" 

46 if not self.op_name: 

47 raise ValueError("op_name cannot be empty") 

48 if not isinstance(self.params, dict): 

49 raise ValueError("params must be a dictionary") 

50 if not isinstance(self.deps, list): 

51 raise ValueError("deps must be a list") 

52 

53 # Validate that all ref() markers reference declared dependencies 

54 self._validate_refs() 

55 

56 def _validate_refs(self) -> None: 

57 """Validate that all ref() markers in params reference declared dependencies.""" 

58 deps_set = set(self.deps) 

59 refs = _collect_refs(self.params) 

60 

61 for ref_marker in refs: 

62 if ref_marker.dep not in deps_set: 

63 raise ValueError( 

64 f"ref('{ref_marker.dep}') in params references undeclared " 

65 "dependency. " 

66 f"Declared deps: {self.deps}. " 

67 f"Add '{ref_marker.dep}' to deps list." 

68 ) 

69 

70 

71@dataclass(frozen=True) 

72class SubGraphNode: 

73 """A vertex that expands to an internal DAG at execution time. 

74 

75 Has deps and params like Node, but carries an internal graph and output node ID 

76 instead of an op_name. The executor runs the internal graph with resolved params 

77 as context and returns the designated output node's artifact. 

78 """ 

79 

80 params: dict[str, Any] 

81 deps: list[str] 

82 graph: dict[str, Node | SubGraphNode | SwitchNode] 

83 output: str 

84 

85 def __post_init__(self) -> None: 

86 """Validate SubGraphNode configuration.""" 

87 if not isinstance(self.params, dict): 

88 raise ValueError("params must be a dictionary") 

89 if not isinstance(self.deps, list): 

90 raise ValueError("deps must be a list") 

91 if not isinstance(self.graph, dict): 

92 raise ValueError("graph must be a dictionary") 

93 if self.output not in self.graph: 

94 raise ValueError( 

95 f"output '{self.output}' must be a key in graph. " 

96 f"Graph keys: {list(self.graph.keys())}." 

97 ) 

98 self._validate_refs() 

99 

100 def _validate_refs(self) -> None: 

101 """Validate that all ref() markers in params reference declared dependencies.""" 

102 deps_set = set(self.deps) 

103 refs = _collect_refs(self.params) 

104 for ref_marker in refs: 

105 if ref_marker.dep not in deps_set: 

106 raise ValueError( 

107 f"ref('{ref_marker.dep}') in params references undeclared " 

108 "dependency. " 

109 f"Declared deps: {self.deps}. " 

110 f"Add '{ref_marker.dep}' to deps list." 

111 ) 

112 

113 

114@dataclass(frozen=True) 

115class SwitchNode: 

116 """A lazy conditional vertex that selects one graph-local branch target. 

117 

118 The selector is resolved from declared deps. Its normalized value selects a 

119 node ID from cases, or default when no case matches. Branch targets are not 

120 dependencies: they are graph-local execution targets resolved by the graph 

121 executor. 

122 """ 

123 

124 selector: Any 

125 deps: list[str] 

126 cases: dict[str, str] 

127 default: str | None = None 

128 

129 def __post_init__(self) -> None: 

130 """Validate switch configuration.""" 

131 if not isinstance(self.deps, list): 

132 raise ValueError("deps must be a list") 

133 if not isinstance(self.cases, dict): 

134 raise ValueError("cases must be a dictionary") 

135 if not self.cases: 

136 raise ValueError("cases must not be empty") 

137 for key, target in self.cases.items(): 

138 if not isinstance(key, str): 

139 raise ValueError("cases keys must be strings") 

140 if not isinstance(target, str) or not target: 

141 raise ValueError("cases values must be non-empty strings") 

142 if self.default is not None and ( 

143 not isinstance(self.default, str) or not self.default 

144 ): 

145 raise ValueError("default must be a non-empty string when present") 

146 self._validate_refs() 

147 

148 def _validate_refs(self) -> None: 

149 """Validate that all ref() markers in selector reference declared deps.""" 

150 deps_set = set(self.deps) 

151 refs = _collect_refs(self.selector) 

152 for ref_marker in refs: 

153 if ref_marker.dep not in deps_set: 

154 raise ValueError( 

155 f"ref('{ref_marker.dep}') in selector references undeclared " 

156 f"dependency. Declared deps: {self.deps}. " 

157 f"Add '{ref_marker.dep}' to deps list." 

158 )