Coverage for agentos/core/middleware.py: 54%

148 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2AgentOS v1.1.4 Agent Runtime Middleware Pipeline — 可组合的执行生命周期中间件。 

3 

4在 Agent 执行的每个阶段(pre-LLM / post-LLM / pre-tool / post-tool) 

5插入策略检查、日志、脱敏、预算控制等拦截逻辑。 

6 

7灵感来自 Microsoft Agent Framework 1.0 的 Middleware Pipeline 和 CrewAI Runtime Hooks。 

8""" 

9 

10from __future__ import annotations 

11 

12from abc import ABC, abstractmethod 

13from dataclasses import dataclass, field 

14from enum import Enum 

15from typing import Any, Callable, Optional 

16 

17 

18class MiddlewarePhase(str, Enum): 

19 """中间件触发阶段。""" 

20 

21 PRE_LLM = "pre_llm" # LLM 调用前 

22 POST_LLM = "post_llm" # LLM 调用后、输出解析前 

23 PRE_TOOL = "pre_tool" # 工具调用前 

24 POST_TOOL = "post_tool" # 工具调用后 

25 ON_ERROR = "on_error" # 执行出错时 

26 ON_START = "on_start" # Agent 启动时 

27 ON_COMPLETE = "on_complete" # Agent 执行完成时 

28 

29 

30@dataclass 

31class MiddlewareContext: 

32 """中间件执行上下文。""" 

33 

34 phase: MiddlewarePhase 

35 agent_name: str = "" 

36 run_id: str = "" 

37 # LLM 阶段 

38 prompt: Optional[str] = None 

39 model_name: Optional[str] = None 

40 llm_output: Optional[str] = None 

41 # Tool 阶段 

42 tool_name: Optional[str] = None 

43 tool_args: Optional[dict] = None 

44 tool_result: Any = None 

45 # Error 

46 error: Optional[Exception] = None 

47 # 额外元数据 

48 metadata: dict[str, Any] = field(default_factory=dict) 

49 

50 

51@dataclass 

52class MiddlewareDecision: 

53 """中间件决策结果。""" 

54 

55 allow: bool = True 

56 """是否允许继续执行。""" 

57 

58 reason: str = "" 

59 """决策理由。""" 

60 

61 modified_context: Optional[MiddlewareContext] = None 

62 """修改后的上下文(如脱敏后的 prompt)。""" 

63 

64 action: str = "allow" # allow / warn / block / transform / escalate 

65 """决策动作。""" 

66 

67 blocked_by: str = "" 

68 """阻断方名称。""" 

69 

70 

71class AgentMiddleware(ABC): 

72 """Agent 运行时中间件基类。 

73 

74 每个中间件声明自己监听的阶段,通过 process() 返回决策。 

75 返回 MiddlewareDecision(allow=False) 阻断执行链。 

76 """ 

77 

78 name: str = "base_middleware" 

79 

80 @property 

81 def phases(self) -> list[MiddlewarePhase]: 

82 """返回此中间件监听的阶段列表。""" 

83 return [MiddlewarePhase.PRE_LLM] 

84 

85 @abstractmethod 

86 async def process(self, ctx: MiddlewareContext) -> MiddlewareDecision: 

87 """处理中间件逻辑。返回决策。""" 

88 ... 

89 

90 

91# ── 内置中间件 ────────────────────────────────────────────────────────────── 

92 

93class PIIMaskingMiddleware(AgentMiddleware): 

94 """PII脱敏中间件:在 pre-LLM 阶段对 prompt 脱敏。""" 

95 

96 name = "pii_masking" 

97 

98 @property 

99 def phases(self) -> list[MiddlewarePhase]: 

100 return [MiddlewarePhase.PRE_LLM] 

101 

102 async def process(self, ctx: MiddlewareContext) -> MiddlewareDecision: 

103 if not ctx.prompt: 

104 return MiddlewareDecision(allow=True) 

105 from agentos.security.guard import PIIDetector 

106 detector = PIIDetector(auto_redact=True) 

107 sanitized, items = detector.redact(ctx.prompt) 

108 count = len(items) 

109 if count > 0: 

110 new_ctx = MiddlewareContext(**{**ctx.__dict__}) 

111 new_ctx.prompt = sanitized 

112 new_ctx.metadata["pii_count"] = count 

113 return MiddlewareDecision( 

114 allow=True, action="transform", 

115 reason=f"Masked {count} PII instances", 

116 modified_context=new_ctx, 

117 ) 

118 return MiddlewareDecision(allow=True) 

119 

120 

121class BudgetGuardMiddleware(AgentMiddleware): 

122 """预算守护中间件:pre-LLM 阶段检查预算。""" 

123 

124 name = "budget_guard" 

125 

126 def __init__(self, tracker=None, budget_limit: float = 0.0, warn_ratio: float = 0.8): 

127 self.tracker = tracker 

128 self.budget_limit = budget_limit 

129 self.warn_ratio = warn_ratio 

130 

131 @property 

132 def phases(self) -> list[MiddlewarePhase]: 

133 return [MiddlewarePhase.PRE_LLM, MiddlewarePhase.PRE_TOOL] 

134 

135 async def process(self, ctx: MiddlewareContext) -> MiddlewareDecision: 

136 if not self.tracker or self.budget_limit <= 0: 

137 return MiddlewareDecision(allow=True) 

138 spent = self.tracker.total_cost 

139 ratio = spent / self.budget_limit 

140 if ratio >= 1.0: 

141 return MiddlewareDecision( 

142 allow=False, action="block", 

143 reason=f"Budget exceeded: ${spent:.4f} / ${self.budget_limit:.2f}", 

144 blocked_by=self.name, 

145 ) 

146 if ratio >= self.warn_ratio: 

147 return MiddlewareDecision( 

148 allow=True, action="warn", 

149 reason=f"Budget warning: {ratio:.0%} used (${spent:.4f} / ${self.budget_limit:.2f})", 

150 ) 

151 return MiddlewareDecision(allow=True) 

152 

153 

154class ToolRiskGuardMiddleware(AgentMiddleware): 

155 """工具风险守护中间件:pre-tool 阶段根据风险等级决定是否阻断。""" 

156 

157 name = "tool_risk_guard" 

158 

159 def __init__(self, max_auto_level: str = "medium"): 

160 from agentos.tools.risk import ToolRiskLevel 

161 self.max_auto_level = ToolRiskLevel(max_auto_level) 

162 

163 @property 

164 def phases(self) -> list[MiddlewarePhase]: 

165 return [MiddlewarePhase.PRE_TOOL] 

166 

167 async def process(self, ctx: MiddlewareContext) -> MiddlewareDecision: 

168 if not ctx.tool_name: 

169 return MiddlewareDecision(allow=True) 

170 

171 from agentos.tools.risk import infer_risk_level 

172 risk = infer_risk_level(ctx.tool_name, tool_args=ctx.tool_args) 

173 

174 if risk.requires_user_confirm(): 

175 return MiddlewareDecision( 

176 allow=False, action="escalate", 

177 reason=f"Tool '{ctx.tool_name}' requires user approval: {risk.description}", 

178 blocked_by=self.name, 

179 ) 

180 

181 levels = ["low", "medium", "high", "critical"] 

182 if levels.index(risk.level.value) > levels.index(self.max_auto_level.value): 

183 return MiddlewareDecision( 

184 allow=False, action="block", 

185 reason=f"Tool '{ctx.tool_name}' risk {risk.level.value} exceeds auto limit {self.max_auto_level.value}", 

186 blocked_by=self.name, 

187 ) 

188 

189 return MiddlewareDecision(allow=True) 

190 

191 

192class AuditLogMiddleware(AgentMiddleware): 

193 """审计日志中间件:在所有阶段记录审计轨迹。""" 

194 

195 name = "audit_log" 

196 

197 @property 

198 def phases(self) -> list[MiddlewarePhase]: 

199 return [ 

200 MiddlewarePhase.ON_START, MiddlewarePhase.PRE_LLM, 

201 MiddlewarePhase.PRE_TOOL, MiddlewarePhase.POST_TOOL, 

202 MiddlewarePhase.ON_ERROR, MiddlewarePhase.ON_COMPLETE, 

203 ] 

204 

205 async def process(self, ctx: MiddlewareContext) -> MiddlewareDecision: 

206 import logging 

207 logger = logging.getLogger("agentos.audit") 

208 logger.info( 

209 f"[{ctx.phase.value}] agent={ctx.agent_name} run={ctx.run_id} " 

210 f"tool={ctx.tool_name or '-'}" 

211 ) 

212 return MiddlewareDecision(allow=True) 

213 

214 

215# ── 中间件管道 ────────────────────────────────────────────────────────────── 

216 

217class MiddlewarePipeline: 

218 """编排多个中间件按阶段执行。 

219 

220 每个阶段: 

221 1. 筛选监听该阶段的中间件 

222 2. 按注册顺序依次执行 

223 3. 任一返回 allow=False 即阻断 

224 4. 若返回 modified_context 则传递给后续中间件 

225 """ 

226 

227 def __init__(self, middlewares: Optional[list[AgentMiddleware]] = None): 

228 self._middlewares: list[AgentMiddleware] = list(middlewares or []) 

229 

230 def add(self, middleware: AgentMiddleware) -> MiddlewarePipeline: 

231 """添加中间件,返回自身以支持链式调用。""" 

232 self._middlewares.append(middleware) 

233 return self 

234 

235 def remove(self, name: str) -> None: 

236 self._middlewares = [m for m in self._middlewares if m.name != name] 

237 

238 @property 

239 def middleware_names(self) -> list[str]: 

240 return [m.name for m in self._middlewares] 

241 

242 async def execute_phase( 

243 self, 

244 phase: MiddlewarePhase, 

245 ctx: MiddlewareContext, 

246 ) -> MiddlewareDecision: 

247 """执行指定阶段的所有中间件。""" 

248 current_ctx = ctx 

249 for mw in self._middlewares: 

250 if phase not in mw.phases: 

251 continue 

252 decision = await mw.process(current_ctx) 

253 if not decision.allow: 

254 decision.blocked_by = mw.name 

255 return decision 

256 if decision.modified_context: 

257 current_ctx = decision.modified_context 

258 return MiddlewareDecision(allow=True, modified_context=current_ctx) 

259 

260 async def on_start(self, ctx: MiddlewareContext) -> MiddlewareDecision: 

261 return await self.execute_phase(MiddlewarePhase.ON_START, ctx) 

262 

263 async def pre_llm(self, ctx: MiddlewareContext) -> MiddlewareDecision: 

264 return await self.execute_phase(MiddlewarePhase.PRE_LLM, ctx) 

265 

266 async def post_llm(self, ctx: MiddlewareContext) -> MiddlewareDecision: 

267 return await self.execute_phase(MiddlewarePhase.POST_LLM, ctx) 

268 

269 async def pre_tool(self, ctx: MiddlewareContext) -> MiddlewareDecision: 

270 return await self.execute_phase(MiddlewarePhase.PRE_TOOL, ctx) 

271 

272 async def post_tool(self, ctx: MiddlewareContext) -> MiddlewareDecision: 

273 return await self.execute_phase(MiddlewarePhase.POST_TOOL, ctx) 

274 

275 async def on_error(self, ctx: MiddlewareContext) -> MiddlewareDecision: 

276 return await self.execute_phase(MiddlewarePhase.ON_ERROR, ctx) 

277 

278 async def on_complete(self, ctx: MiddlewareContext) -> MiddlewareDecision: 

279 return await self.execute_phase(MiddlewarePhase.ON_COMPLETE, ctx)