Coverage for src / invariant / graph.py: 88.00%
75 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-06 12:18 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-06 12:18 +0000
1"""GraphResolver for parsing, validating, and sorting DAGs."""
3from typing import TYPE_CHECKING
5from invariant.node import Node, SubGraphNode, SwitchNode
7if TYPE_CHECKING:
8 from invariant.registry import OpRegistry
10# Graph may contain regular nodes, subgraph nodes, or lazy switch nodes.
11GraphVertex = Node | SubGraphNode | SwitchNode
12Graph = dict[str, GraphVertex]
15def _switch_targets(node: SwitchNode) -> list[str]:
16 """Return switch branch targets in deterministic order."""
17 targets = [node.cases[key] for key in sorted(node.cases)]
18 if node.default is not None:
19 targets.append(node.default)
20 return targets
23def _graph_deps(node: GraphVertex) -> list[str]:
24 """Return declared dependency edges for a vertex."""
25 return list(node.deps)
28class GraphResolver:
29 """Responsible for parsing graph definitions and ensuring valid DAGs.
31 Handles:
32 - Cycle detection
33 - Validation (missing dependencies, missing ops)
34 - Topological sorting
35 """
37 def __init__(self, registry: "OpRegistry | None" = None) -> None:
38 """Initialize GraphResolver.
40 Args:
41 registry: Optional OpRegistry for validating that ops exist.
42 If None, op validation is skipped.
43 """
44 self.registry = registry
46 def validate(self, graph: Graph, context_keys: set[str] | None = None) -> None:
47 """Validate a graph definition.
49 Checks:
50 - All node dependencies exist in the graph or in context
51 - All switch branch targets exist in the graph
52 - All referenced ops are registered (if registry provided; Node only)
53 - No cycles exist across declared dependencies
55 Args:
56 graph: Dictionary mapping node IDs to graph vertices.
57 context_keys: Optional set of external dependency keys (from context).
58 Dependencies not in the graph are allowed if they are in
59 context.
61 Raises:
62 ValueError: If validation fails (missing deps, missing ops, cycles).
63 """
64 # Check all dependencies exist
65 node_ids = set(graph.keys())
66 context_keys = context_keys or set()
67 for node_id, node in graph.items():
68 for dep in node.deps:
69 if dep not in node_ids and dep not in context_keys:
70 raise ValueError(
71 f"Node '{node_id}' has dependency '{dep}' that doesn't "
72 "exist in graph "
73 f"or context. Available: graph={sorted(node_ids)}, "
74 f"context={sorted(context_keys)}"
75 )
76 if isinstance(node, SwitchNode):
77 for target in _switch_targets(node):
78 if target not in node_ids:
79 raise ValueError(
80 f"SwitchNode '{node_id}' targets '{target}' which "
81 f"doesn't exist in graph. Available: {sorted(node_ids)}"
82 )
84 # Check all ops are registered (if registry provided); only Node has op_name
85 if self.registry:
86 for node_id, node in graph.items():
87 if isinstance(node, Node) and not self.registry.has(node.op_name):
88 raise ValueError(
89 f"Node '{node_id}' references unregistered op "
90 f"'{node.op_name}'"
91 )
93 # Check for cycles (excluding context dependencies)
94 if self._has_cycle(graph, context_keys=context_keys):
95 raise ValueError("Graph contains cycles")
97 def _has_cycle(self, graph: Graph, context_keys: set[str] | None = None) -> bool:
98 """Detect cycles in the graph using DFS.
100 Args:
101 graph: Dictionary mapping node IDs to graph vertices.
102 context_keys: Optional set of external dependency keys (from context).
103 These are excluded from cycle detection.
105 Returns:
106 True if cycle exists, False otherwise.
107 """
108 node_ids = set(graph.keys())
109 WHITE = 0
110 GRAY = 1
111 BLACK = 2
113 color: dict[str, int] = {node_id: WHITE for node_id in node_ids}
115 def dfs(node_id: str) -> bool:
116 """DFS helper that returns True if cycle found."""
117 if node_id not in node_ids:
118 return False
119 if color[node_id] == GRAY:
120 return True
121 if color[node_id] == BLACK:
122 return False
124 color[node_id] = GRAY
125 node = graph[node_id]
126 for dep in _graph_deps(node):
127 # Only check dependencies that are in the graph (not context)
128 if dep in node_ids and dfs(dep):
129 return True
131 color[node_id] = BLACK
132 return False
134 # Check all nodes (handles disconnected components)
135 return any(color[node_id] == WHITE and dfs(node_id) for node_id in node_ids)
137 def topological_sort(
138 self, graph: Graph, context_keys: set[str] | None = None
139 ) -> list[str]:
140 """Topologically sort the graph's declared dependency edges using DFS.
142 Args:
143 graph: Dictionary mapping node IDs to graph vertices.
144 context_keys: Optional set of external dependency keys (from context).
145 These are excluded from topological sorting.
147 Returns:
148 List of node IDs in topological order (dependencies before dependents).
150 Raises:
151 ValueError: If graph contains cycles.
152 """
153 node_ids = set(graph.keys())
154 color: dict[str, int] = {node_id: 0 for node_id in node_ids}
155 result: list[str] = []
157 def visit(node_id: str) -> None:
158 if node_id not in node_ids:
159 return
160 if color[node_id] == 1:
161 raise ValueError("Graph contains cycles (topological sort impossible)")
162 if color[node_id] == 2:
163 return
165 color[node_id] = 1
166 for dep in _graph_deps(graph[node_id]):
167 if dep in node_ids:
168 visit(dep)
169 color[node_id] = 2
170 result.append(node_id)
172 for node_id in sorted(node_ids):
173 visit(node_id)
175 return result
177 def resolve(self, graph: Graph, context_keys: set[str] | None = None) -> list[str]:
178 """Validate and topologically sort a graph.
180 Convenience method that validates then sorts.
182 Args:
183 graph: Dictionary mapping node IDs to graph vertices.
184 context_keys: Optional set of external dependency keys (from context).
186 Returns:
187 List of node IDs in topological order.
189 Raises:
190 ValueError: If validation fails or cycles exist.
191 """
192 self.validate(graph, context_keys=context_keys)
193 return self.topological_sort(graph, context_keys=context_keys)