Coverage for agentos/core/middleware.py: 54%
148 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2AgentOS v1.1.4 Agent Runtime Middleware Pipeline — 可组合的执行生命周期中间件。
4在 Agent 执行的每个阶段(pre-LLM / post-LLM / pre-tool / post-tool)
5插入策略检查、日志、脱敏、预算控制等拦截逻辑。
7灵感来自 Microsoft Agent Framework 1.0 的 Middleware Pipeline 和 CrewAI Runtime Hooks。
8"""
10from __future__ import annotations
12from abc import ABC, abstractmethod
13from dataclasses import dataclass, field
14from enum import Enum
15from typing import Any, Callable, Optional
18class MiddlewarePhase(str, Enum):
19 """中间件触发阶段。"""
21 PRE_LLM = "pre_llm" # LLM 调用前
22 POST_LLM = "post_llm" # LLM 调用后、输出解析前
23 PRE_TOOL = "pre_tool" # 工具调用前
24 POST_TOOL = "post_tool" # 工具调用后
25 ON_ERROR = "on_error" # 执行出错时
26 ON_START = "on_start" # Agent 启动时
27 ON_COMPLETE = "on_complete" # Agent 执行完成时
30@dataclass
31class MiddlewareContext:
32 """中间件执行上下文。"""
34 phase: MiddlewarePhase
35 agent_name: str = ""
36 run_id: str = ""
37 # LLM 阶段
38 prompt: Optional[str] = None
39 model_name: Optional[str] = None
40 llm_output: Optional[str] = None
41 # Tool 阶段
42 tool_name: Optional[str] = None
43 tool_args: Optional[dict] = None
44 tool_result: Any = None
45 # Error
46 error: Optional[Exception] = None
47 # 额外元数据
48 metadata: dict[str, Any] = field(default_factory=dict)
51@dataclass
52class MiddlewareDecision:
53 """中间件决策结果。"""
55 allow: bool = True
56 """是否允许继续执行。"""
58 reason: str = ""
59 """决策理由。"""
61 modified_context: Optional[MiddlewareContext] = None
62 """修改后的上下文(如脱敏后的 prompt)。"""
64 action: str = "allow" # allow / warn / block / transform / escalate
65 """决策动作。"""
67 blocked_by: str = ""
68 """阻断方名称。"""
71class AgentMiddleware(ABC):
72 """Agent 运行时中间件基类。
74 每个中间件声明自己监听的阶段,通过 process() 返回决策。
75 返回 MiddlewareDecision(allow=False) 阻断执行链。
76 """
78 name: str = "base_middleware"
80 @property
81 def phases(self) -> list[MiddlewarePhase]:
82 """返回此中间件监听的阶段列表。"""
83 return [MiddlewarePhase.PRE_LLM]
85 @abstractmethod
86 async def process(self, ctx: MiddlewareContext) -> MiddlewareDecision:
87 """处理中间件逻辑。返回决策。"""
88 ...
91# ── 内置中间件 ──────────────────────────────────────────────────────────────
93class PIIMaskingMiddleware(AgentMiddleware):
94 """PII脱敏中间件:在 pre-LLM 阶段对 prompt 脱敏。"""
96 name = "pii_masking"
98 @property
99 def phases(self) -> list[MiddlewarePhase]:
100 return [MiddlewarePhase.PRE_LLM]
102 async def process(self, ctx: MiddlewareContext) -> MiddlewareDecision:
103 if not ctx.prompt:
104 return MiddlewareDecision(allow=True)
105 from agentos.security.guard import PIIDetector
106 detector = PIIDetector(auto_redact=True)
107 sanitized, items = detector.redact(ctx.prompt)
108 count = len(items)
109 if count > 0:
110 new_ctx = MiddlewareContext(**{**ctx.__dict__})
111 new_ctx.prompt = sanitized
112 new_ctx.metadata["pii_count"] = count
113 return MiddlewareDecision(
114 allow=True, action="transform",
115 reason=f"Masked {count} PII instances",
116 modified_context=new_ctx,
117 )
118 return MiddlewareDecision(allow=True)
121class BudgetGuardMiddleware(AgentMiddleware):
122 """预算守护中间件:pre-LLM 阶段检查预算。"""
124 name = "budget_guard"
126 def __init__(self, tracker=None, budget_limit: float = 0.0, warn_ratio: float = 0.8):
127 self.tracker = tracker
128 self.budget_limit = budget_limit
129 self.warn_ratio = warn_ratio
131 @property
132 def phases(self) -> list[MiddlewarePhase]:
133 return [MiddlewarePhase.PRE_LLM, MiddlewarePhase.PRE_TOOL]
135 async def process(self, ctx: MiddlewareContext) -> MiddlewareDecision:
136 if not self.tracker or self.budget_limit <= 0:
137 return MiddlewareDecision(allow=True)
138 spent = self.tracker.total_cost
139 ratio = spent / self.budget_limit
140 if ratio >= 1.0:
141 return MiddlewareDecision(
142 allow=False, action="block",
143 reason=f"Budget exceeded: ${spent:.4f} / ${self.budget_limit:.2f}",
144 blocked_by=self.name,
145 )
146 if ratio >= self.warn_ratio:
147 return MiddlewareDecision(
148 allow=True, action="warn",
149 reason=f"Budget warning: {ratio:.0%} used (${spent:.4f} / ${self.budget_limit:.2f})",
150 )
151 return MiddlewareDecision(allow=True)
154class ToolRiskGuardMiddleware(AgentMiddleware):
155 """工具风险守护中间件:pre-tool 阶段根据风险等级决定是否阻断。"""
157 name = "tool_risk_guard"
159 def __init__(self, max_auto_level: str = "medium"):
160 from agentos.tools.risk import ToolRiskLevel
161 self.max_auto_level = ToolRiskLevel(max_auto_level)
163 @property
164 def phases(self) -> list[MiddlewarePhase]:
165 return [MiddlewarePhase.PRE_TOOL]
167 async def process(self, ctx: MiddlewareContext) -> MiddlewareDecision:
168 if not ctx.tool_name:
169 return MiddlewareDecision(allow=True)
171 from agentos.tools.risk import infer_risk_level
172 risk = infer_risk_level(ctx.tool_name, tool_args=ctx.tool_args)
174 if risk.requires_user_confirm():
175 return MiddlewareDecision(
176 allow=False, action="escalate",
177 reason=f"Tool '{ctx.tool_name}' requires user approval: {risk.description}",
178 blocked_by=self.name,
179 )
181 levels = ["low", "medium", "high", "critical"]
182 if levels.index(risk.level.value) > levels.index(self.max_auto_level.value):
183 return MiddlewareDecision(
184 allow=False, action="block",
185 reason=f"Tool '{ctx.tool_name}' risk {risk.level.value} exceeds auto limit {self.max_auto_level.value}",
186 blocked_by=self.name,
187 )
189 return MiddlewareDecision(allow=True)
192class AuditLogMiddleware(AgentMiddleware):
193 """审计日志中间件:在所有阶段记录审计轨迹。"""
195 name = "audit_log"
197 @property
198 def phases(self) -> list[MiddlewarePhase]:
199 return [
200 MiddlewarePhase.ON_START, MiddlewarePhase.PRE_LLM,
201 MiddlewarePhase.PRE_TOOL, MiddlewarePhase.POST_TOOL,
202 MiddlewarePhase.ON_ERROR, MiddlewarePhase.ON_COMPLETE,
203 ]
205 async def process(self, ctx: MiddlewareContext) -> MiddlewareDecision:
206 import logging
207 logger = logging.getLogger("agentos.audit")
208 logger.info(
209 f"[{ctx.phase.value}] agent={ctx.agent_name} run={ctx.run_id} "
210 f"tool={ctx.tool_name or '-'}"
211 )
212 return MiddlewareDecision(allow=True)
215# ── 中间件管道 ──────────────────────────────────────────────────────────────
217class MiddlewarePipeline:
218 """编排多个中间件按阶段执行。
220 每个阶段:
221 1. 筛选监听该阶段的中间件
222 2. 按注册顺序依次执行
223 3. 任一返回 allow=False 即阻断
224 4. 若返回 modified_context 则传递给后续中间件
225 """
227 def __init__(self, middlewares: Optional[list[AgentMiddleware]] = None):
228 self._middlewares: list[AgentMiddleware] = list(middlewares or [])
230 def add(self, middleware: AgentMiddleware) -> MiddlewarePipeline:
231 """添加中间件,返回自身以支持链式调用。"""
232 self._middlewares.append(middleware)
233 return self
235 def remove(self, name: str) -> None:
236 self._middlewares = [m for m in self._middlewares if m.name != name]
238 @property
239 def middleware_names(self) -> list[str]:
240 return [m.name for m in self._middlewares]
242 async def execute_phase(
243 self,
244 phase: MiddlewarePhase,
245 ctx: MiddlewareContext,
246 ) -> MiddlewareDecision:
247 """执行指定阶段的所有中间件。"""
248 current_ctx = ctx
249 for mw in self._middlewares:
250 if phase not in mw.phases:
251 continue
252 decision = await mw.process(current_ctx)
253 if not decision.allow:
254 decision.blocked_by = mw.name
255 return decision
256 if decision.modified_context:
257 current_ctx = decision.modified_context
258 return MiddlewareDecision(allow=True, modified_context=current_ctx)
260 async def on_start(self, ctx: MiddlewareContext) -> MiddlewareDecision:
261 return await self.execute_phase(MiddlewarePhase.ON_START, ctx)
263 async def pre_llm(self, ctx: MiddlewareContext) -> MiddlewareDecision:
264 return await self.execute_phase(MiddlewarePhase.PRE_LLM, ctx)
266 async def post_llm(self, ctx: MiddlewareContext) -> MiddlewareDecision:
267 return await self.execute_phase(MiddlewarePhase.POST_LLM, ctx)
269 async def pre_tool(self, ctx: MiddlewareContext) -> MiddlewareDecision:
270 return await self.execute_phase(MiddlewarePhase.PRE_TOOL, ctx)
272 async def post_tool(self, ctx: MiddlewareContext) -> MiddlewareDecision:
273 return await self.execute_phase(MiddlewarePhase.POST_TOOL, ctx)
275 async def on_error(self, ctx: MiddlewareContext) -> MiddlewareDecision:
276 return await self.execute_phase(MiddlewarePhase.ON_ERROR, ctx)
278 async def on_complete(self, ctx: MiddlewareContext) -> MiddlewareDecision:
279 return await self.execute_phase(MiddlewarePhase.ON_COMPLETE, ctx)