Coverage for agentos/core/code_agent.py: 25%

171 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2AgentOS v1.1.9 — CodeAgent 模式。 

3 

4基因来源: Smolagents CodeAgent (HuggingFace) 

5 

6CodeAgent 允许 Agent 通过生成和执行 Python 代码来完成子任务, 

7而非仅调用预定义工具。代码可以调用已注册的 tools + 安全内置函数。 

8 

9特性: 

10- 多步执行:生成代码 → 执行 → 观察结果 → 继续 

11- 安全沙箱:白名单模块、禁止危险操作、超时控制 

12- Tools 集成:代码中直接调用 `tool_name(args)` 

13- 内存持久:跨步骤的变量和结果通过 locals 传递 

14""" 

15 

16from __future__ import annotations 

17 

18import ast 

19import asyncio 

20import inspect 

21import sys 

22import traceback 

23from dataclasses import dataclass, field 

24from typing import Any, Callable, Dict, List, Optional, Tuple 

25 

26from agentos.models.router import ModelRouter 

27 

28 

29# ── 安全常量 ─────────────────────────────────── 

30 

31DEFAULT_ALLOWED_MODULES = frozenset({ 

32 "math", "json", "re", "datetime", "collections", 

33 "itertools", "functools", "typing", "dataclasses", 

34 "decimal", "fractions", "statistics", "random", 

35 "string", "textwrap", "unicodedata", "hashlib", 

36 "base64", "binascii", "uuid", "copy", "pprint", 

37 "enum", "pathlib", "logging", "warnings", 

38 "csv", "html", "urllib.parse", "xml.etree.ElementTree", 

39 "operator", "heapq", "bisect", "array", 

40 "struct", "io", "os.path", 

41}) 

42 

43FORBIDDEN_CALLS = frozenset({ 

44 "exec", "eval", "compile", "open", "__import__", 

45 "getattr", "setattr", "delattr", "hasattr", 

46 "globals", "locals", "vars", 

47 "breakpoint", "input", 

48 "os", "subprocess", "shutil", "sys", 

49 "ctypes", "socket", "pickle", "marshal", 

50 "multiprocessing", "threading", "signal", 

51}) 

52 

53MAX_OUTPUT_LENGTH = 10000 

54 

55 

56# ── 数据结构 ─────────────────────────────────── 

57 

58@dataclass 

59class CodeStep: 

60 """CodeAgent 单步执行记录。""" 

61 

62 step: int 

63 code: str 

64 result: Any = None 

65 stdout: str = "" 

66 error: Optional[str] = None 

67 duration_ms: float = 0.0 

68 

69 

70@dataclass 

71class CodeResult: 

72 """CodeAgent 执行结果。""" 

73 

74 success: bool 

75 final_answer: Any = None 

76 steps: List[CodeStep] = field(default_factory=list) 

77 total_duration_ms: float = 0.0 

78 error: Optional[str] = None 

79 

80 

81# ── 代码安全检查器 ───────────────────────────── 

82 

83class CodeGuard(ast.NodeVisitor): 

84 """Python 代码 AST 安全扫描器,拦截危险操作。""" 

85 

86 def __init__(self, allowed_modules: frozenset): 

87 self.allowed_modules = allowed_modules 

88 self.violations: List[str] = [] 

89 

90 def visit_Import(self, node: ast.Import) -> None: 

91 for alias in node.names: 

92 if alias.name not in self.allowed_modules: 

93 self.violations.append(f"import '{alias.name}' not allowed") 

94 

95 def visit_ImportFrom(self, node: ast.ImportFrom) -> None: 

96 module = node.module or "" 

97 base = module.split(".")[0] 

98 if base not in self.allowed_modules: 

99 self.violations.append(f"import from '{module}' not allowed") 

100 

101 def visit_Call(self, node: ast.Call) -> None: 

102 if isinstance(node.func, ast.Name): 

103 if node.func.id in FORBIDDEN_CALLS: 

104 self.violations.append(f"call to '{node.func.id}()' is forbidden") 

105 elif isinstance(node.func, ast.Attribute): 

106 parts = [] 

107 curr = node.func 

108 while isinstance(curr, ast.Attribute): 

109 parts.append(curr.attr) 

110 curr = curr.value 

111 if isinstance(curr, ast.Name): 

112 full = f"{curr.id}.{'.'.join(reversed(parts))}" 

113 for forbidden in FORBIDDEN_CALLS: 

114 if full.startswith(forbidden): 

115 self.violations.append(f"call to '{full}()' is forbidden") 

116 break 

117 self.generic_visit(node) 

118 

119 

120def scan_code(code: str, allowed_modules: frozenset) -> List[str]: 

121 try: 

122 tree = ast.parse(code) 

123 except SyntaxError: 

124 return [] 

125 guard = CodeGuard(allowed_modules) 

126 guard.visit(tree) 

127 return guard.violations 

128 

129 

130# ── 受控执行环境 ─────────────────────────────── 

131 

132def safe_exec( 

133 code: str, 

134 tools: Dict[str, Callable], 

135 state: Dict[str, Any], 

136 timeout: float, 

137) -> Tuple[Any, str, Optional[str]]: 

138 from io import StringIO 

139 

140 stdout_capture = StringIO() 

141 old_stdout = sys.stdout 

142 sys.stdout = stdout_capture 

143 result = None 

144 error = None 

145 

146 try: 

147 exec_globals = {"__builtins__": __builtins__} 

148 exec_globals.update(tools) 

149 exec_globals.update({ 

150 "print": lambda *a, **kw: print(*a, **kw), 

151 "__result__": None, 

152 "state": state, 

153 }) 

154 compiled = compile(code, "<code_agent>", "exec") 

155 exec(compiled, exec_globals) 

156 result = exec_globals.get("__result__") 

157 for key in list(exec_globals.keys()): 

158 if key.startswith("_") or key in tools or key in ("state", "print"): 

159 continue 

160 if key not in ("__builtins__",): 

161 state.setdefault("_vars", {})[key] = exec_globals[key] 

162 except Exception as e: 

163 error = f"{type(e).__name__}: {e}\n{traceback.format_exc(limit=3)}" 

164 finally: 

165 sys.stdout = old_stdout 

166 

167 stdout = stdout_capture.getvalue() 

168 if len(stdout) > MAX_OUTPUT_LENGTH: 

169 stdout = stdout[:MAX_OUTPUT_LENGTH] + "\n... [truncated]" 

170 return result, stdout, error 

171 

172 

173# ── 代码生成 Prompt ───────────────────────────── 

174 

175CODE_AGENT_SYSTEM_PROMPT = """You are a CodeAgent that solves tasks by writing and executing Python code. 

176 

177YOU MUST respond ONLY with Python code inside ```python ... ``` blocks. 

178NO explanations, NO markdown outside the code block. Just the code. 

179 

180Available tools (callable as functions): 

181{tools_description} 

182 

183To output the final answer, assign it to the variable `__result__`. 

184You can store persistent data in the `state` dict. 

185 

186Example: 

187```python 

188# Use tools 

189data = web_search("Python 3.12 release date") 

190# Compute 

191result = len(data) 

192# Return 

193__result__ = f"Found {{result}} results" 

194``` 

195 

196Now solve the following task. ONLY output the code block.""" 

197 

198 

199# ── CodeAgent ─────────────────────────────────── 

200 

201class CodeAgent: 

202 """代码执行型 Agent。""" 

203 

204 def __init__( 

205 self, 

206 tools: List[Callable] | None = None, 

207 model: str = "gpt-4o", 

208 max_steps: int = 10, 

209 timeout_per_step: float = 30.0, 

210 allowed_modules: frozenset = DEFAULT_ALLOWED_MODULES, 

211 ): 

212 self.model = model 

213 self.max_steps = max_steps 

214 self.timeout_per_step = timeout_per_step 

215 self.allowed_modules = allowed_modules 

216 self._tools: Dict[str, Callable] = {} 

217 if tools: 

218 for tool in tools: 

219 self._tools[tool.__name__] = tool 

220 

221 @property 

222 def tools(self) -> Dict[str, Callable]: 

223 return self._tools 

224 

225 def _tools_description(self) -> str: 

226 lines = [] 

227 for name, fn in self._tools.items(): 

228 sig = str(inspect.signature(fn)) 

229 doc = (inspect.getdoc(fn) or "No description").split("\n")[0] 

230 lines.append(f" {name}{sig}: {doc}") 

231 return "\n".join(lines) if lines else " (no tools available)" 

232 

233 async def run(self, task: str, state: Dict[str, Any] | None = None) -> CodeResult: 

234 if state is None: 

235 state = {"_vars": {}} 

236 tools_desc = self._tools_description() 

237 steps: List[CodeStep] = [] 

238 total_start = asyncio.get_event_loop().time() 

239 

240 for step_num in range(1, self.max_steps + 1): 

241 if step_num == 1: 

242 user_prompt = task 

243 else: 

244 last = steps[-1] 

245 if last.error: 

246 feedback = f"Error: {last.error}" 

247 else: 

248 rp = str(last.result)[:500] if last.result is not None else "None" 

249 op = last.stdout[:500] if last.stdout else "" 

250 feedback = f"Output: {op}\nResult: {rp}" 

251 user_prompt = f"Step {step_num}: Continue.\nPrevious result:\n{feedback}\n\nTask: {task}" 

252 

253 router = ModelRouter() 

254 try: 

255 response = await router.chat( 

256 model=self.model, 

257 messages=[ 

258 {"role": "system", "content": CODE_AGENT_SYSTEM_PROMPT.format(tools_description=tools_desc)}, 

259 {"role": "user", "content": user_prompt}, 

260 ], 

261 temperature=0.0, 

262 max_tokens=2048, 

263 ) 

264 except Exception as e: 

265 return CodeResult( 

266 success=False, steps=steps, 

267 total_duration_ms=(asyncio.get_event_loop().time() - total_start) * 1000, 

268 error=f"LLM error: {e}", 

269 ) 

270 

271 code = self._extract_code(response.content) 

272 if not code: 

273 if steps: 

274 return CodeResult( 

275 success=True, final_answer=steps[-1].result, steps=steps, 

276 total_duration_ms=(asyncio.get_event_loop().time() - total_start) * 1000, 

277 ) 

278 continue 

279 

280 violations = scan_code(code, self.allowed_modules) 

281 if violations: 

282 steps.append(CodeStep(step=step_num, code=code, 

283 error=f"Security violation: {'; '.join(violations)}")) 

284 continue 

285 

286 step_start = asyncio.get_event_loop().time() 

287 try: 

288 loop = asyncio.get_event_loop() 

289 result, stdout, error = await asyncio.wait_for( 

290 loop.run_in_executor(None, safe_exec, code, self._tools, state, self.timeout_per_step), 

291 timeout=self.timeout_per_step + 5, 

292 ) 

293 except asyncio.TimeoutError: 

294 result, stdout, error = None, "", "TimeoutError: exceeded limit" 

295 

296 step_duration = (asyncio.get_event_loop().time() - step_start) * 1000 

297 cs = CodeStep(step=step_num, code=code, result=result, stdout=stdout, error=error, duration_ms=step_duration) 

298 steps.append(cs) 

299 

300 if error: 

301 continue 

302 

303 if "__result__" in code or (result is not None and "__result__" in code): 

304 return CodeResult( 

305 success=True, final_answer=result, steps=steps, 

306 total_duration_ms=(asyncio.get_event_loop().time() - total_start) * 1000, 

307 ) 

308 

309 # heuristic: non-trivial result without error = likely done 

310 if result is not None and step_num >= 1: 

311 return CodeResult( 

312 success=True, final_answer=result, steps=steps, 

313 total_duration_ms=(asyncio.get_event_loop().time() - total_start) * 1000, 

314 ) 

315 

316 return CodeResult( 

317 success=False, final_answer=steps[-1].result if steps else None, steps=steps, 

318 total_duration_ms=(asyncio.get_event_loop().time() - total_start) * 1000, 

319 error=f"Max steps ({self.max_steps}) reached", 

320 ) 

321 

322 @staticmethod 

323 def _extract_code(content: str) -> Optional[str]: 

324 if "```python" in content: 

325 parts = content.split("```python", 1) 

326 if len(parts) > 1: 

327 return parts[1].split("```", 1)[0].strip() 

328 if "```" in content: 

329 parts = content.split("```", 1) 

330 if len(parts) > 1: 

331 return parts[1].split("```", 1)[0].strip() 

332 if "print(" in content or "def " in content or "result" in content: 

333 return content.strip() 

334 return None