Coverage for agentos/feedback/learner.py: 39%
95 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2AgentOS v0.30 反馈学习系统 — Human-in-the-loop + RLHF hooks。
3支持人工评分、偏好学习、持续改进。
4"""
6from dataclasses import dataclass, field
7from datetime import datetime
8from enum import Enum
9from typing import Optional
10import json
11import os
14class FeedbackType(str, Enum):
16 """反馈类型枚举。"""
18 THUMB = "thumb" # 点赞/踩
19 RATING = "rating" # 1-5星
20 CORRECTIVE = "corrective" # 纠正指令
21 PREFERENCE = "preference" # A/B偏好
22 DETAILED = "detailed" # 详细评价
25@dataclass
26class FeedbackRecord:
27 """反馈记录。"""
28 session_id: str
29 iteration: int
30 feedback_type: FeedbackType
31 content: str # 反馈内容或评分
32 original_output: str = ""
33 corrected_output: str = ""
34 timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
35 metadata: dict = field(default_factory=dict)
38class FeedbackCollector:
39 """反馈收集器 — HITL反馈入口。"""
41 def __init__(self, storage_path: str = "./feedback_data.jsonl"):
42 self.storage_path = storage_path
43 self._records: list[FeedbackRecord] = []
44 self._callbacks: list[callable] = []
45 if storage_path and os.path.exists(storage_path):
46 self._load()
48 def collect(self, record: FeedbackRecord):
49 self._records.append(record)
50 self._save()
51 for cb in self._callbacks:
52 cb(record)
54 def collect_thumbs(self, session_id: str, iteration: int, up: bool):
55 self.collect(FeedbackRecord(
56 session_id=session_id,
57 iteration=iteration,
58 feedback_type=FeedbackType.THUMB,
59 content="up" if up else "down",
60 ))
62 def collect_rating(self, session_id: str, iteration: int, rating: int, comment: str = ""):
63 self.collect(FeedbackRecord(
64 session_id=session_id,
65 iteration=iteration,
66 feedback_type=FeedbackType.RATING,
67 content=str(rating),
68 metadata={"comment": comment},
69 ))
71 def collect_corrective(self, session_id: str, iteration: int, correction: str, original: str = ""):
72 self.collect(FeedbackRecord(
73 session_id=session_id,
74 iteration=iteration,
75 feedback_type=FeedbackType.CORRECTIVE,
76 content=correction,
77 original_output=original,
78 ))
80 def on_feedback(self, callback):
81 self._callbacks.append(callback)
83 def stats(self) -> dict:
84 thumbs = {"up": 0, "down": 0}
85 ratings = []
86 corrective = 0
87 for r in self._records:
88 if r.feedback_type == FeedbackType.THUMB:
89 if r.content == "up":
90 thumbs["up"] += 1
91 else:
92 thumbs["down"] += 1
93 elif r.feedback_type == FeedbackType.RATING:
94 ratings.append(int(r.content))
95 elif r.feedback_type == FeedbackType.CORRECTIVE:
96 corrective += 1
97 return {
98 "total": len(self._records),
99 "thumbs_up": thumbs["up"],
100 "thumbs_down": thumbs["down"],
101 "avg_rating": sum(ratings) / len(ratings) if ratings else 0.0,
102 "corrective_count": corrective,
103 "satisfaction": thumbs["up"] / max(thumbs["up"] + thumbs["down"], 1),
104 }
106 def _save(self):
107 if not self.storage_path:
108 return
109 os.makedirs(os.path.dirname(self.storage_path) or ".", exist_ok=True)
110 with open(self.storage_path, "a") as f:
111 for r in self._records[-1:]:
112 f.write(json.dumps({
113 "session_id": r.session_id,
114 "iteration": r.iteration,
115 "feedback_type": r.feedback_type.value,
116 "content": r.content,
117 "original_output": r.original_output,
118 "corrected_output": r.corrected_output,
119 "timestamp": r.timestamp,
120 "metadata": r.metadata,
121 }, ensure_ascii=False) + "\n")
123 def _load(self):
124 with open(self.storage_path) as f:
125 for line in f:
126 line = line.strip()
127 if not line:
128 continue
129 d = json.loads(line)
130 self._records.append(FeedbackRecord(
131 session_id=d["session_id"],
132 iteration=d["iteration"],
133 feedback_type=FeedbackType(d["feedback_type"]),
134 content=d["content"],
135 original_output=d.get("original_output", ""),
136 corrected_output=d.get("corrected_output", ""),
137 timestamp=d.get("timestamp", ""),
138 metadata=d.get("metadata", {}),
139 ))
142class PreferenceLearner:
143 """偏好学习器 — 从反馈中提取改进信号。"""
145 def __init__(self, window_size: int = 100):
146 self.window_size = window_size
147 self._recent_patterns: list[dict] = []
149 def learn_from_feedback(self, record: FeedbackRecord):
150 """从单条反馈中学习。"""
151 pattern = {
152 "type": record.feedback_type.value,
153 "content": record.content[:200],
154 "session": record.session_id,
155 }
156 self._recent_patterns.append(pattern)
157 if len(self._recent_patterns) > self.window_size:
158 self._recent_patterns = self._recent_patterns[-self.window_size:]
160 def get_improvement_hints(self) -> list[str]:
161 """获取改进建议。"""
162 hints = []
163 corrections = [r for r in self._recent_patterns if r["type"] == "corrective"]
164 if corrections:
165 hints.append(f"最近 {len(corrections)} 条纠正反馈,建议调整输出风格")
166 thumbs_down = sum(1 for r in self._recent_patterns if r["type"] == "thumb" and r["content"] == "down")
167 thumbs_up = sum(1 for r in self._recent_patterns if r["type"] == "thumb" and r["content"] == "up")
168 if thumbs_down > thumbs_up:
169 hints.append("近期满意度下降,建议优化响应质量")
170 return hints
172 def should_retrain(self, threshold: float = 0.3) -> bool:
173 """判断是否应该触发模型微调。"""
174 if not self._recent_patterns:
175 return False
176 negative = sum(1 for r in self._recent_patterns if r["type"] in ("thumb", "corrective"))
177 return negative / len(self._recent_patterns) > threshold