Coverage for agentos/swarm/result_fusion.py: 17%
136 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2v1.9.4: LLM-as-Judge Result Fusion engine.
4Aggregates multiple agent outputs with weighted fusion,
5confidence scoring, and LLM-as-Judge quality arbitration.
6"""
8from __future__ import annotations
10from dataclasses import dataclass, field
11from typing import Any
13import json as _json
14import math
17_JUDGE_PROMPT = """You are an expert quality judge. Given a task and multiple candidate results,
18select the best result or synthesize a combined result.
20Task: {task}
22Candidates:
23{candidates}
25Instructions:
261. Evaluate each candidate for correctness, completeness, and clarity
272. If one candidate is clearly best, output: {{"action": "select", "best_index": N, "reason": "..."}}
283. If candidates complement each other, output: {{"action": "merge", "merged": "...", "reason": "..."}}
294. If all candidates are poor, output: {{"action": "reject", "reason": "..."}}
31Output ONLY the JSON object, no other text.
32JSON:"""
35@dataclass
36class FusedResult:
37 """Result of fusion operation."""
39 merged: Any = None
40 best_index: int = -1
41 confidence: float = 0.0
42 action: str = "none" # select | merge | reject
43 reason: str = ""
44 individual_scores: dict[str, float] = field(default_factory=dict)
45 all_outputs: dict[str, Any] = field(default_factory=dict)
48class ResultFusion:
49 """LLM-as-Judge result aggregation engine.
51 Combines outputs from multiple agents with:
52 - Weighted-vote aggregation
53 - LLM-as-Judge for quality arbitration
54 - Confidence scoring
55 """
57 def __init__(
58 self,
59 strategy: str = "auto",
60 llm_model: str = "gpt-4o-mini",
61 ):
62 self._strategy = strategy
63 self._llm_model = llm_model
65 def fuse(
66 self,
67 task: str,
68 outputs: dict[str, Any],
69 weights: dict[str, float] | None = None,
70 ) -> FusedResult:
71 """Fuse multiple agent outputs into a single result.
73 Args:
74 task: Original task description
75 outputs: Dict of agent_name -> agent_output
76 weights: Optional dict of agent_name -> weight (default: equal)
78 Returns:
79 FusedResult with merged output and confidence
80 """
81 if not outputs:
82 result = FusedResult(action="reject", reason="No outputs to fuse")
83 result.individual_scores = {}
84 result.all_outputs = {}
85 return result
87 if len(outputs) == 1:
88 name, value = next(iter(outputs.items()))
89 result = FusedResult(
90 merged=value,
91 best_index=0,
92 confidence=0.7,
93 action="select",
94 reason="Single output",
95 individual_scores={name: 0.7},
96 all_outputs=outputs,
97 )
98 return result
100 weights = weights or {k: 1.0 for k in outputs}
102 # Step 1: Compute individual scores
103 scores = self._compute_scores(outputs, weights)
105 # Step 2: Try LLM judge for arbitration
106 llm_result = self._llm_judge(task, outputs)
107 if llm_result:
108 return llm_result
110 # Step 3: Fallback — weighted aggregation
111 return self._weighted_aggregate(outputs, scores)
113 def _compute_scores(
114 self,
115 outputs: dict[str, Any],
116 weights: dict[str, float],
117 ) -> dict[str, float]:
118 """Score each output for quality heuristics."""
119 scores: dict[str, float] = {}
120 for name, output in outputs.items():
121 base = float(weights.get(name, 1.0))
122 quality = self._quality_heuristic(output)
123 scores[name] = round(base * quality, 3)
124 return scores
126 def _quality_heuristic(self, output: Any) -> float:
127 """Heuristic quality score based on output characteristics."""
128 score = 0.5 # baseline
130 text = str(output) if output is not None else ""
132 if not text:
133 return 0.1
135 # Length heuristic: too short is suspicious, reasonable length is good
136 length = len(text)
137 if 100 < length < 2000:
138 score += 0.15
139 elif 50 <= length <= 100:
140 score += 0.05
141 elif length > 5000:
142 score += 0.05
144 # Error patterns
145 error_keywords = ["error", "exception", "traceback", "failed", "错误", "失败"]
146 for kw in error_keywords:
147 if kw in text:
148 score -= 0.15
149 break
151 # Structure bonus
152 if any(marker in text for marker in ("```", "##", "# ", "**", "<table")):
153 score += 0.1
155 # Confidence keywords
156 confidence_keywords = ["recommend", "建议", "recommendation", "conclusion"]
157 for kw in confidence_keywords:
158 if kw in text:
159 score += 0.05
161 return max(0.0, min(1.0, score))
163 def _llm_judge(
164 self, task: str, outputs: dict[str, Any]
165 ) -> FusedResult | None:
166 """Use LLM to judge and fuse results. Returns None on failure."""
167 try:
168 import os
169 api_key = os.environ.get("OPENAI_API_KEY", "")
170 if not api_key:
171 return None
173 candidates_str = "\n".join(
174 f"[{i}] {name}: {str(output)[:300]}"
175 for i, (name, output) in enumerate(outputs.items())
176 )
178 prompt = _JUDGE_PROMPT.format(
179 task=task,
180 candidates=candidates_str,
181 )
183 import requests
184 resp = requests.post(
185 "https://api.openai.com/v1/chat/completions",
186 headers={"Authorization": f"Bearer {api_key}",
187 "Content-Type": "application/json"},
188 json={
189 "model": self._llm_model,
190 "messages": [{"role": "user", "content": prompt}],
191 "temperature": 0.0,
192 "max_tokens": 500,
193 },
194 timeout=30,
195 )
196 if resp.status_code != 200:
197 return None
199 text = resp.json()["choices"][0]["message"]["content"]
201 start = text.find("{")
202 end = text.rfind("}") + 1
203 if start == -1 or end == 0:
204 return None
206 data = _json.loads(text[start:end])
207 action = data.get("action", "reject")
209 agent_names = list(outputs.keys())
211 result = FusedResult()
212 result.action = action
213 result.reason = data.get("reason", "")
214 result.all_outputs = {
215 k: str(v)[:200] for k, v in outputs.items()
216 }
217 result.individual_scores = {
218 k: 0.5 for k in outputs
219 }
221 if action == "select":
222 idx = int(data.get("best_index", 0))
223 idx = max(0, min(idx, len(agent_names) - 1))
224 result.best_index = idx
225 result.merged = outputs[agent_names[idx]]
226 result.confidence = 0.8
227 elif action == "merge":
228 result.merged = data.get("merged", "")
229 result.confidence = 0.75
230 result.best_index = -1
231 else: # reject
232 result.merged = None
233 result.confidence = 0.0
235 return result
236 except Exception:
237 return None
239 def _weighted_aggregate(
240 self,
241 outputs: dict[str, Any],
242 scores: dict[str, float],
243 ) -> FusedResult:
244 """Weighted vote aggregation fallback."""
245 max_score = max(scores.values()) if scores else 0.0
246 if max_score == 0:
247 return FusedResult(action="reject", reason="All outputs scored zero")
249 # Find best candidate
250 best_name = max(scores, key=scores.get) # type: ignore[arg-type]
252 # Normalize scores to confidence
253 total = sum(scores.values())
254 confidence = max_score / total if total > 0 else 0.3
256 # Check for consensus: if all string outputs are similar, merge them
257 all_str = [str(v) for v in outputs.values()]
258 consensus = self._check_consensus(all_str)
260 if consensus:
261 return FusedResult(
262 merged=outputs[best_name],
263 best_index=list(outputs.keys()).index(best_name),
264 confidence=confidence,
265 action="select",
266 reason="Consensus among outputs",
267 individual_scores=scores,
268 all_outputs={k: str(v)[:200] for k, v in outputs.items()},
269 )
271 return FusedResult(
272 merged=outputs[best_name],
273 best_index=list(outputs.keys()).index(best_name),
274 confidence=confidence,
275 action="select",
276 reason="Weighted voting (no consensus)",
277 individual_scores=scores,
278 all_outputs={k: str(v)[:200] for k, v in outputs.items()},
279 )
281 def _check_consensus(self, outputs: list[str]) -> bool:
282 """Check if string outputs are similar enough for consensus."""
283 if len(outputs) < 2:
284 return True
286 # Simple overlap ratio
287 words = [set(o.lower().split()) for o in outputs]
288 if any(len(w) == 0 for w in words):
289 return False
291 overlaps = []
292 for i, wi in enumerate(words):
293 for j, wj in enumerate(words):
294 if i >= j:
295 continue
296 if len(wi | wj) == 0:
297 overlaps.append(0.0)
298 else:
299 overlaps.append(len(wi & wj) / len(wi | wj))
301 if not overlaps:
302 return False
304 avg_overlap = sum(overlaps) / len(overlaps)
305 return avg_overlap > 0.4