Coverage for src / invariant / graph.py: 88.00%

75 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-08 09:24 +0000

1"""GraphResolver for parsing, validating, and sorting DAGs.""" 

2 

3from typing import TYPE_CHECKING 

4 

5from invariant.node import Node, SubGraphNode, SwitchNode 

6 

7if TYPE_CHECKING: 

8 from invariant.registry import OpRegistry 

9 

10# Graph may contain regular nodes, subgraph nodes, or lazy switch nodes. 

11GraphVertex = Node | SubGraphNode | SwitchNode 

12Graph = dict[str, GraphVertex] 

13 

14 

15def _switch_targets(node: SwitchNode) -> list[str]: 

16 """Return switch branch targets in deterministic order.""" 

17 targets = [node.cases[key] for key in sorted(node.cases)] 

18 if node.default is not None: 

19 targets.append(node.default) 

20 return targets 

21 

22 

23def _graph_deps(node: GraphVertex) -> list[str]: 

24 """Return declared dependency edges for a vertex.""" 

25 return list(node.deps) 

26 

27 

28class GraphResolver: 

29 """Responsible for parsing graph definitions and ensuring valid DAGs. 

30 

31 Handles: 

32 - Cycle detection 

33 - Validation (missing dependencies, missing ops) 

34 - Topological sorting 

35 """ 

36 

37 def __init__(self, registry: "OpRegistry | None" = None) -> None: 

38 """Initialize GraphResolver. 

39 

40 Args: 

41 registry: Optional OpRegistry for validating that ops exist. 

42 If None, op validation is skipped. 

43 """ 

44 self.registry = registry 

45 

46 def validate(self, graph: Graph, context_keys: set[str] | None = None) -> None: 

47 """Validate a graph definition. 

48 

49 Checks: 

50 - All node dependencies exist in the graph or in context 

51 - All switch branch targets exist in the graph 

52 - All referenced ops are registered (if registry provided; Node only) 

53 - No cycles exist across declared dependencies 

54 

55 Args: 

56 graph: Dictionary mapping node IDs to graph vertices. 

57 context_keys: Optional set of external dependency keys (from context). 

58 Dependencies not in the graph are allowed if they are in 

59 context. 

60 

61 Raises: 

62 ValueError: If validation fails (missing deps, missing ops, cycles). 

63 """ 

64 # Check all dependencies exist 

65 node_ids = set(graph.keys()) 

66 context_keys = context_keys or set() 

67 for node_id, node in graph.items(): 

68 for dep in node.deps: 

69 if dep not in node_ids and dep not in context_keys: 

70 raise ValueError( 

71 f"Node '{node_id}' has dependency '{dep}' that doesn't " 

72 "exist in graph " 

73 f"or context. Available: graph={sorted(node_ids)}, " 

74 f"context={sorted(context_keys)}" 

75 ) 

76 if isinstance(node, SwitchNode): 

77 for target in _switch_targets(node): 

78 if target not in node_ids: 

79 raise ValueError( 

80 f"SwitchNode '{node_id}' targets '{target}' which " 

81 f"doesn't exist in graph. Available: {sorted(node_ids)}" 

82 ) 

83 

84 # Check all ops are registered (if registry provided); only Node has op_name 

85 if self.registry: 

86 for node_id, node in graph.items(): 

87 if isinstance(node, Node) and not self.registry.has(node.op_name): 

88 raise ValueError( 

89 f"Node '{node_id}' references unregistered op " 

90 f"'{node.op_name}'" 

91 ) 

92 

93 # Check for cycles (excluding context dependencies) 

94 if self._has_cycle(graph, context_keys=context_keys): 

95 raise ValueError("Graph contains cycles") 

96 

97 def _has_cycle(self, graph: Graph, context_keys: set[str] | None = None) -> bool: 

98 """Detect cycles in the graph using DFS. 

99 

100 Args: 

101 graph: Dictionary mapping node IDs to graph vertices. 

102 context_keys: Optional set of external dependency keys (from context). 

103 These are excluded from cycle detection. 

104 

105 Returns: 

106 True if cycle exists, False otherwise. 

107 """ 

108 node_ids = set(graph.keys()) 

109 WHITE = 0 

110 GRAY = 1 

111 BLACK = 2 

112 

113 color: dict[str, int] = {node_id: WHITE for node_id in node_ids} 

114 

115 def dfs(node_id: str) -> bool: 

116 """DFS helper that returns True if cycle found.""" 

117 if node_id not in node_ids: 

118 return False 

119 if color[node_id] == GRAY: 

120 return True 

121 if color[node_id] == BLACK: 

122 return False 

123 

124 color[node_id] = GRAY 

125 node = graph[node_id] 

126 for dep in _graph_deps(node): 

127 # Only check dependencies that are in the graph (not context) 

128 if dep in node_ids and dfs(dep): 

129 return True 

130 

131 color[node_id] = BLACK 

132 return False 

133 

134 # Check all nodes (handles disconnected components) 

135 return any(color[node_id] == WHITE and dfs(node_id) for node_id in node_ids) 

136 

137 def topological_sort( 

138 self, graph: Graph, context_keys: set[str] | None = None 

139 ) -> list[str]: 

140 """Topologically sort the graph's declared dependency edges using DFS. 

141 

142 Args: 

143 graph: Dictionary mapping node IDs to graph vertices. 

144 context_keys: Optional set of external dependency keys (from context). 

145 These are excluded from topological sorting. 

146 

147 Returns: 

148 List of node IDs in topological order (dependencies before dependents). 

149 

150 Raises: 

151 ValueError: If graph contains cycles. 

152 """ 

153 node_ids = set(graph.keys()) 

154 color: dict[str, int] = {node_id: 0 for node_id in node_ids} 

155 result: list[str] = [] 

156 

157 def visit(node_id: str) -> None: 

158 if node_id not in node_ids: 

159 return 

160 if color[node_id] == 1: 

161 raise ValueError("Graph contains cycles (topological sort impossible)") 

162 if color[node_id] == 2: 

163 return 

164 

165 color[node_id] = 1 

166 for dep in _graph_deps(graph[node_id]): 

167 if dep in node_ids: 

168 visit(dep) 

169 color[node_id] = 2 

170 result.append(node_id) 

171 

172 for node_id in sorted(node_ids): 

173 visit(node_id) 

174 

175 return result 

176 

177 def resolve(self, graph: Graph, context_keys: set[str] | None = None) -> list[str]: 

178 """Validate and topologically sort a graph. 

179 

180 Convenience method that validates then sorts. 

181 

182 Args: 

183 graph: Dictionary mapping node IDs to graph vertices. 

184 context_keys: Optional set of external dependency keys (from context). 

185 

186 Returns: 

187 List of node IDs in topological order. 

188 

189 Raises: 

190 ValueError: If validation fails or cycles exist. 

191 """ 

192 self.validate(graph, context_keys=context_keys) 

193 return self.topological_sort(graph, context_keys=context_keys)