Coverage for smartmdao / visualization.py: 100%
141 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 20:01 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 20:01 +0200
1import os
2import logging
3from typing import List, Set, Dict, Literal, Optional, Tuple
5from .models import Step
7# Initialize module-level logger
8logger = logging.getLogger(__name__)
10# Try importing graphviz; handle absence gracefully
11try:
12 import graphviz
13except ImportError:
14 graphviz = None
15 logger.warning("Graphviz not found. Visualization features will be unavailable.")
18class PipelineVisualizer:
19 """
20 A modern, modular visualizer for the Pipeline using Graphviz.
21 Focuses on standardizing workflow visualization with clear separation of concerns.
22 """
24 # --- Standard Palette (Material Design Pastels) ---
25 STYLE_INPUT = {
26 "shape": "parallelogram",
27 "style": "filled",
28 "fillcolor": "#E3F2FD", # Blue 50
29 "color": "#1565C0", # Blue 800
30 "penwidth": "1.5",
31 "margin": "0.2"
32 }
33 STYLE_STEP = {
34 "shape": "component",
35 "style": "filled",
36 "fillcolor": "#FFF3E0", # Orange 50
37 "color": "#EF6C00", # Orange 800
38 "penwidth": "1.5",
39 "margin": "0.3"
40 }
41 STYLE_INTERMEDIATE = {
42 "shape": "ellipse",
43 "style": "filled",
44 "fillcolor": "#F5F5F5", # Grey 100
45 "color": "#757575", # Grey 600
46 "penwidth": "1.0",
47 "height": "0.4"
48 }
49 STYLE_FINAL = {
50 "shape": "parallelogram",
51 "style": "filled",
52 "fillcolor": "#E8F5E9", # Green 50
53 "color": "#2E7D32", # Green 800
54 "penwidth": "2.0", # Thicker border for emphasis
55 "peripheries": "2", # Double border
56 "margin": "0.2"
57 }
58 STYLE_MISSING = {
59 "shape": "hexagon",
60 "style": "filled",
61 "fillcolor": "#FFEBEE", # Red 50
62 "color": "#C62828", # Red 800
63 "penwidth": "2.0"
64 }
66 def __init__(
67 self,
68 steps: List[Step],
69 input_keys: Set[str],
70 orientation: Literal['TB', 'LR'] = 'TB'
71 ):
72 if graphviz is None:
73 raise ImportError(
74 "The 'graphviz' library is required for visualization. "
75 "Please install it using: pip install graphviz"
76 )
78 self.steps = sorted(steps, key=lambda s: s.name)
79 self.input_keys = input_keys
80 self.orientation = orientation
82 # Initialize the graph
83 self.dot = graphviz.Digraph(comment='Pipeline Graph')
84 self._setup_graph_attributes()
86 def _setup_graph_attributes(self):
87 """Configures global graph styling for a professional look."""
88 self.dot.attr(rankdir=self.orientation)
89 self.dot.attr(compound='true') # Allow edges between clusters
91 # Global typography
92 self.dot.attr('node', fontname='Helvetica', fontsize='11')
93 self.dot.attr('edge', fontname='Helvetica', fontsize='9', color='#616161')
95 # 'ortho' provides clean, rect-linear lines suitable for technical diagrams
96 # 'splines'='polyline' is also a good option if ortho gets messy.
97 self.dot.attr(splines='ortho')
99 def build(self, graph_type: Literal["flow", "bipartite"] = "flow") -> "PipelineVisualizer":
100 """
101 Builds the nodes and edges.
102 Note: The 'bipartite' (Data Flow) view is recommended for detailed analysis
103 of Inputs vs Intermediates vs Finals.
104 """
105 if graph_type == "bipartite":
106 self._build_bipartite_standard()
107 else:
108 self._build_flow_standard()
109 return self
111 def render(self, output_path: Optional[str] = None, view: bool = True):
112 """
113 Renders the graph to a file or temporary view.
114 """
115 try:
116 if output_path:
117 filename, ext = os.path.splitext(output_path)
118 fmt = ext.lstrip('.').lower() if ext else 'pdf'
119 out_file = self.dot.render(filename, format=fmt, cleanup=True, view=view)
120 if not view:
121 logger.info(f"Pipeline diagram saved to: {out_file}")
122 else:
123 self.dot.view(cleanup=True)
124 logger.info("Pipeline diagram opened in viewer.")
125 except Exception as e:
126 logger.error(f"Graph rendered successfully, but viewer failed: {e}")
127 if output_path:
128 logger.info(f"File saved at: {output_path}")
130 # --- Classification Logic ---
132 def _analyze_variables(self) -> Tuple[Set[str], Set[str], Set[str], Set[str], Dict[str, Step]]:
133 """
134 Categorizes all variables in the pipeline.
135 Returns: (inputs, intermediates, finals, missing, producer_map)
136 """
137 producers_map = {}
138 consumed = set()
139 produced = set()
141 for step in self.steps:
142 # Outputs
143 for out in step.resolve_output_names():
144 producers_map[out] = step
145 produced.add(out)
147 # Inputs (Use unwrapped signature)
148 sig = step.get_signature()
149 for param in sig.parameters:
150 consumed.add(param)
152 # 2. Categorize
153 # Inputs: Variables consumed but NOT produced internally.
154 # (We strictly use input_keys to validate, but graph logic relies on structural dependency)
155 real_inputs = {v for v in consumed if v not in produced}
157 # Missing: Required inputs that are NOT in the provided input_keys
158 missing = {v for v in real_inputs if v not in self.input_keys}
160 # Valid Inputs: Real inputs that exist in input_keys
161 valid_inputs = real_inputs.intersection(self.input_keys)
163 # Intermediates: Produced AND Consumed
164 intermediates = produced.intersection(consumed)
166 # Finals: Produced but NEVER Consumed
167 finals = produced - consumed
169 return valid_inputs, intermediates, finals, missing, producers_map
171 # --- Bipartite (Data Flow) Strategy ---
173 def _build_bipartite_standard(self):
174 """
175 Constructs a Data Flow Diagram (DFD).
176 Strictly separates: Input Nodes -> Step Nodes -> Intermediate Nodes -> Step Nodes -> Final Nodes.
177 """
178 inputs, intermediates, finals, missing, producers = self._analyze_variables()
180 # 1. Draw Inputs (Rank Source to force top/left)
181 with self.dot.subgraph(name='cluster_inputs') as c:
182 c.attr(rank='source', style='invis') # Invisible container for grouping
183 for var in inputs:
184 self._add_node(c, f"Var_{var}", var, self.STYLE_INPUT)
185 for var in missing:
186 self._add_node(c, f"Missing_{var}", f"{var} (?)", self.STYLE_MISSING)
188 # 2. Draw Finals (Rank Sink to force bottom/right)
189 with self.dot.subgraph(name='cluster_finals') as c:
190 c.attr(rank='sink', style='invis')
191 for var in finals:
192 self._add_node(c, f"Var_{var}", var, self.STYLE_FINAL)
194 # 3. Draw Intermediates
195 for var in intermediates:
196 self._add_node(self.dot, f"Var_{var}", var, self.STYLE_INTERMEDIATE)
198 # 4. Draw Steps
199 for step in self.steps:
200 self._add_step_node(self.dot, step)
202 # 5. Draw Edges
203 for step in self.steps:
204 step_id = self._node_id(step)
205 sig = step.get_signature()
207 # Inputs to Step
208 for param in sig.parameters:
209 if param in missing:
210 self.dot.edge(f"Missing_{param}", step_id, style="dotted", color="#D32F2F")
211 else:
212 # It's either a valid input or an intermediate/produced var
213 var_id = f"Var_{param}"
214 self.dot.edge(var_id, step_id)
216 # Step to Outputs
217 for out in step.resolve_output_names():
218 var_id = f"Var_{out}"
219 self.dot.edge(step_id, var_id)
221 # --- Flow Strategy (Process Flow) ---
223 def _build_flow_standard(self):
224 """
225 Constructs a Process Flow Diagram.
226 Focuses on Steps. Data is shown as explicit nodes ONLY if it is an Input or Final Output.
227 Intermediates are labels on edges.
228 """
229 inputs, intermediates, finals, missing, producers = self._analyze_variables()
230 step_indices = {step: i for i, step in enumerate(self.steps)}
232 # 1. Draw Inputs
233 with self.dot.subgraph(name='cluster_inputs') as c:
234 c.attr(rank='source', style='invis')
235 for var in inputs:
236 self._add_node(c, f"Input_{var}", var, self.STYLE_INPUT)
237 for var in missing:
238 self._add_node(c, f"Missing_{var}", f"{var} (?)", self.STYLE_MISSING)
240 # 2. Draw Finals
241 with self.dot.subgraph(name='cluster_finals') as c:
242 c.attr(rank='sink', style='invis')
243 for var in finals:
244 # Note: In flow view, we link the step directly to this final node
245 self._add_node(c, f"Final_{var}", var, self.STYLE_FINAL)
247 # 3. Draw Steps
248 for step in self.steps:
249 self._add_step_node(self.dot, step)
251 # 4. Draw Edges
252 for step in self.steps:
253 step_id = self._node_id(step)
254 sig = step.get_signature()
256 for param in sig.parameters:
257 # Case A: Missing
258 if param in missing:
259 self.dot.edge(f"Missing_{param}", step_id, style="dotted", color="#D32F2F")
261 # Case B: External Input
262 elif param in inputs:
263 self.dot.edge(f"Input_{param}", step_id)
265 # Case C: Produced by another step (Intermediate)
266 elif param in producers:
267 producer = producers[param]
268 prod_id = self._node_id(producer)
270 # Cycle Detection
271 is_feedback = step_indices[producer] >= step_indices[step]
272 style = "dashed" if is_feedback else "solid"
273 color = "#D32F2F" if is_feedback else "#616161"
275 self.dot.edge(prod_id, step_id, label=param, style=style, color=color)
277 # 5. Link Steps to Final Outputs
278 for step in self.steps:
279 step_id = self._node_id(step)
280 for out in step.resolve_output_names():
281 if out in finals:
282 self.dot.edge(step_id, f"Final_{out}")
284 # --- Helpers ---
286 def _node_id(self, step: Step) -> str:
287 return f"Step_{id(step)}"
289 def _add_node(self, graph, node_id: str, label: str, style_dict: Dict[str, str]):
290 """Generic node adder using a style dictionary."""
291 # Make a copy to avoid mutating the class constant
292 attrs = style_dict.copy()
293 attrs['label'] = label
294 graph.node(node_id, **attrs)
296 def _add_step_node(self, graph, step: Step):
297 """Adds a function/step node."""
298 attrs = self.STYLE_STEP.copy()
299 # HTML label for bold text
300 attrs['label'] = f"<<b>{step.name}</b>>"
301 graph.node(self._node_id(step), **attrs)
303# API adapter
304def visualize_pipeline(
305 steps: List[Step],
306 inputs: Set[str],
307 output_path: Optional[str] = None,
308 orientation: str = "TD",
309 graph_type: Literal["flow", "bipartite"] = "flow",
310 view: bool = True
311):
312 viz = PipelineVisualizer(steps, inputs, orientation)
313 viz.build(graph_type).render(output_path, view=view)