Coverage for agentos/prompts/few_shot.py: 37%
109 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2Few-Shot Example Management — intelligent few-shot selection strategies.
4Supports similarity-based, random, diversity-maximizing, and
5custom selection algorithms for constructing optimal few-shot prompts.
6"""
8from dataclasses import dataclass, field
9from enum import Enum
10from typing import Any, Dict, Iterable, List, Optional, Sequence
11import hashlib
12import random
13import re
16class SelectionStrategy(str, Enum):
17 """Strategy for selecting few-shot examples."""
19 RANDOM = "random"
20 SIMILARITY = "similarity"
21 DIVERSITY = "diversity"
22 RECENCY = "recency"
23 LABEL_BALANCED = "label_balanced"
24 ACTIVE_LEARNING = "active_learning"
27@dataclass
28class Example:
29 """A single training example for few-shot learning."""
31 input: str
32 output: str
33 id: str = ""
34 label: str = ""
35 metadata: dict[str, Any] = field(default_factory=dict)
36 score: float = 0.0
38 def __post_init__(self):
39 if not self.id:
40 self.id = hashlib.md5(
41 f"{self.input}{self.output}".encode()
42 ).hexdigest()[:12]
45class FewShotSelector:
46 """Selects and formats the best few-shot examples for a given query.
48 Usage::
50 examples = [
51 Example(input="What is 2+2?", output="4", label="math"),
52 Example(input="Capital of France?", output="Paris", label="geo"),
53 ]
54 selector = FewShotSelector(examples, strategy=SelectionStrategy.SIMILARITY)
55 prompt = selector.build_prompt("What is 3+5?", base_instruction="Answer:")
56 """
58 DEFAULT_FORMAT = "Q: {input}\nA: {output}"
59 MAX_TOKEN_ESTIMATE = 4096
61 def __init__(
62 self,
63 examples: Sequence[Example],
64 strategy: SelectionStrategy = SelectionStrategy.SIMILARITY,
65 max_examples: int = 5,
66 example_format: str = "",
67 seed: int = 42,
68 ):
69 self.examples = list(examples)
70 self.strategy = strategy
71 self.max_examples = max_examples
72 self.example_format = example_format or self.DEFAULT_FORMAT
73 random.seed(seed)
75 def select(self, query: str, k: Optional[int] = None) -> list[Example]:
76 """Select top-k examples for the given query."""
77 k = k or self.max_examples
78 if not self.examples:
79 return []
81 strategy_map = {
82 SelectionStrategy.RANDOM: self._select_random,
83 SelectionStrategy.SIMILARITY: self._select_similarity,
84 SelectionStrategy.DIVERSITY: self._select_diversity,
85 SelectionStrategy.RECENCY: self._select_recency,
86 SelectionStrategy.LABEL_BALANCED: self._select_label_balanced,
87 }
88 selector_fn = strategy_map.get(self.strategy, self._select_similarity)
89 return selector_fn(query, k)
91 def build_prompt(
92 self,
93 query: str,
94 base_instruction: str = "",
95 k: Optional[int] = None,
96 ) -> str:
97 """Build a complete few-shot prompt string."""
98 selected = self.select(query, k)
99 parts: list[str] = []
100 if base_instruction:
101 parts.append(base_instruction)
102 for ex in selected:
103 parts.append(self.example_format.format(input=ex.input, output=ex.output))
104 parts.append(self.example_format.format(input=query, output=""))
105 return "\n\n".join(parts)
107 def add_example(self, example: Example):
108 """Add a new example to the pool."""
109 self.examples.append(example)
111 def remove_example(self, example_id: str):
112 """Remove an example by ID."""
113 self.examples = [e for e in self.examples if e.id != example_id]
115 def set_score(self, example_id: str, score: float):
116 """Update the utility score for an example."""
117 for ex in self.examples:
118 if ex.id == example_id:
119 ex.score = score
120 break
122 def _select_random(self, _query: str, k: int) -> list[Example]:
123 return random.sample(self.examples, min(k, len(self.examples)))
125 def _select_similarity(self, query: str, k: int) -> list[Example]:
126 """Jaccard-based token similarity for fast selection."""
127 query_tokens = set(_tokenize(query))
128 scored = [
129 (ex, self._jaccard(query_tokens, ex))
130 for ex in self.examples
131 ]
132 scored.sort(key=lambda x: x[1], reverse=True)
133 return [ex for ex, _ in scored[:k]]
135 def _select_diversity(self, _query: str, k: int) -> list[Example]:
136 """Maximize diversity via greedy farthest-first."""
137 if k >= len(self.examples):
138 return list(self.examples)
139 # Start with a random seed
140 selected = [random.choice(self.examples)]
141 remaining = [e for e in self.examples if e not in selected]
142 while len(selected) < k and remaining:
143 # Pick the example least similar to any already selected
144 best = max(
145 remaining,
146 key=lambda ex: min(
147 self._jaccard(set(_tokenize(ex.input)), s)
148 for s in selected
149 ),
150 )
151 selected.append(best)
152 remaining.remove(best)
153 return selected
155 def _select_recency(self, _query: str, k: int) -> list[Example]:
156 """Most recent examples first (assumes append order = recency)."""
157 return list(reversed(self.examples[-k:]))
159 def _select_label_balanced(self, _query: str, k: int) -> list[Example]:
160 """Balance selection across unique labels."""
161 by_label: dict[str, list[Example]] = {}
162 for ex in self.examples:
163 by_label.setdefault(ex.label or "_unlabeled", []).append(ex)
164 labels = list(by_label.keys())
165 result: list[Example] = []
166 idx = 0
167 while len(result) < k and any(by_label.values()):
168 label = labels[idx % len(labels)]
169 pool = by_label[label]
170 if pool:
171 result.append(pool.pop(random.randrange(len(pool))))
172 idx += 1
173 return result
175 @staticmethod
176 def _jaccard(tokens_a: set[str], example: Example) -> float:
177 tokens_b = set(_tokenize(example.input))
178 if not tokens_a or not tokens_b:
179 return 0.0
180 intersection = tokens_a & tokens_b
181 union = tokens_a | tokens_b
182 return len(intersection) / len(union)
185def build_examples(
186 pairs: Iterable[tuple[str, str]],
187 labels: Iterable[str] | None = None,
188 metadata: list[dict] | None = None,
189) -> list[Example]:
190 """Convenience factory to build a list of Example objects.
192 Args:
193 pairs: Iterable of (input, output) tuples.
194 labels: Optional labels for each example.
195 metadata: Optional metadata dicts.
197 Returns:
198 List of ``Example`` objects.
199 """
200 examples: list[Example] = []
201 label_list = list(labels) if labels else []
202 meta_list = list(metadata) if metadata else []
203 for i, (inp, out) in enumerate(pairs):
204 ex = Example(
205 input=inp,
206 output=out,
207 label=label_list[i] if i < len(label_list) else "",
208 metadata=meta_list[i] if i < len(meta_list) else {},
209 )
210 examples.append(ex)
211 return examples
214def _tokenize(text: str) -> list[str]:
215 """Simple whitespace+punctuation tokenizer."""
216 return re.findall(r"\w+", text.lower())