Coverage for agentos/tests/test_1_1_4_features.py: 0%

186 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1"""v1.1.4 新特性集成测试。""" 

2from __future__ import annotations 

3 

4import pytest 

5 

6 

7# ══════════════════════════════════════════════════════════════════════════════ 

8# ToolRiskRating 测试 

9# ══════════════════════════════════════════════════════════════════════════════ 

10 

11class TestToolRiskRating: 

12 def test_risk_level_enum(self): 

13 from agentos.tools.risk import ToolRiskLevel 

14 assert ToolRiskLevel.LOW.value == "low" 

15 assert ToolRiskLevel.CRITICAL.value == "critical" 

16 assert len(list(ToolRiskLevel)) == 4 

17 

18 def test_risk_rating_defaults(self): 

19 from agentos.tools.risk import ToolRiskRating 

20 r = ToolRiskRating() 

21 assert r.level.value == "medium" 

22 assert r.reversible is True 

23 assert r.requires_approval is False 

24 assert r.requires_user_confirm() is False 

25 

26 def test_requires_confirm_high(self): 

27 from agentos.tools.risk import ToolRiskRating, ToolRiskLevel 

28 r = ToolRiskRating(level=ToolRiskLevel.HIGH) 

29 assert r.requires_user_confirm() is True 

30 

31 def test_requires_confirm_critical(self): 

32 from agentos.tools.risk import ToolRiskRating, ToolRiskLevel 

33 r = ToolRiskRating(level=ToolRiskLevel.CRITICAL) 

34 assert r.requires_user_confirm() is True 

35 

36 def test_requires_confirm_financial(self): 

37 from agentos.tools.risk import ToolRiskRating 

38 r = ToolRiskRating(financial_impact=True) 

39 assert r.requires_user_confirm() is True 

40 

41 def test_get_risk_preset_list_files(self): 

42 from agentos.tools.risk import get_risk_preset, ToolRiskLevel 

43 r = get_risk_preset("list_files") 

44 assert r is not None 

45 assert r.level == ToolRiskLevel.LOW 

46 

47 def test_get_risk_preset_delete_file(self): 

48 from agentos.tools.risk import get_risk_preset, ToolRiskLevel 

49 r = get_risk_preset("delete_file") 

50 assert r is not None 

51 assert r.level == ToolRiskLevel.HIGH 

52 assert r.requires_approval is True 

53 

54 def test_get_risk_preset_payment(self): 

55 from agentos.tools.risk import get_risk_preset, ToolRiskLevel 

56 r = get_risk_preset("execute_payment") 

57 assert r is not None 

58 assert r.level == ToolRiskLevel.CRITICAL 

59 assert r.financial_impact is True 

60 

61 def test_get_risk_preset_case_insensitive(self): 

62 from agentos.tools.risk import get_risk_preset 

63 assert get_risk_preset("DELETE_FILE") is not None 

64 

65 def test_infer_risk_level_keyword_delete(self): 

66 from agentos.tools.risk import infer_risk_level, ToolRiskLevel 

67 r = infer_risk_level("purge_records", "delete all records") 

68 assert r.level == ToolRiskLevel.HIGH 

69 

70 def test_infer_risk_level_keyword_write(self): 

71 from agentos.tools.risk import infer_risk_level, ToolRiskLevel 

72 r = infer_risk_level("update_profile") 

73 assert r.level == ToolRiskLevel.MEDIUM 

74 

75 def test_infer_risk_level_default(self): 

76 from agentos.tools.risk import infer_risk_level, ToolRiskLevel 

77 r = infer_risk_level("get_status") 

78 assert r.level == ToolRiskLevel.LOW 

79 

80 

81# ══════════════════════════════════════════════════════════════════════════════ 

82# Middleware Pipeline 测试 

83# ══════════════════════════════════════════════════════════════════════════════ 

84 

85class TestMiddlewarePipeline: 

86 @pytest.mark.asyncio 

87 async def test_empty_pipeline_allows(self): 

88 from agentos.core.middleware import MiddlewarePipeline, MiddlewarePhase, MiddlewareContext 

89 pipe = MiddlewarePipeline() 

90 ctx = MiddlewareContext(phase=MiddlewarePhase.PRE_LLM, prompt="hello") 

91 decision = await pipe.pre_llm(ctx) 

92 assert decision.allow is True 

93 

94 @pytest.mark.asyncio 

95 async def test_blocking_middleware(self): 

96 from agentos.core.middleware import ( 

97 AgentMiddleware, MiddlewarePipeline, MiddlewarePhase, 

98 MiddlewareContext, MiddlewareDecision, 

99 ) 

100 

101 class Blocker(AgentMiddleware): 

102 name = "blocker" 

103 @property 

104 def phases(self): 

105 return [MiddlewarePhase.PRE_LLM] 

106 async def process(self, ctx): 

107 return MiddlewareDecision(allow=False, reason="blocked by test", action="block") 

108 

109 pipe = MiddlewarePipeline([Blocker()]) 

110 ctx = MiddlewareContext(phase=MiddlewarePhase.PRE_LLM, prompt="test") 

111 decision = await pipe.pre_llm(ctx) 

112 assert decision.allow is False 

113 assert "blocked by test" in decision.reason 

114 

115 @pytest.mark.asyncio 

116 async def test_transform_middleware(self): 

117 from agentos.core.middleware import ( 

118 AgentMiddleware, MiddlewarePipeline, MiddlewarePhase, 

119 MiddlewareContext, MiddlewareDecision, 

120 ) 

121 

122 class UpperCaseTransform(AgentMiddleware): 

123 name = "upper" 

124 @property 

125 def phases(self): 

126 return [MiddlewarePhase.PRE_LLM] 

127 async def process(self, ctx): 

128 if ctx.prompt: 

129 new_ctx = MiddlewareContext(**{**ctx.__dict__}) 

130 new_ctx.prompt = ctx.prompt.upper() 

131 return MiddlewareDecision(allow=True, action="transform", modified_context=new_ctx) 

132 return MiddlewareDecision(allow=True) 

133 

134 pipe = MiddlewarePipeline([UpperCaseTransform()]) 

135 ctx = MiddlewareContext(phase=MiddlewarePhase.PRE_LLM, prompt="hello") 

136 decision = await pipe.pre_llm(ctx) 

137 assert decision.allow is True 

138 assert decision.modified_context is not None 

139 assert decision.modified_context.prompt == "HELLO" 

140 

141 @pytest.mark.asyncio 

142 async def test_chain_add(self): 

143 from agentos.core.middleware import MiddlewarePipeline 

144 pipe = MiddlewarePipeline() 

145 from agentos.core.middleware import AuditLogMiddleware 

146 pipe.add(AuditLogMiddleware()) 

147 assert "audit_log" in pipe.middleware_names 

148 

149 @pytest.mark.asyncio 

150 async def test_remove(self): 

151 from agentos.core.middleware import MiddlewarePipeline, AuditLogMiddleware 

152 pipe = MiddlewarePipeline([AuditLogMiddleware()]) 

153 pipe.remove("audit_log") 

154 assert "audit_log" not in pipe.middleware_names 

155 

156 @pytest.mark.asyncio 

157 async def test_phase_filtering(self): 

158 from agentos.core.middleware import ( 

159 MiddlewarePipeline, MiddlewarePhase, MiddlewareContext, PIIMaskingMiddleware, 

160 ) 

161 pipe = MiddlewarePipeline([PIIMaskingMiddleware()]) 

162 # PIIMaskingMiddleware only listens on PRE_LLM 

163 ctx = MiddlewareContext(phase=MiddlewarePhase.PRE_TOOL, tool_name="test") 

164 decision = await pipe.pre_tool(ctx) 

165 assert decision.allow is True # It should pass through since no middleware listens 

166 

167 

168# ══════════════════════════════════════════════════════════════════════════════ 

169# Enhanced CostTracker + RunCostSession 测试 

170# ══════════════════════════════════════════════════════════════════════════════ 

171 

172class TestRunCostSession: 

173 def test_session_lifecycle(self): 

174 import time 

175 from agentos.cost.tracker import RunCostSession 

176 session = RunCostSession(run_id="test-123") 

177 assert session.run_id == "test-123" 

178 assert session.call_count == 0 

179 assert session.total_cost == 0 

180 

181 # Record some usage 

182 from agentos.cost.tracker import UsageRecord 

183 session.records.append(UsageRecord( 

184 model="deepseek-v3.1", input_tokens=1000, output_tokens=500, 

185 cost_usd=0.01, run_id="test-123", 

186 )) 

187 session.records.append(UsageRecord( 

188 model="deepseek-v3.1", input_tokens=2000, output_tokens=800, 

189 cost_usd=0.02, run_id="test-123", 

190 )) 

191 session.finished_at = time.time() 

192 

193 assert session.call_count == 2 

194 assert session.total_cost == 0.03 

195 assert session.duration_seconds > 0 

196 

197 def test_total_tokens(self): 

198 from agentos.cost.tracker import RunCostSession, UsageRecord 

199 session = RunCostSession(run_id="t") 

200 session.records.append(UsageRecord(model="m", input_tokens=100, output_tokens=50, cost_usd=0.0)) 

201 session.records.append(UsageRecord(model="m", input_tokens=200, output_tokens=100, cost_usd=0.0)) 

202 assert session.total_tokens == {"input": 300, "output": 150, "total": 450} 

203 

204 

205class TestCostTrackerEnhanced: 

206 def test_start_end_session(self): 

207 from agentos.cost.tracker import CostTracker 

208 tracker = CostTracker() 

209 rid = tracker.start_session() 

210 assert len(tracker.active_sessions) == 1 

211 session = tracker.end_session(rid) 

212 assert session is not None 

213 assert session.finished_at is not None 

214 assert len(tracker.active_sessions) == 0 

215 

216 def test_record_with_session(self): 

217 from agentos.cost.tracker import CostTracker 

218 tracker = CostTracker() 

219 rid = tracker.start_session() 

220 tracker.record("deepseek-v3.1", {"prompt_tokens": 1000, "completion_tokens": 500}, run_id=rid) 

221 tracker.record("deepseek-v3.1", {"prompt_tokens": 500, "completion_tokens": 200}, run_id=rid) 

222 session = tracker.end_session(rid) 

223 assert session.call_count == 2 

224 assert session.total_cost > 0 

225 

226 def test_cost_by_session(self): 

227 from agentos.cost.tracker import CostTracker 

228 tracker = CostTracker() 

229 r1 = tracker.start_session() 

230 r2 = tracker.start_session() 

231 tracker.record("deepseek-v3.1", {"prompt_tokens": 1000, "completion_tokens": 100}, run_id=r1) 

232 tracker.record("deepseek-v3.1", {"prompt_tokens": 500, "completion_tokens": 50}, run_id=r2) 

233 costs = tracker.cost_by_session() 

234 assert r1 in costs 

235 assert r2 in costs 

236 assert costs[r2] < costs[r1] 

237 

238 def test_get_session_active_and_completed(self): 

239 from agentos.cost.tracker import CostTracker 

240 tracker = CostTracker() 

241 rid = tracker.start_session() 

242 assert tracker.get_session(rid) is not None 

243 tracker.end_session(rid) 

244 assert tracker.get_session(rid) is not None # Should find completed 

245 

246 def test_record_with_cache(self): 

247 from agentos.cost.tracker import CostTracker 

248 tracker = CostTracker() 

249 cost = tracker.record_with_cache("deepseek-v3.1", 1000, 500) 

250 assert cost > 0 

251 assert tracker.total_cost == cost