Coverage for agentos/prompts/few_shot.py: 37%

109 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2Few-Shot Example Management — intelligent few-shot selection strategies. 

3 

4Supports similarity-based, random, diversity-maximizing, and 

5custom selection algorithms for constructing optimal few-shot prompts. 

6""" 

7 

8from dataclasses import dataclass, field 

9from enum import Enum 

10from typing import Any, Dict, Iterable, List, Optional, Sequence 

11import hashlib 

12import random 

13import re 

14 

15 

16class SelectionStrategy(str, Enum): 

17 """Strategy for selecting few-shot examples.""" 

18 

19 RANDOM = "random" 

20 SIMILARITY = "similarity" 

21 DIVERSITY = "diversity" 

22 RECENCY = "recency" 

23 LABEL_BALANCED = "label_balanced" 

24 ACTIVE_LEARNING = "active_learning" 

25 

26 

27@dataclass 

28class Example: 

29 """A single training example for few-shot learning.""" 

30 

31 input: str 

32 output: str 

33 id: str = "" 

34 label: str = "" 

35 metadata: dict[str, Any] = field(default_factory=dict) 

36 score: float = 0.0 

37 

38 def __post_init__(self): 

39 if not self.id: 

40 self.id = hashlib.md5( 

41 f"{self.input}{self.output}".encode() 

42 ).hexdigest()[:12] 

43 

44 

45class FewShotSelector: 

46 """Selects and formats the best few-shot examples for a given query. 

47 

48 Usage:: 

49 

50 examples = [ 

51 Example(input="What is 2+2?", output="4", label="math"), 

52 Example(input="Capital of France?", output="Paris", label="geo"), 

53 ] 

54 selector = FewShotSelector(examples, strategy=SelectionStrategy.SIMILARITY) 

55 prompt = selector.build_prompt("What is 3+5?", base_instruction="Answer:") 

56 """ 

57 

58 DEFAULT_FORMAT = "Q: {input}\nA: {output}" 

59 MAX_TOKEN_ESTIMATE = 4096 

60 

61 def __init__( 

62 self, 

63 examples: Sequence[Example], 

64 strategy: SelectionStrategy = SelectionStrategy.SIMILARITY, 

65 max_examples: int = 5, 

66 example_format: str = "", 

67 seed: int = 42, 

68 ): 

69 self.examples = list(examples) 

70 self.strategy = strategy 

71 self.max_examples = max_examples 

72 self.example_format = example_format or self.DEFAULT_FORMAT 

73 random.seed(seed) 

74 

75 def select(self, query: str, k: Optional[int] = None) -> list[Example]: 

76 """Select top-k examples for the given query.""" 

77 k = k or self.max_examples 

78 if not self.examples: 

79 return [] 

80 

81 strategy_map = { 

82 SelectionStrategy.RANDOM: self._select_random, 

83 SelectionStrategy.SIMILARITY: self._select_similarity, 

84 SelectionStrategy.DIVERSITY: self._select_diversity, 

85 SelectionStrategy.RECENCY: self._select_recency, 

86 SelectionStrategy.LABEL_BALANCED: self._select_label_balanced, 

87 } 

88 selector_fn = strategy_map.get(self.strategy, self._select_similarity) 

89 return selector_fn(query, k) 

90 

91 def build_prompt( 

92 self, 

93 query: str, 

94 base_instruction: str = "", 

95 k: Optional[int] = None, 

96 ) -> str: 

97 """Build a complete few-shot prompt string.""" 

98 selected = self.select(query, k) 

99 parts: list[str] = [] 

100 if base_instruction: 

101 parts.append(base_instruction) 

102 for ex in selected: 

103 parts.append(self.example_format.format(input=ex.input, output=ex.output)) 

104 parts.append(self.example_format.format(input=query, output="")) 

105 return "\n\n".join(parts) 

106 

107 def add_example(self, example: Example): 

108 """Add a new example to the pool.""" 

109 self.examples.append(example) 

110 

111 def remove_example(self, example_id: str): 

112 """Remove an example by ID.""" 

113 self.examples = [e for e in self.examples if e.id != example_id] 

114 

115 def set_score(self, example_id: str, score: float): 

116 """Update the utility score for an example.""" 

117 for ex in self.examples: 

118 if ex.id == example_id: 

119 ex.score = score 

120 break 

121 

122 def _select_random(self, _query: str, k: int) -> list[Example]: 

123 return random.sample(self.examples, min(k, len(self.examples))) 

124 

125 def _select_similarity(self, query: str, k: int) -> list[Example]: 

126 """Jaccard-based token similarity for fast selection.""" 

127 query_tokens = set(_tokenize(query)) 

128 scored = [ 

129 (ex, self._jaccard(query_tokens, ex)) 

130 for ex in self.examples 

131 ] 

132 scored.sort(key=lambda x: x[1], reverse=True) 

133 return [ex for ex, _ in scored[:k]] 

134 

135 def _select_diversity(self, _query: str, k: int) -> list[Example]: 

136 """Maximize diversity via greedy farthest-first.""" 

137 if k >= len(self.examples): 

138 return list(self.examples) 

139 # Start with a random seed 

140 selected = [random.choice(self.examples)] 

141 remaining = [e for e in self.examples if e not in selected] 

142 while len(selected) < k and remaining: 

143 # Pick the example least similar to any already selected 

144 best = max( 

145 remaining, 

146 key=lambda ex: min( 

147 self._jaccard(set(_tokenize(ex.input)), s) 

148 for s in selected 

149 ), 

150 ) 

151 selected.append(best) 

152 remaining.remove(best) 

153 return selected 

154 

155 def _select_recency(self, _query: str, k: int) -> list[Example]: 

156 """Most recent examples first (assumes append order = recency).""" 

157 return list(reversed(self.examples[-k:])) 

158 

159 def _select_label_balanced(self, _query: str, k: int) -> list[Example]: 

160 """Balance selection across unique labels.""" 

161 by_label: dict[str, list[Example]] = {} 

162 for ex in self.examples: 

163 by_label.setdefault(ex.label or "_unlabeled", []).append(ex) 

164 labels = list(by_label.keys()) 

165 result: list[Example] = [] 

166 idx = 0 

167 while len(result) < k and any(by_label.values()): 

168 label = labels[idx % len(labels)] 

169 pool = by_label[label] 

170 if pool: 

171 result.append(pool.pop(random.randrange(len(pool)))) 

172 idx += 1 

173 return result 

174 

175 @staticmethod 

176 def _jaccard(tokens_a: set[str], example: Example) -> float: 

177 tokens_b = set(_tokenize(example.input)) 

178 if not tokens_a or not tokens_b: 

179 return 0.0 

180 intersection = tokens_a & tokens_b 

181 union = tokens_a | tokens_b 

182 return len(intersection) / len(union) 

183 

184 

185def build_examples( 

186 pairs: Iterable[tuple[str, str]], 

187 labels: Iterable[str] | None = None, 

188 metadata: list[dict] | None = None, 

189) -> list[Example]: 

190 """Convenience factory to build a list of Example objects. 

191 

192 Args: 

193 pairs: Iterable of (input, output) tuples. 

194 labels: Optional labels for each example. 

195 metadata: Optional metadata dicts. 

196 

197 Returns: 

198 List of ``Example`` objects. 

199 """ 

200 examples: list[Example] = [] 

201 label_list = list(labels) if labels else [] 

202 meta_list = list(metadata) if metadata else [] 

203 for i, (inp, out) in enumerate(pairs): 

204 ex = Example( 

205 input=inp, 

206 output=out, 

207 label=label_list[i] if i < len(label_list) else "", 

208 metadata=meta_list[i] if i < len(meta_list) else {}, 

209 ) 

210 examples.append(ex) 

211 return examples 

212 

213 

214def _tokenize(text: str) -> list[str]: 

215 """Simple whitespace+punctuation tokenizer.""" 

216 return re.findall(r"\w+", text.lower())