Coverage for agentos/tools/pipeline.py: 95%
146 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 07:03 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-03 07:03 +0800
1"""
2Pipeline — composable data processing pipeline for AgentOS.
4Features:
5- Linear and branching pipelines
6- Fan-out / fan-in patterns
7- Backpressure and quota control
8- Stage-level error handling and retry
9- Pipeline serialization for checkpoint/resume
10"""
12import threading
13import time
14from abc import ABC, abstractmethod
15from collections import deque
16from dataclasses import dataclass, field
17from enum import Enum, auto
18from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, TypeVar
20T = TypeVar("T")
21U = TypeVar("U")
24# ============================================================================
25# Core Types
26# ============================================================================
28class StageStatus(Enum):
29 IDLE = auto()
30 RUNNING = auto()
31 PAUSED = auto()
32 STOPPED = auto()
33 ERROR = auto()
36@dataclass
37class PipelineContext:
38 """Shared context flowing through the pipeline."""
39 data: Dict[str, Any] = field(default_factory=dict)
40 metadata: Dict[str, Any] = field(default_factory=dict)
42 def get(self, key: str, default: Any = None) -> Any:
43 return self.data.get(key, default)
45 def set(self, key: str, value: Any) -> None:
46 self.data[key] = value
49# ============================================================================
50# Stage
51# ============================================================================
53_NO_FALLBACK = object()
56class Stage(Generic[T, U], ABC):
57 """Abstract pipeline stage. Transforms T → U."""
59 def __init__(self, name: str = "", max_retries: int = 0):
60 self.name = name or self.__class__.__name__
61 self.max_retries = max_retries
62 self.status: StageStatus = StageStatus.IDLE
63 self._error: Optional[Exception] = None
64 self._items_processed: int = 0
65 self._items_errored: int = 0
67 @abstractmethod
68 def process(self, item: T, ctx: PipelineContext) -> U: ...
70 def on_error(self, item: T, error: Exception, ctx: PipelineContext) -> Any:
71 """Override to provide fallback on error. Return _NO_FALLBACK to propagate."""
72 return _NO_FALLBACK
74 def execute(self, item: T, ctx: PipelineContext) -> U:
75 self.status = StageStatus.RUNNING
76 for attempt in range(self.max_retries + 1):
77 try:
78 result = self.process(item, ctx)
79 self._items_processed += 1
80 self.status = StageStatus.IDLE
81 return result
82 except Exception as e:
83 self._items_errored += 1
84 if attempt < self.max_retries:
85 time.sleep(0.01 * (attempt + 1))
86 continue
87 fallback = self.on_error(item, e, ctx)
88 if fallback is not _NO_FALLBACK:
89 self.status = StageStatus.IDLE
90 return fallback
91 self._error = e
92 self.status = StageStatus.ERROR
93 raise
95 @property
96 def stats(self) -> Dict[str, Any]:
97 return {
98 "name": self.name,
99 "status": self.status.name,
100 "items_processed": self._items_processed,
101 "items_errored": self._items_errored,
102 }
105class LambdaStage(Stage[T, U]):
106 """Convenience stage from a callable."""
107 def __init__(self, fn: Callable[[T, PipelineContext], U], name: str = "", max_retries: int = 0):
108 super().__init__(name=name, max_retries=max_retries)
109 self._fn = fn
111 def process(self, item: T, ctx: PipelineContext) -> U:
112 return self._fn(item, ctx)
115# ============================================================================
116# Pipeline
117# ============================================================================
119class Pipeline(Generic[T, U]):
120 """Linear pipeline: a sequence of stages T → ? → ... → U."""
122 def __init__(self, name: str = "pipeline"):
123 self.name = name
124 self._stages: List[Stage] = []
125 self._lock = threading.Lock()
126 self._ctx = PipelineContext()
127 self.status: StageStatus = StageStatus.IDLE
129 def add_stage(self, stage: Stage) -> "Pipeline":
130 with self._lock:
131 self._stages.append(stage)
132 return self
134 def then(self, fn: Callable[[Any, PipelineContext], Any], name: str = "", max_retries: int = 0) -> "Pipeline":
135 """Fluent API: add a lambda stage."""
136 return self.add_stage(LambdaStage(fn, name=name, max_retries=max_retries))
138 def run(self, input_item: T) -> U:
139 """Run pipeline on a single item."""
140 current = input_item
141 self.status = StageStatus.RUNNING
142 try:
143 for stage in self._stages:
144 current = stage.execute(current, self._ctx)
145 return current
146 finally:
147 all_idle = all(s.status == StageStatus.IDLE for s in self._stages)
148 self.status = StageStatus.IDLE if all_idle else StageStatus.ERROR
150 def run_batch(self, items: List[T]) -> List[U]:
151 """Run pipeline on a batch."""
152 results = []
153 for item in items:
154 results.append(self.run(item))
155 return results
157 @property
158 def context(self) -> PipelineContext:
159 return self._ctx
161 @property
162 def stats(self) -> Dict[str, Any]:
163 return {
164 "name": self.name,
165 "status": self.status.name,
166 "stages": [s.stats for s in self._stages],
167 }
170# ============================================================================
171# ParallelPipeline — Fan-out / Fan-in
172# ============================================================================
174class ParallelPipeline(Generic[T, U]):
175 """Branches: split input across parallel stages, then merge results.
177 Fan-out: single input → all branches simultaneously.
178 Fan-in: all branch outputs → merge function → single output.
179 """
181 def __init__(self, name: str = "parallel_pipeline"):
182 self.name = name
183 self._branches: List[Pipeline] = []
184 self._merge: Optional[Callable[[List[Any], PipelineContext], U]] = None
185 self._lock = threading.Lock()
186 self._ctx = PipelineContext()
188 def branch(self, pipeline: Pipeline) -> "ParallelPipeline":
189 with self._lock:
190 self._branches.append(pipeline)
191 return self
193 def merge(self, fn: Callable[[List[Any], PipelineContext], U]) -> "ParallelPipeline":
194 self._merge = fn
195 return self
197 def run(self, input_item: T) -> U:
198 import concurrent.futures
200 with concurrent.futures.ThreadPoolExecutor(max_workers=len(self._branches)) as pool:
201 futures = {
202 pool.submit(branch.run, input_item): i
203 for i, branch in enumerate(self._branches)
204 }
205 results = [None] * len(self._branches)
206 for future in concurrent.futures.as_completed(futures):
207 idx = futures[future]
208 results[idx] = future.result()
210 if self._merge:
211 return self._merge(results, self._ctx)
212 return results # type: ignore
214 @property
215 def context(self) -> PipelineContext:
216 return self._ctx
219# ============================================================================
220# Stage helpers
221# ============================================================================
223class FilterStage(Stage[T, T]):
224 """Pass-through stage that filters items."""
225 def __init__(self, predicate: Callable[[T, PipelineContext], bool], name: str = "filter", max_retries: int = 0):
226 super().__init__(name=name, max_retries=max_retries)
227 self._predicate = predicate
229 def process(self, item: T, ctx: PipelineContext) -> T:
230 if not self._predicate(item, ctx):
231 raise FilterDrop()
232 return item
234 def on_error(self, item: T, error: Exception, ctx: PipelineContext) -> Optional[T]:
235 if isinstance(error, FilterDrop):
236 return None
237 return super().on_error(item, error, ctx)
240class FilterDrop(Exception):
241 """Signal that an item should be filtered out."""
242 pass
245class BatchStage(Stage[List[T], List[U]]):
246 """Accumulates items into batches before processing."""
247 def __init__(self, batch_size: int, fn: Callable[[List[T], PipelineContext], List[U]], name: str = "batch", max_retries: int = 0):
248 super().__init__(name=name, max_retries=max_retries)
249 self.batch_size = batch_size
250 self._buffer: List[T] = []
251 self._fn = fn
253 def process(self, item: List[T], ctx: PipelineContext) -> List[U]:
254 return self._fn(item, ctx)