Coverage for agentos/core/code_agent.py: 25%
171 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2AgentOS v1.1.9 — CodeAgent 模式。
4基因来源: Smolagents CodeAgent (HuggingFace)
6CodeAgent 允许 Agent 通过生成和执行 Python 代码来完成子任务,
7而非仅调用预定义工具。代码可以调用已注册的 tools + 安全内置函数。
9特性:
10- 多步执行:生成代码 → 执行 → 观察结果 → 继续
11- 安全沙箱:白名单模块、禁止危险操作、超时控制
12- Tools 集成:代码中直接调用 `tool_name(args)`
13- 内存持久:跨步骤的变量和结果通过 locals 传递
14"""
16from __future__ import annotations
18import ast
19import asyncio
20import inspect
21import sys
22import traceback
23from dataclasses import dataclass, field
24from typing import Any, Callable, Dict, List, Optional, Tuple
26from agentos.models.router import ModelRouter
29# ── 安全常量 ───────────────────────────────────
31DEFAULT_ALLOWED_MODULES = frozenset({
32 "math", "json", "re", "datetime", "collections",
33 "itertools", "functools", "typing", "dataclasses",
34 "decimal", "fractions", "statistics", "random",
35 "string", "textwrap", "unicodedata", "hashlib",
36 "base64", "binascii", "uuid", "copy", "pprint",
37 "enum", "pathlib", "logging", "warnings",
38 "csv", "html", "urllib.parse", "xml.etree.ElementTree",
39 "operator", "heapq", "bisect", "array",
40 "struct", "io", "os.path",
41})
43FORBIDDEN_CALLS = frozenset({
44 "exec", "eval", "compile", "open", "__import__",
45 "getattr", "setattr", "delattr", "hasattr",
46 "globals", "locals", "vars",
47 "breakpoint", "input",
48 "os", "subprocess", "shutil", "sys",
49 "ctypes", "socket", "pickle", "marshal",
50 "multiprocessing", "threading", "signal",
51})
53MAX_OUTPUT_LENGTH = 10000
56# ── 数据结构 ───────────────────────────────────
58@dataclass
59class CodeStep:
60 """CodeAgent 单步执行记录。"""
62 step: int
63 code: str
64 result: Any = None
65 stdout: str = ""
66 error: Optional[str] = None
67 duration_ms: float = 0.0
70@dataclass
71class CodeResult:
72 """CodeAgent 执行结果。"""
74 success: bool
75 final_answer: Any = None
76 steps: List[CodeStep] = field(default_factory=list)
77 total_duration_ms: float = 0.0
78 error: Optional[str] = None
81# ── 代码安全检查器 ─────────────────────────────
83class CodeGuard(ast.NodeVisitor):
84 """Python 代码 AST 安全扫描器,拦截危险操作。"""
86 def __init__(self, allowed_modules: frozenset):
87 self.allowed_modules = allowed_modules
88 self.violations: List[str] = []
90 def visit_Import(self, node: ast.Import) -> None:
91 for alias in node.names:
92 if alias.name not in self.allowed_modules:
93 self.violations.append(f"import '{alias.name}' not allowed")
95 def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
96 module = node.module or ""
97 base = module.split(".")[0]
98 if base not in self.allowed_modules:
99 self.violations.append(f"import from '{module}' not allowed")
101 def visit_Call(self, node: ast.Call) -> None:
102 if isinstance(node.func, ast.Name):
103 if node.func.id in FORBIDDEN_CALLS:
104 self.violations.append(f"call to '{node.func.id}()' is forbidden")
105 elif isinstance(node.func, ast.Attribute):
106 parts = []
107 curr = node.func
108 while isinstance(curr, ast.Attribute):
109 parts.append(curr.attr)
110 curr = curr.value
111 if isinstance(curr, ast.Name):
112 full = f"{curr.id}.{'.'.join(reversed(parts))}"
113 for forbidden in FORBIDDEN_CALLS:
114 if full.startswith(forbidden):
115 self.violations.append(f"call to '{full}()' is forbidden")
116 break
117 self.generic_visit(node)
120def scan_code(code: str, allowed_modules: frozenset) -> List[str]:
121 try:
122 tree = ast.parse(code)
123 except SyntaxError:
124 return []
125 guard = CodeGuard(allowed_modules)
126 guard.visit(tree)
127 return guard.violations
130# ── 受控执行环境 ───────────────────────────────
132def safe_exec(
133 code: str,
134 tools: Dict[str, Callable],
135 state: Dict[str, Any],
136 timeout: float,
137) -> Tuple[Any, str, Optional[str]]:
138 from io import StringIO
140 stdout_capture = StringIO()
141 old_stdout = sys.stdout
142 sys.stdout = stdout_capture
143 result = None
144 error = None
146 try:
147 exec_globals = {"__builtins__": __builtins__}
148 exec_globals.update(tools)
149 exec_globals.update({
150 "print": lambda *a, **kw: print(*a, **kw),
151 "__result__": None,
152 "state": state,
153 })
154 compiled = compile(code, "<code_agent>", "exec")
155 exec(compiled, exec_globals)
156 result = exec_globals.get("__result__")
157 for key in list(exec_globals.keys()):
158 if key.startswith("_") or key in tools or key in ("state", "print"):
159 continue
160 if key not in ("__builtins__",):
161 state.setdefault("_vars", {})[key] = exec_globals[key]
162 except Exception as e:
163 error = f"{type(e).__name__}: {e}\n{traceback.format_exc(limit=3)}"
164 finally:
165 sys.stdout = old_stdout
167 stdout = stdout_capture.getvalue()
168 if len(stdout) > MAX_OUTPUT_LENGTH:
169 stdout = stdout[:MAX_OUTPUT_LENGTH] + "\n... [truncated]"
170 return result, stdout, error
173# ── 代码生成 Prompt ─────────────────────────────
175CODE_AGENT_SYSTEM_PROMPT = """You are a CodeAgent that solves tasks by writing and executing Python code.
177YOU MUST respond ONLY with Python code inside ```python ... ``` blocks.
178NO explanations, NO markdown outside the code block. Just the code.
180Available tools (callable as functions):
181{tools_description}
183To output the final answer, assign it to the variable `__result__`.
184You can store persistent data in the `state` dict.
186Example:
187```python
188# Use tools
189data = web_search("Python 3.12 release date")
190# Compute
191result = len(data)
192# Return
193__result__ = f"Found {{result}} results"
194```
196Now solve the following task. ONLY output the code block."""
199# ── CodeAgent ───────────────────────────────────
201class CodeAgent:
202 """代码执行型 Agent。"""
204 def __init__(
205 self,
206 tools: List[Callable] | None = None,
207 model: str = "gpt-4o",
208 max_steps: int = 10,
209 timeout_per_step: float = 30.0,
210 allowed_modules: frozenset = DEFAULT_ALLOWED_MODULES,
211 ):
212 self.model = model
213 self.max_steps = max_steps
214 self.timeout_per_step = timeout_per_step
215 self.allowed_modules = allowed_modules
216 self._tools: Dict[str, Callable] = {}
217 if tools:
218 for tool in tools:
219 self._tools[tool.__name__] = tool
221 @property
222 def tools(self) -> Dict[str, Callable]:
223 return self._tools
225 def _tools_description(self) -> str:
226 lines = []
227 for name, fn in self._tools.items():
228 sig = str(inspect.signature(fn))
229 doc = (inspect.getdoc(fn) or "No description").split("\n")[0]
230 lines.append(f" {name}{sig}: {doc}")
231 return "\n".join(lines) if lines else " (no tools available)"
233 async def run(self, task: str, state: Dict[str, Any] | None = None) -> CodeResult:
234 if state is None:
235 state = {"_vars": {}}
236 tools_desc = self._tools_description()
237 steps: List[CodeStep] = []
238 total_start = asyncio.get_event_loop().time()
240 for step_num in range(1, self.max_steps + 1):
241 if step_num == 1:
242 user_prompt = task
243 else:
244 last = steps[-1]
245 if last.error:
246 feedback = f"Error: {last.error}"
247 else:
248 rp = str(last.result)[:500] if last.result is not None else "None"
249 op = last.stdout[:500] if last.stdout else ""
250 feedback = f"Output: {op}\nResult: {rp}"
251 user_prompt = f"Step {step_num}: Continue.\nPrevious result:\n{feedback}\n\nTask: {task}"
253 router = ModelRouter()
254 try:
255 response = await router.chat(
256 model=self.model,
257 messages=[
258 {"role": "system", "content": CODE_AGENT_SYSTEM_PROMPT.format(tools_description=tools_desc)},
259 {"role": "user", "content": user_prompt},
260 ],
261 temperature=0.0,
262 max_tokens=2048,
263 )
264 except Exception as e:
265 return CodeResult(
266 success=False, steps=steps,
267 total_duration_ms=(asyncio.get_event_loop().time() - total_start) * 1000,
268 error=f"LLM error: {e}",
269 )
271 code = self._extract_code(response.content)
272 if not code:
273 if steps:
274 return CodeResult(
275 success=True, final_answer=steps[-1].result, steps=steps,
276 total_duration_ms=(asyncio.get_event_loop().time() - total_start) * 1000,
277 )
278 continue
280 violations = scan_code(code, self.allowed_modules)
281 if violations:
282 steps.append(CodeStep(step=step_num, code=code,
283 error=f"Security violation: {'; '.join(violations)}"))
284 continue
286 step_start = asyncio.get_event_loop().time()
287 try:
288 loop = asyncio.get_event_loop()
289 result, stdout, error = await asyncio.wait_for(
290 loop.run_in_executor(None, safe_exec, code, self._tools, state, self.timeout_per_step),
291 timeout=self.timeout_per_step + 5,
292 )
293 except asyncio.TimeoutError:
294 result, stdout, error = None, "", "TimeoutError: exceeded limit"
296 step_duration = (asyncio.get_event_loop().time() - step_start) * 1000
297 cs = CodeStep(step=step_num, code=code, result=result, stdout=stdout, error=error, duration_ms=step_duration)
298 steps.append(cs)
300 if error:
301 continue
303 if "__result__" in code or (result is not None and "__result__" in code):
304 return CodeResult(
305 success=True, final_answer=result, steps=steps,
306 total_duration_ms=(asyncio.get_event_loop().time() - total_start) * 1000,
307 )
309 # heuristic: non-trivial result without error = likely done
310 if result is not None and step_num >= 1:
311 return CodeResult(
312 success=True, final_answer=result, steps=steps,
313 total_duration_ms=(asyncio.get_event_loop().time() - total_start) * 1000,
314 )
316 return CodeResult(
317 success=False, final_answer=steps[-1].result if steps else None, steps=steps,
318 total_duration_ms=(asyncio.get_event_loop().time() - total_start) * 1000,
319 error=f"Max steps ({self.max_steps}) reached",
320 )
322 @staticmethod
323 def _extract_code(content: str) -> Optional[str]:
324 if "```python" in content:
325 parts = content.split("```python", 1)
326 if len(parts) > 1:
327 return parts[1].split("```", 1)[0].strip()
328 if "```" in content:
329 parts = content.split("```", 1)
330 if len(parts) > 1:
331 return parts[1].split("```", 1)[0].strip()
332 if "print(" in content or "def " in content or "result" in content:
333 return content.strip()
334 return None