Coverage for smart_pipeline / core.py: 100%

37 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-02 13:37 +0200

1import logging 

2from dataclasses import dataclass, field 

3from typing import Callable, List, Literal 

4 

5from .models import Step 

6from .solvers import Solver, DAGSolver 

7from .visualization import visualize_pipeline 

8 

9# Initialize module-level logger 

10logger = logging.getLogger(__name__) 

11 

12@dataclass 

13class Pipeline: 

14 steps: list[Step] = field(default_factory=list) 

15 solver: Solver = field(default_factory=DAGSolver) 

16 

17 def add(self, fn: Callable, outputs: list[str] = None): 

18 """ 

19 Add a step to the pipeline. 

20 :param fn: The function to execute. 

21 :param outputs: Optional list of variable names this function produces.  

22 """ 

23 step = Step(fn, outputs) 

24 self.steps.append(step) 

25 logger.debug(f"Added step '{step.name}' to pipeline.") 

26 return self 

27 

28 def step(self, fn: Callable = None, *, outputs: List[str] = None): 

29 """ 

30 Decorator to register a step. 

31 """ 

32 if fn is not None and callable(fn): 

33 self.add(fn, outputs=outputs) 

34 return fn 

35 

36 def wrapper(func): 

37 self.add(func, outputs=outputs) 

38 return func 

39 

40 return wrapper 

41 

42 def run(self, **inputs): 

43 """ 

44 Delegates the execution to the configured Solver. 

45 """ 

46 logger.info(f"Starting pipeline execution with {len(self.steps)} steps and inputs: {list(inputs.keys())}") 

47 try: 

48 result = self.solver.solve(self.steps, inputs) 

49 logger.info("Pipeline execution completed successfully.") 

50 return result 

51 except Exception as e: 

52 logger.error(f"Pipeline execution failed: {e}") 

53 raise 

54 

55 def visualize(self, 

56 inputs: List[str] = None, 

57 output_path: str = None, 

58 orientation: Literal["TB", "LR"] = "TB", 

59 graph_type: Literal["flow", "bipartite"] = "flow", 

60 view: bool = True): 

61 """ 

62 Generates a Graphviz diagram of the pipeline. 

63 """ 

64 input_set = set(inputs or []) 

65 logger.debug(f"Generating visualization ({graph_type}) for pipeline.") 

66 

67 visualize_pipeline( 

68 steps=self.steps, 

69 inputs=input_set, 

70 output_path=output_path, 

71 orientation=orientation, 

72 graph_type=graph_type, 

73 view=view 

74 )