Coverage for agentos/feedback/learner.py: 39%

95 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2AgentOS v0.30 反馈学习系统 — Human-in-the-loop + RLHF hooks。 

3支持人工评分、偏好学习、持续改进。 

4""" 

5 

6from dataclasses import dataclass, field 

7from datetime import datetime 

8from enum import Enum 

9from typing import Optional 

10import json 

11import os 

12 

13 

14class FeedbackType(str, Enum): 

15 

16 """反馈类型枚举。""" 

17 

18 THUMB = "thumb" # 点赞/踩 

19 RATING = "rating" # 1-5星 

20 CORRECTIVE = "corrective" # 纠正指令 

21 PREFERENCE = "preference" # A/B偏好 

22 DETAILED = "detailed" # 详细评价 

23 

24 

25@dataclass 

26class FeedbackRecord: 

27 """反馈记录。""" 

28 session_id: str 

29 iteration: int 

30 feedback_type: FeedbackType 

31 content: str # 反馈内容或评分 

32 original_output: str = "" 

33 corrected_output: str = "" 

34 timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) 

35 metadata: dict = field(default_factory=dict) 

36 

37 

38class FeedbackCollector: 

39 """反馈收集器 — HITL反馈入口。""" 

40 

41 def __init__(self, storage_path: str = "./feedback_data.jsonl"): 

42 self.storage_path = storage_path 

43 self._records: list[FeedbackRecord] = [] 

44 self._callbacks: list[callable] = [] 

45 if storage_path and os.path.exists(storage_path): 

46 self._load() 

47 

48 def collect(self, record: FeedbackRecord): 

49 self._records.append(record) 

50 self._save() 

51 for cb in self._callbacks: 

52 cb(record) 

53 

54 def collect_thumbs(self, session_id: str, iteration: int, up: bool): 

55 self.collect(FeedbackRecord( 

56 session_id=session_id, 

57 iteration=iteration, 

58 feedback_type=FeedbackType.THUMB, 

59 content="up" if up else "down", 

60 )) 

61 

62 def collect_rating(self, session_id: str, iteration: int, rating: int, comment: str = ""): 

63 self.collect(FeedbackRecord( 

64 session_id=session_id, 

65 iteration=iteration, 

66 feedback_type=FeedbackType.RATING, 

67 content=str(rating), 

68 metadata={"comment": comment}, 

69 )) 

70 

71 def collect_corrective(self, session_id: str, iteration: int, correction: str, original: str = ""): 

72 self.collect(FeedbackRecord( 

73 session_id=session_id, 

74 iteration=iteration, 

75 feedback_type=FeedbackType.CORRECTIVE, 

76 content=correction, 

77 original_output=original, 

78 )) 

79 

80 def on_feedback(self, callback): 

81 self._callbacks.append(callback) 

82 

83 def stats(self) -> dict: 

84 thumbs = {"up": 0, "down": 0} 

85 ratings = [] 

86 corrective = 0 

87 for r in self._records: 

88 if r.feedback_type == FeedbackType.THUMB: 

89 if r.content == "up": 

90 thumbs["up"] += 1 

91 else: 

92 thumbs["down"] += 1 

93 elif r.feedback_type == FeedbackType.RATING: 

94 ratings.append(int(r.content)) 

95 elif r.feedback_type == FeedbackType.CORRECTIVE: 

96 corrective += 1 

97 return { 

98 "total": len(self._records), 

99 "thumbs_up": thumbs["up"], 

100 "thumbs_down": thumbs["down"], 

101 "avg_rating": sum(ratings) / len(ratings) if ratings else 0.0, 

102 "corrective_count": corrective, 

103 "satisfaction": thumbs["up"] / max(thumbs["up"] + thumbs["down"], 1), 

104 } 

105 

106 def _save(self): 

107 if not self.storage_path: 

108 return 

109 os.makedirs(os.path.dirname(self.storage_path) or ".", exist_ok=True) 

110 with open(self.storage_path, "a") as f: 

111 for r in self._records[-1:]: 

112 f.write(json.dumps({ 

113 "session_id": r.session_id, 

114 "iteration": r.iteration, 

115 "feedback_type": r.feedback_type.value, 

116 "content": r.content, 

117 "original_output": r.original_output, 

118 "corrected_output": r.corrected_output, 

119 "timestamp": r.timestamp, 

120 "metadata": r.metadata, 

121 }, ensure_ascii=False) + "\n") 

122 

123 def _load(self): 

124 with open(self.storage_path) as f: 

125 for line in f: 

126 line = line.strip() 

127 if not line: 

128 continue 

129 d = json.loads(line) 

130 self._records.append(FeedbackRecord( 

131 session_id=d["session_id"], 

132 iteration=d["iteration"], 

133 feedback_type=FeedbackType(d["feedback_type"]), 

134 content=d["content"], 

135 original_output=d.get("original_output", ""), 

136 corrected_output=d.get("corrected_output", ""), 

137 timestamp=d.get("timestamp", ""), 

138 metadata=d.get("metadata", {}), 

139 )) 

140 

141 

142class PreferenceLearner: 

143 """偏好学习器 — 从反馈中提取改进信号。""" 

144 

145 def __init__(self, window_size: int = 100): 

146 self.window_size = window_size 

147 self._recent_patterns: list[dict] = [] 

148 

149 def learn_from_feedback(self, record: FeedbackRecord): 

150 """从单条反馈中学习。""" 

151 pattern = { 

152 "type": record.feedback_type.value, 

153 "content": record.content[:200], 

154 "session": record.session_id, 

155 } 

156 self._recent_patterns.append(pattern) 

157 if len(self._recent_patterns) > self.window_size: 

158 self._recent_patterns = self._recent_patterns[-self.window_size:] 

159 

160 def get_improvement_hints(self) -> list[str]: 

161 """获取改进建议。""" 

162 hints = [] 

163 corrections = [r for r in self._recent_patterns if r["type"] == "corrective"] 

164 if corrections: 

165 hints.append(f"最近 {len(corrections)} 条纠正反馈,建议调整输出风格") 

166 thumbs_down = sum(1 for r in self._recent_patterns if r["type"] == "thumb" and r["content"] == "down") 

167 thumbs_up = sum(1 for r in self._recent_patterns if r["type"] == "thumb" and r["content"] == "up") 

168 if thumbs_down > thumbs_up: 

169 hints.append("近期满意度下降,建议优化响应质量") 

170 return hints 

171 

172 def should_retrain(self, threshold: float = 0.3) -> bool: 

173 """判断是否应该触发模型微调。""" 

174 if not self._recent_patterns: 

175 return False 

176 negative = sum(1 for r in self._recent_patterns if r["type"] in ("thumb", "corrective")) 

177 return negative / len(self._recent_patterns) > threshold