Coverage for agentos/tools/pipeline.py: 0%

146 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-03 07:07 +0800

1""" 

2Pipeline — composable data processing pipeline for AgentOS. 

3 

4Features: 

5- Linear and branching pipelines 

6- Fan-out / fan-in patterns 

7- Backpressure and quota control 

8- Stage-level error handling and retry 

9- Pipeline serialization for checkpoint/resume 

10""" 

11 

12import threading 

13import time 

14from abc import ABC, abstractmethod 

15from collections import deque 

16from dataclasses import dataclass, field 

17from enum import Enum, auto 

18from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, TypeVar 

19 

20T = TypeVar("T") 

21U = TypeVar("U") 

22 

23 

24# ============================================================================ 

25# Core Types 

26# ============================================================================ 

27 

28class StageStatus(Enum): 

29 IDLE = auto() 

30 RUNNING = auto() 

31 PAUSED = auto() 

32 STOPPED = auto() 

33 ERROR = auto() 

34 

35 

36@dataclass 

37class PipelineContext: 

38 """Shared context flowing through the pipeline.""" 

39 data: Dict[str, Any] = field(default_factory=dict) 

40 metadata: Dict[str, Any] = field(default_factory=dict) 

41 

42 def get(self, key: str, default: Any = None) -> Any: 

43 return self.data.get(key, default) 

44 

45 def set(self, key: str, value: Any) -> None: 

46 self.data[key] = value 

47 

48 

49# ============================================================================ 

50# Stage 

51# ============================================================================ 

52 

53_NO_FALLBACK = object() 

54 

55 

56class Stage(Generic[T, U], ABC): 

57 """Abstract pipeline stage. Transforms T → U.""" 

58 

59 def __init__(self, name: str = "", max_retries: int = 0): 

60 self.name = name or self.__class__.__name__ 

61 self.max_retries = max_retries 

62 self.status: StageStatus = StageStatus.IDLE 

63 self._error: Optional[Exception] = None 

64 self._items_processed: int = 0 

65 self._items_errored: int = 0 

66 

67 @abstractmethod 

68 def process(self, item: T, ctx: PipelineContext) -> U: ... 

69 

70 def on_error(self, item: T, error: Exception, ctx: PipelineContext) -> Any: 

71 """Override to provide fallback on error. Return _NO_FALLBACK to propagate.""" 

72 return _NO_FALLBACK 

73 

74 def execute(self, item: T, ctx: PipelineContext) -> U: 

75 self.status = StageStatus.RUNNING 

76 for attempt in range(self.max_retries + 1): 

77 try: 

78 result = self.process(item, ctx) 

79 self._items_processed += 1 

80 self.status = StageStatus.IDLE 

81 return result 

82 except Exception as e: 

83 self._items_errored += 1 

84 if attempt < self.max_retries: 

85 time.sleep(0.01 * (attempt + 1)) 

86 continue 

87 fallback = self.on_error(item, e, ctx) 

88 if fallback is not _NO_FALLBACK: 

89 self.status = StageStatus.IDLE 

90 return fallback 

91 self._error = e 

92 self.status = StageStatus.ERROR 

93 raise 

94 

95 @property 

96 def stats(self) -> Dict[str, Any]: 

97 return { 

98 "name": self.name, 

99 "status": self.status.name, 

100 "items_processed": self._items_processed, 

101 "items_errored": self._items_errored, 

102 } 

103 

104 

105class LambdaStage(Stage[T, U]): 

106 """Convenience stage from a callable.""" 

107 def __init__(self, fn: Callable[[T, PipelineContext], U], name: str = "", max_retries: int = 0): 

108 super().__init__(name=name, max_retries=max_retries) 

109 self._fn = fn 

110 

111 def process(self, item: T, ctx: PipelineContext) -> U: 

112 return self._fn(item, ctx) 

113 

114 

115# ============================================================================ 

116# Pipeline 

117# ============================================================================ 

118 

119class Pipeline(Generic[T, U]): 

120 """Linear pipeline: a sequence of stages T → ? → ... → U.""" 

121 

122 def __init__(self, name: str = "pipeline"): 

123 self.name = name 

124 self._stages: List[Stage] = [] 

125 self._lock = threading.Lock() 

126 self._ctx = PipelineContext() 

127 self.status: StageStatus = StageStatus.IDLE 

128 

129 def add_stage(self, stage: Stage) -> "Pipeline": 

130 with self._lock: 

131 self._stages.append(stage) 

132 return self 

133 

134 def then(self, fn: Callable[[Any, PipelineContext], Any], name: str = "", max_retries: int = 0) -> "Pipeline": 

135 """Fluent API: add a lambda stage.""" 

136 return self.add_stage(LambdaStage(fn, name=name, max_retries=max_retries)) 

137 

138 def run(self, input_item: T) -> U: 

139 """Run pipeline on a single item.""" 

140 current = input_item 

141 self.status = StageStatus.RUNNING 

142 try: 

143 for stage in self._stages: 

144 current = stage.execute(current, self._ctx) 

145 return current 

146 finally: 

147 all_idle = all(s.status == StageStatus.IDLE for s in self._stages) 

148 self.status = StageStatus.IDLE if all_idle else StageStatus.ERROR 

149 

150 def run_batch(self, items: List[T]) -> List[U]: 

151 """Run pipeline on a batch.""" 

152 results = [] 

153 for item in items: 

154 results.append(self.run(item)) 

155 return results 

156 

157 @property 

158 def context(self) -> PipelineContext: 

159 return self._ctx 

160 

161 @property 

162 def stats(self) -> Dict[str, Any]: 

163 return { 

164 "name": self.name, 

165 "status": self.status.name, 

166 "stages": [s.stats for s in self._stages], 

167 } 

168 

169 

170# ============================================================================ 

171# ParallelPipeline — Fan-out / Fan-in 

172# ============================================================================ 

173 

174class ParallelPipeline(Generic[T, U]): 

175 """Branches: split input across parallel stages, then merge results. 

176 

177 Fan-out: single input → all branches simultaneously. 

178 Fan-in: all branch outputs → merge function → single output. 

179 """ 

180 

181 def __init__(self, name: str = "parallel_pipeline"): 

182 self.name = name 

183 self._branches: List[Pipeline] = [] 

184 self._merge: Optional[Callable[[List[Any], PipelineContext], U]] = None 

185 self._lock = threading.Lock() 

186 self._ctx = PipelineContext() 

187 

188 def branch(self, pipeline: Pipeline) -> "ParallelPipeline": 

189 with self._lock: 

190 self._branches.append(pipeline) 

191 return self 

192 

193 def merge(self, fn: Callable[[List[Any], PipelineContext], U]) -> "ParallelPipeline": 

194 self._merge = fn 

195 return self 

196 

197 def run(self, input_item: T) -> U: 

198 import concurrent.futures 

199 

200 with concurrent.futures.ThreadPoolExecutor(max_workers=len(self._branches)) as pool: 

201 futures = { 

202 pool.submit(branch.run, input_item): i 

203 for i, branch in enumerate(self._branches) 

204 } 

205 results = [None] * len(self._branches) 

206 for future in concurrent.futures.as_completed(futures): 

207 idx = futures[future] 

208 results[idx] = future.result() 

209 

210 if self._merge: 

211 return self._merge(results, self._ctx) 

212 return results # type: ignore 

213 

214 @property 

215 def context(self) -> PipelineContext: 

216 return self._ctx 

217 

218 

219# ============================================================================ 

220# Stage helpers 

221# ============================================================================ 

222 

223class FilterStage(Stage[T, T]): 

224 """Pass-through stage that filters items.""" 

225 def __init__(self, predicate: Callable[[T, PipelineContext], bool], name: str = "filter", max_retries: int = 0): 

226 super().__init__(name=name, max_retries=max_retries) 

227 self._predicate = predicate 

228 

229 def process(self, item: T, ctx: PipelineContext) -> T: 

230 if not self._predicate(item, ctx): 

231 raise FilterDrop() 

232 return item 

233 

234 def on_error(self, item: T, error: Exception, ctx: PipelineContext) -> Optional[T]: 

235 if isinstance(error, FilterDrop): 

236 return None 

237 return super().on_error(item, error, ctx) 

238 

239 

240class FilterDrop(Exception): 

241 """Signal that an item should be filtered out.""" 

242 pass 

243 

244 

245class BatchStage(Stage[List[T], List[U]]): 

246 """Accumulates items into batches before processing.""" 

247 def __init__(self, batch_size: int, fn: Callable[[List[T], PipelineContext], List[U]], name: str = "batch", max_retries: int = 0): 

248 super().__init__(name=name, max_retries=max_retries) 

249 self.batch_size = batch_size 

250 self._buffer: List[T] = [] 

251 self._fn = fn 

252 

253 def process(self, item: List[T], ctx: PipelineContext) -> List[U]: 

254 return self._fn(item, ctx)