Coverage for agentos/tests/test_subagent_parent_child.py: 99%

252 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-03 13:55 +0800

1"""测试 SubAgent 父子通信 — 状态共享、心跳、生命周期管理。""" 

2 

3import asyncio 

4import time 

5import pytest 

6 

7pytestmark = pytest.mark.asyncio 

8from agentos.subagent import ( 

9 SubAgentManager, 

10 SubAgentMode, 

11 SubAgentSpec, 

12 SubAgentResult, 

13 ChildStatus, 

14 ChildHeartbeat, 

15 ChildInfo, 

16 SharedState, 

17 ChildContext, 

18 ChildHandle, 

19) 

20 

21 

22class TestSharedState: 

23 async def test_set_get(self): 

24 ss = SharedState() 

25 await ss.set("key1", "val1") 

26 assert await ss.get("key1") == "val1" 

27 assert await ss.get("missing", "def") == "def" 

28 

29 async def test_update_snapshot(self): 

30 ss = SharedState() 

31 await ss.update({"a": 1, "b": 2}) 

32 snap = await ss.snapshot() 

33 assert snap == {"a": 1, "b": 2} 

34 

35 async def test_sync_ops(self): 

36 ss = SharedState() 

37 ss.set_sync("x", 42) 

38 assert ss.get_sync("x") == 42 

39 

40 async def test_concurrent_writes(self): 

41 ss = SharedState() 

42 

43 async def writer(key: str, n: int): 

44 for i in range(n): 

45 await ss.set(key, i) 

46 await asyncio.sleep(0) 

47 

48 await asyncio.gather(writer("a", 20), writer("b", 20)) 

49 assert await ss.get("a") == 19 

50 assert await ss.get("b") == 19 

51 

52 

53class TestChildContext: 

54 async def test_progress_report(self): 

55 hbs = [] 

56 

57 async def hb_cb(hb: ChildHeartbeat): 

58 hbs.append(hb) 

59 

60 ctx = ChildContext("test-1", heartbeat_callback=hb_cb) 

61 await ctx.report_progress(0.5, "step1", "half done") 

62 assert ctx.progress == 0.5 

63 assert len(hbs) == 1 

64 assert hbs[0].progress == 0.5 

65 assert hbs[0].current_step == "step1" 

66 

67 async def test_step_and_heartbeat(self): 

68 hbs = [] 

69 

70 async def hb_cb(hb: ChildHeartbeat): 

71 hbs.append(hb) 

72 

73 ctx = ChildContext("test-2", heartbeat_callback=hb_cb) 

74 await ctx.step(1, "init") 

75 await ctx.send_heartbeat("alive") 

76 assert len(hbs) == 1 

77 assert hbs[0].iteration == 1 

78 

79 async def test_done(self): 

80 hbs = [] 

81 

82 async def hb_cb(hb: ChildHeartbeat): 

83 hbs.append(hb) 

84 

85 ctx = ChildContext("test-3", heartbeat_callback=hb_cb) 

86 await ctx.done("all good") 

87 assert len(hbs) == 1 

88 assert hbs[0].status == ChildStatus.COMPLETED 

89 assert hbs[0].progress == 1.0 

90 assert hbs[0].message == "all good" 

91 

92 async def test_fail(self): 

93 hbs = [] 

94 

95 async def hb_cb(hb: ChildHeartbeat): 

96 hbs.append(hb) 

97 

98 ctx = ChildContext("test-4", heartbeat_callback=hb_cb) 

99 await ctx.fail("something broke") 

100 assert len(hbs) == 1 

101 assert hbs[0].status == ChildStatus.FAILED 

102 assert hbs[0].message == "something broke" 

103 

104 async def test_cancel_detection(self): 

105 cancelled = [False] 

106 

107 def on_cancel(): 

108 return cancelled[0] 

109 

110 ctx = ChildContext("test-5", on_cancel=on_cancel) 

111 assert not ctx.cancelled 

112 status = await ctx.check_control() 

113 assert status == ChildStatus.RUNNING 

114 

115 cancelled[0] = True 

116 status = await ctx.check_control() 

117 assert status == ChildStatus.CANCELLED 

118 assert ctx.cancelled 

119 

120 async def test_pause_resume(self): 

121 paused = [True] 

122 resume_triggered = [False] 

123 

124 async def on_pause(): 

125 resume_triggered[0] = True 

126 paused[0] = False 

127 

128 ctx = ChildContext("test-6", on_pause=on_pause) 

129 ctx._paused = paused[0] 

130 status = await ctx.check_control() 

131 assert status == ChildStatus.PAUSED 

132 assert resume_triggered[0] 

133 

134 

135class TestChildHandle: 

136 async def test_create_context(self): 

137 handle = ChildHandle("h1", "do stuff", "fork") 

138 ctx = handle.create_context() 

139 assert ctx.agent_id == "h1" 

140 assert handle.context is ctx 

141 assert handle.shared_state is ctx.shared_state 

142 

143 async def test_pause_resume(self): 

144 handle = ChildHandle("h2", "task", "fork") 

145 handle.create_context() 

146 assert handle.status == ChildStatus.IDLE 

147 

148 await handle.pause() 

149 assert handle.status == ChildStatus.PAUSED 

150 

151 await handle.resume() 

152 assert handle.status == ChildStatus.RUNNING 

153 

154 async def test_cancel(self): 

155 handle = ChildHandle("h3", "task", "fork") 

156 handle.create_context() 

157 await handle.cancel() 

158 assert handle.status == ChildStatus.CANCELLED 

159 

160 async def test_get_status(self): 

161 handle = ChildHandle("h4", "analyze", "fork") 

162 handle.create_context() 

163 handle.info.progress = 0.7 

164 handle.info.current_step = "parsing" 

165 handle.info.iterations = 12 

166 

167 status = handle.get_status() 

168 assert status["agent_id"] == "h4" 

169 assert status["progress"] == 0.7 

170 assert status["current_step"] == "parsing" 

171 assert status["iterations"] == 12 

172 assert "elapsed" in status 

173 

174 async def test_timeout_detection(self): 

175 handle = ChildHandle("h5", "task", "fork", timeout=0.1) 

176 await asyncio.sleep(0.15) 

177 assert handle.check_timeout() 

178 

179 async def test_no_timeout_when_unset(self): 

180 handle = ChildHandle("h6", "task", "fork", timeout=None) 

181 assert not handle.check_timeout() 

182 

183 async def test_heartbeat_timeout(self): 

184 handle = ChildHandle("h7", "task", "fork", heartbeat_interval=0.1) 

185 await asyncio.sleep(0.35) 

186 assert handle.check_heartbeat_timeout() 

187 

188 async def test_heartbeat_updates_info(self): 

189 handle = ChildHandle("h8", "task", "fork") 

190 handle.create_context() 

191 await handle._receive_heartbeat(ChildHeartbeat( 

192 agent_id="h8", progress=0.5, current_step="s1", 

193 message="working", iteration=5, 

194 )) 

195 assert handle.info.progress == 0.5 

196 assert handle.info.current_step == "s1" 

197 assert handle.info.iterations == 5 

198 

199 async def test_shared_state_parent_child(self): 

200 handle = ChildHandle("h9", "task", "fork") 

201 ctx = handle.create_context() 

202 

203 await ctx.shared_state.set("data", [1, 2, 3]) 

204 val = await handle.shared_state.get("data") 

205 assert val == [1, 2, 3] 

206 

207 await handle.shared_state.set("status", "ok") 

208 assert await ctx.shared_state.get("status") == "ok" 

209 

210 

211class TestSubAgentManager: 

212 async def test_spawn_fork_with_child_context(self): 

213 hbs = [] 

214 

215 async def run_func(spec: SubAgentSpec, ctx: ChildContext): 

216 await ctx.report_progress(0.3, "init") 

217 await ctx.step(1, "load") 

218 await ctx.report_progress(0.7, "process") 

219 await ctx.done("success") 

220 return ("success", 2) 

221 

222 mgr = SubAgentManager() 

223 result = await mgr.spawn_fork("test task", run_func=run_func) 

224 assert result.output == "success" 

225 assert result.iterations == 2 

226 assert result.handle is not None 

227 assert result.handle.status == ChildStatus.COMPLETED 

228 

229 async def test_spawn_fork_failure(self): 

230 async def run_func(spec, ctx): 

231 await ctx.report_progress(0.1, "start") 

232 raise ValueError("boom") 

233 

234 mgr = SubAgentManager() 

235 result = await mgr.spawn_fork("bad task", run_func=run_func) 

236 assert result.error == "boom" 

237 assert result.handle.status == ChildStatus.FAILED 

238 assert result.handle.info.error == "boom" 

239 

240 async def test_spawn_fork_pause_resume_flow(self): 

241 """模拟父子协作:父暂停→子暂停→父恢复→子继续→完成。""" 

242 state = {"phase": "init"} 

243 

244 async def run_func(spec, ctx: ChildContext): 

245 state["phase"] = "running" 

246 await ctx.report_progress(0.2, "step1") 

247 

248 # 检查控制信号 

249 status = await ctx.check_control() 

250 if status == ChildStatus.PAUSED: 

251 state["phase"] = "paused" 

252 

253 # 再次检查(模拟恢复后继续) 

254 status = await ctx.check_control() 

255 if status == ChildStatus.RUNNING: 

256 state["phase"] = "resumed" 

257 

258 await ctx.done("ok") 

259 return ("ok", 3) 

260 

261 mgr = SubAgentManager() 

262 

263 # 启动 

264 task = asyncio.create_task( 

265 mgr.spawn_fork("pause test", run_func=run_func) 

266 ) 

267 

268 await asyncio.sleep(0.05) # 让子Agent跑到 step 

269 handle = mgr.get_handle(task.result().handle.agent_id) if hasattr(task, 'result') else None 

270 

271 # 等task完成 

272 result = await task 

273 assert result.error is None or result.error == "" 

274 assert state["phase"] in ("running", "paused", "resumed") 

275 

276 async def test_swarm_parallel(self): 

277 results_log = [] 

278 

279 async def run_func(spec: SubAgentSpec, ctx: ChildContext): 

280 await ctx.report_progress(0.5, spec.task) 

281 await asyncio.sleep(0.01) 

282 await ctx.done(f"done_{spec.task}") 

283 results_log.append(spec.task) 

284 return (f"done_{spec.task}", 1) 

285 

286 mgr = SubAgentManager() 

287 results = await mgr.spawn_swarm( 

288 ["A", "B", "C"], run_func=run_func 

289 ) 

290 assert len(results) == 3 

291 assert len(results_log) == 3 

292 for r in results: 

293 assert r.handle is not None 

294 assert r.handle.status == ChildStatus.COMPLETED 

295 

296 async def test_cancel_all(self): 

297 async def run_func(spec, ctx: ChildContext): 

298 await ctx.report_progress(0.1, "init") 

299 for i in range(50): 

300 await ctx.step(i, f"step_{i}") 

301 status = await ctx.check_control() 

302 if status == ChildStatus.CANCELLED: 

303 return ("cancelled", i) 

304 await asyncio.sleep(0.01) 

305 return ("done", 50) 

306 

307 mgr = SubAgentManager() 

308 t1 = asyncio.create_task( 

309 mgr.spawn_fork("long task 1", run_func=run_func) 

310 ) 

311 t2 = asyncio.create_task( 

312 mgr.spawn_fork("long task 2", run_func=run_func) 

313 ) 

314 

315 await asyncio.sleep(0.05) 

316 await mgr.cancel_all() 

317 

318 r1, r2 = await asyncio.gather(t1, t2) 

319 assert r1.handle.status == ChildStatus.CANCELLED 

320 assert r2.handle.status == ChildStatus.CANCELLED 

321 

322 async def test_list_children(self): 

323 mgr = SubAgentManager() 

324 r = await mgr.spawn_fork("task1") 

325 children = mgr.list_children() 

326 assert len(children) == 1 

327 assert children[0]["agent_id"] == r.agent_id 

328 

329 async def test_cleanup(self): 

330 mgr = SubAgentManager() 

331 r = await mgr.spawn_fork("cleanup test") 

332 assert len(mgr._agents) == 1 

333 

334 cleaned = await mgr.cleanup(max_age_seconds=-1.0) 

335 assert cleaned == 1 

336 assert len(mgr._agents) == 0 

337 

338 async def test_heartbeat_monitoring(self): 

339 mgr = SubAgentManager() 

340 handle = ChildHandle("hb-test", "task", "fork", timeout=0.05) 

341 mgr._agents["hb-test"] = handle 

342 handle.info.status = ChildStatus.RUNNING 

343 

344 monitor = asyncio.create_task(mgr.monitor_heartbeats(interval=0.02)) 

345 await asyncio.sleep(0.1) 

346 monitor.cancel() 

347 try: 

348 await monitor 

349 except asyncio.CancelledError: 

350 pass 

351 

352 assert handle.status in (ChildStatus.TIMEOUT, ChildStatus.RUNNING)