Coverage for smartmdao / visualization.py: 100%

141 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-02 20:01 +0200

1import os 

2import logging 

3from typing import List, Set, Dict, Literal, Optional, Tuple 

4 

5from .models import Step 

6 

7# Initialize module-level logger 

8logger = logging.getLogger(__name__) 

9 

10# Try importing graphviz; handle absence gracefully 

11try: 

12 import graphviz 

13except ImportError: 

14 graphviz = None 

15 logger.warning("Graphviz not found. Visualization features will be unavailable.") 

16 

17 

18class PipelineVisualizer: 

19 """ 

20 A modern, modular visualizer for the Pipeline using Graphviz. 

21 Focuses on standardizing workflow visualization with clear separation of concerns. 

22 """ 

23 

24 # --- Standard Palette (Material Design Pastels) --- 

25 STYLE_INPUT = { 

26 "shape": "parallelogram", 

27 "style": "filled", 

28 "fillcolor": "#E3F2FD", # Blue 50 

29 "color": "#1565C0", # Blue 800 

30 "penwidth": "1.5", 

31 "margin": "0.2" 

32 } 

33 STYLE_STEP = { 

34 "shape": "component", 

35 "style": "filled", 

36 "fillcolor": "#FFF3E0", # Orange 50 

37 "color": "#EF6C00", # Orange 800 

38 "penwidth": "1.5", 

39 "margin": "0.3" 

40 } 

41 STYLE_INTERMEDIATE = { 

42 "shape": "ellipse", 

43 "style": "filled", 

44 "fillcolor": "#F5F5F5", # Grey 100 

45 "color": "#757575", # Grey 600 

46 "penwidth": "1.0", 

47 "height": "0.4" 

48 } 

49 STYLE_FINAL = { 

50 "shape": "parallelogram", 

51 "style": "filled", 

52 "fillcolor": "#E8F5E9", # Green 50 

53 "color": "#2E7D32", # Green 800 

54 "penwidth": "2.0", # Thicker border for emphasis 

55 "peripheries": "2", # Double border 

56 "margin": "0.2" 

57 } 

58 STYLE_MISSING = { 

59 "shape": "hexagon", 

60 "style": "filled", 

61 "fillcolor": "#FFEBEE", # Red 50 

62 "color": "#C62828", # Red 800 

63 "penwidth": "2.0" 

64 } 

65 

66 def __init__( 

67 self, 

68 steps: List[Step], 

69 input_keys: Set[str], 

70 orientation: Literal['TB', 'LR'] = 'TB' 

71 ): 

72 if graphviz is None: 

73 raise ImportError( 

74 "The 'graphviz' library is required for visualization. " 

75 "Please install it using: pip install graphviz" 

76 ) 

77 

78 self.steps = sorted(steps, key=lambda s: s.name) 

79 self.input_keys = input_keys 

80 self.orientation = orientation 

81 

82 # Initialize the graph 

83 self.dot = graphviz.Digraph(comment='Pipeline Graph') 

84 self._setup_graph_attributes() 

85 

86 def _setup_graph_attributes(self): 

87 """Configures global graph styling for a professional look.""" 

88 self.dot.attr(rankdir=self.orientation) 

89 self.dot.attr(compound='true') # Allow edges between clusters 

90 

91 # Global typography 

92 self.dot.attr('node', fontname='Helvetica', fontsize='11') 

93 self.dot.attr('edge', fontname='Helvetica', fontsize='9', color='#616161') 

94 

95 # 'ortho' provides clean, rect-linear lines suitable for technical diagrams 

96 # 'splines'='polyline' is also a good option if ortho gets messy. 

97 self.dot.attr(splines='ortho') 

98 

99 def build(self, graph_type: Literal["flow", "bipartite"] = "flow") -> "PipelineVisualizer": 

100 """ 

101 Builds the nodes and edges. 

102 Note: The 'bipartite' (Data Flow) view is recommended for detailed analysis  

103 of Inputs vs Intermediates vs Finals. 

104 """ 

105 if graph_type == "bipartite": 

106 self._build_bipartite_standard() 

107 else: 

108 self._build_flow_standard() 

109 return self 

110 

111 def render(self, output_path: Optional[str] = None, view: bool = True): 

112 """ 

113 Renders the graph to a file or temporary view. 

114 """ 

115 try: 

116 if output_path: 

117 filename, ext = os.path.splitext(output_path) 

118 fmt = ext.lstrip('.').lower() if ext else 'pdf' 

119 out_file = self.dot.render(filename, format=fmt, cleanup=True, view=view) 

120 if not view: 

121 logger.info(f"Pipeline diagram saved to: {out_file}") 

122 else: 

123 self.dot.view(cleanup=True) 

124 logger.info("Pipeline diagram opened in viewer.") 

125 except Exception as e: 

126 logger.error(f"Graph rendered successfully, but viewer failed: {e}") 

127 if output_path: 

128 logger.info(f"File saved at: {output_path}") 

129 

130 # --- Classification Logic --- 

131 

132 def _analyze_variables(self) -> Tuple[Set[str], Set[str], Set[str], Set[str], Dict[str, Step]]: 

133 """ 

134 Categorizes all variables in the pipeline. 

135 Returns: (inputs, intermediates, finals, missing, producer_map) 

136 """ 

137 producers_map = {} 

138 consumed = set() 

139 produced = set() 

140 

141 for step in self.steps: 

142 # Outputs 

143 for out in step.resolve_output_names(): 

144 producers_map[out] = step 

145 produced.add(out) 

146 

147 # Inputs (Use unwrapped signature) 

148 sig = step.get_signature() 

149 for param in sig.parameters: 

150 consumed.add(param) 

151 

152 # 2. Categorize 

153 # Inputs: Variables consumed but NOT produced internally. 

154 # (We strictly use input_keys to validate, but graph logic relies on structural dependency) 

155 real_inputs = {v for v in consumed if v not in produced} 

156 

157 # Missing: Required inputs that are NOT in the provided input_keys 

158 missing = {v for v in real_inputs if v not in self.input_keys} 

159 

160 # Valid Inputs: Real inputs that exist in input_keys 

161 valid_inputs = real_inputs.intersection(self.input_keys) 

162 

163 # Intermediates: Produced AND Consumed 

164 intermediates = produced.intersection(consumed) 

165 

166 # Finals: Produced but NEVER Consumed 

167 finals = produced - consumed 

168 

169 return valid_inputs, intermediates, finals, missing, producers_map 

170 

171 # --- Bipartite (Data Flow) Strategy --- 

172 

173 def _build_bipartite_standard(self): 

174 """ 

175 Constructs a Data Flow Diagram (DFD). 

176 Strictly separates: Input Nodes -> Step Nodes -> Intermediate Nodes -> Step Nodes -> Final Nodes. 

177 """ 

178 inputs, intermediates, finals, missing, producers = self._analyze_variables() 

179 

180 # 1. Draw Inputs (Rank Source to force top/left) 

181 with self.dot.subgraph(name='cluster_inputs') as c: 

182 c.attr(rank='source', style='invis') # Invisible container for grouping 

183 for var in inputs: 

184 self._add_node(c, f"Var_{var}", var, self.STYLE_INPUT) 

185 for var in missing: 

186 self._add_node(c, f"Missing_{var}", f"{var} (?)", self.STYLE_MISSING) 

187 

188 # 2. Draw Finals (Rank Sink to force bottom/right) 

189 with self.dot.subgraph(name='cluster_finals') as c: 

190 c.attr(rank='sink', style='invis') 

191 for var in finals: 

192 self._add_node(c, f"Var_{var}", var, self.STYLE_FINAL) 

193 

194 # 3. Draw Intermediates 

195 for var in intermediates: 

196 self._add_node(self.dot, f"Var_{var}", var, self.STYLE_INTERMEDIATE) 

197 

198 # 4. Draw Steps 

199 for step in self.steps: 

200 self._add_step_node(self.dot, step) 

201 

202 # 5. Draw Edges 

203 for step in self.steps: 

204 step_id = self._node_id(step) 

205 sig = step.get_signature() 

206 

207 # Inputs to Step 

208 for param in sig.parameters: 

209 if param in missing: 

210 self.dot.edge(f"Missing_{param}", step_id, style="dotted", color="#D32F2F") 

211 else: 

212 # It's either a valid input or an intermediate/produced var 

213 var_id = f"Var_{param}" 

214 self.dot.edge(var_id, step_id) 

215 

216 # Step to Outputs 

217 for out in step.resolve_output_names(): 

218 var_id = f"Var_{out}" 

219 self.dot.edge(step_id, var_id) 

220 

221 # --- Flow Strategy (Process Flow) --- 

222 

223 def _build_flow_standard(self): 

224 """ 

225 Constructs a Process Flow Diagram. 

226 Focuses on Steps. Data is shown as explicit nodes ONLY if it is an Input or Final Output. 

227 Intermediates are labels on edges. 

228 """ 

229 inputs, intermediates, finals, missing, producers = self._analyze_variables() 

230 step_indices = {step: i for i, step in enumerate(self.steps)} 

231 

232 # 1. Draw Inputs 

233 with self.dot.subgraph(name='cluster_inputs') as c: 

234 c.attr(rank='source', style='invis') 

235 for var in inputs: 

236 self._add_node(c, f"Input_{var}", var, self.STYLE_INPUT) 

237 for var in missing: 

238 self._add_node(c, f"Missing_{var}", f"{var} (?)", self.STYLE_MISSING) 

239 

240 # 2. Draw Finals 

241 with self.dot.subgraph(name='cluster_finals') as c: 

242 c.attr(rank='sink', style='invis') 

243 for var in finals: 

244 # Note: In flow view, we link the step directly to this final node 

245 self._add_node(c, f"Final_{var}", var, self.STYLE_FINAL) 

246 

247 # 3. Draw Steps 

248 for step in self.steps: 

249 self._add_step_node(self.dot, step) 

250 

251 # 4. Draw Edges 

252 for step in self.steps: 

253 step_id = self._node_id(step) 

254 sig = step.get_signature() 

255 

256 for param in sig.parameters: 

257 # Case A: Missing 

258 if param in missing: 

259 self.dot.edge(f"Missing_{param}", step_id, style="dotted", color="#D32F2F") 

260 

261 # Case B: External Input 

262 elif param in inputs: 

263 self.dot.edge(f"Input_{param}", step_id) 

264 

265 # Case C: Produced by another step (Intermediate) 

266 elif param in producers: 

267 producer = producers[param] 

268 prod_id = self._node_id(producer) 

269 

270 # Cycle Detection 

271 is_feedback = step_indices[producer] >= step_indices[step] 

272 style = "dashed" if is_feedback else "solid" 

273 color = "#D32F2F" if is_feedback else "#616161" 

274 

275 self.dot.edge(prod_id, step_id, label=param, style=style, color=color) 

276 

277 # 5. Link Steps to Final Outputs 

278 for step in self.steps: 

279 step_id = self._node_id(step) 

280 for out in step.resolve_output_names(): 

281 if out in finals: 

282 self.dot.edge(step_id, f"Final_{out}") 

283 

284 # --- Helpers --- 

285 

286 def _node_id(self, step: Step) -> str: 

287 return f"Step_{id(step)}" 

288 

289 def _add_node(self, graph, node_id: str, label: str, style_dict: Dict[str, str]): 

290 """Generic node adder using a style dictionary.""" 

291 # Make a copy to avoid mutating the class constant 

292 attrs = style_dict.copy() 

293 attrs['label'] = label 

294 graph.node(node_id, **attrs) 

295 

296 def _add_step_node(self, graph, step: Step): 

297 """Adds a function/step node.""" 

298 attrs = self.STYLE_STEP.copy() 

299 # HTML label for bold text 

300 attrs['label'] = f"<<b>{step.name}</b>>" 

301 graph.node(self._node_id(step), **attrs) 

302 

303# API adapter 

304def visualize_pipeline( 

305 steps: List[Step], 

306 inputs: Set[str], 

307 output_path: Optional[str] = None, 

308 orientation: str = "TD", 

309 graph_type: Literal["flow", "bipartite"] = "flow", 

310 view: bool = True 

311): 

312 viz = PipelineVisualizer(steps, inputs, orientation) 

313 viz.build(graph_type).render(output_path, view=view)