Coverage for agentos/cost/token_counter.py: 35%
132 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2Token Counter — Model-aware token counting and cost estimation.
4Supports tiktoken-based counting for OpenAI models and approximate
5counting for other providers (Anthropic, Google, local models).
6"""
8from __future__ import annotations
10import math
11from dataclasses import dataclass, field
12from enum import Enum
13from typing import Optional
16class ModelFamily(Enum):
18 """模型系列枚举。"""
20 GPT4 = "gpt-4"
21 GPT4O = "gpt-4o"
22 GPT35 = "gpt-3.5-turbo"
23 CLAUDE3 = "claude-3"
24 CLAUDE35 = "claude-3.5"
25 GEMINI = "gemini"
26 LLAMA = "llama"
27 MIXTRAL = "mixtral"
28 UNKNOWN = "unknown"
31@dataclass
32class TokenCount:
33 """Token counts for a message or conversation."""
35 prompt_tokens: int = 0
36 completion_tokens: int = 0
37 total_tokens: int = 0
38 model: str = ""
41@dataclass
42class CostEstimate:
43 """Estimated cost for token usage."""
45 prompt_cost: float = 0.0
46 completion_cost: float = 0.0
47 total_cost: float = 0.0
48 currency: str = "USD"
49 token_count: Optional[TokenCount] = None
52# Pricing per 1M tokens (input, output) — updated mid-2025
53PRICING_TABLE: dict[str, tuple[float, float]] = {
54 # OpenAI
55 "gpt-4o": (2.50, 10.00),
56 "gpt-4o-mini": (0.15, 0.60),
57 "gpt-4-turbo": (10.00, 30.00),
58 "gpt-4": (30.00, 60.00),
59 "gpt-3.5-turbo": (0.50, 1.50),
60 # Anthropic
61 "claude-3.5-sonnet": (3.00, 15.00),
62 "claude-3-opus": (15.00, 75.00),
63 "claude-3-haiku": (0.25, 1.25),
64 "claude-3-sonnet": (3.00, 15.00),
65 # Google
66 "gemini-1.5-pro": (1.25, 5.00),
67 "gemini-1.5-flash": (0.075, 0.30),
68 "gemini-2.0-flash": (0.10, 0.40),
69 # Open-source (hosted)
70 "llama-3-70b": (0.59, 0.79),
71 "llama-3-8b": (0.06, 0.06),
72 "mixtral-8x7b": (0.24, 0.24),
73}
76class TokenCounter:
77 """
78 Model-aware token counting and cost estimation.
80 Uses tiktoken when available for OpenAI models, falls back to
81 character-based approximation for other models.
83 Example::
85 counter = TokenCounter()
86 tokens = counter.count("Hello, world!", model="gpt-4o")
87 cost = counter.estimate_cost(tokens, model="gpt-4o")
88 """
90 # Characters per token — rough estimates per model family
91 CHARS_PER_TOKEN: dict[ModelFamily, float] = {
92 ModelFamily.GPT4: 3.5,
93 ModelFamily.GPT4O: 3.8,
94 ModelFamily.GPT35: 4.0,
95 ModelFamily.CLAUDE3: 3.2,
96 ModelFamily.CLAUDE35: 3.4,
97 ModelFamily.GEMINI: 3.0,
98 ModelFamily.LLAMA: 3.8,
99 ModelFamily.MIXTRAL: 3.6,
100 ModelFamily.UNKNOWN: 4.0,
101 }
103 def __init__(self):
104 self._tiktoken_available = self._try_load_tiktoken()
105 self._encoders: dict[str, object] = {}
106 self._usage_log: list[TokenCount] = []
108 def _try_load_tiktoken(self) -> bool:
109 try:
110 import tiktoken
111 self._tiktoken = tiktoken
112 return True
113 except ImportError:
114 return False
116 def _get_encoder(self, model: str):
117 """Get tiktoken encoder for model, with caching."""
118 if not self._tiktoken_available:
119 return None
121 if model in self._encoders:
122 return self._encoders[model]
124 try:
125 encoder = self._tiktoken.encoding_for_model(model)
126 except KeyError:
127 try:
128 encoder = self._tiktoken.get_encoding("cl100k_base")
129 except Exception:
130 return None
131 self._encoders[model] = encoder
132 return encoder
134 def count(self, text: str, model: str = "gpt-4o") -> TokenCount:
135 """
136 Count tokens in text for a specific model.
138 Args:
139 text: The text to count tokens for.
140 model: Model identifier string.
142 Returns:
143 TokenCount with prompt_tokens set (single text counts as prompt).
144 """
145 family = self._classify_model(model)
146 encoder = self._get_encoder(model)
148 if encoder:
149 count_val = len(encoder.encode(text))
150 else:
151 chars_per = self.CHARS_PER_TOKEN.get(family, 4.0)
152 count_val = max(1, int(len(text) / chars_per))
154 result = TokenCount(
155 prompt_tokens=count_val,
156 total_tokens=count_val,
157 model=model,
158 )
159 self._usage_log.append(result)
160 return result
162 def count_messages(
163 self, messages: list[dict[str, str]], model: str = "gpt-4o",
164 ) -> TokenCount:
165 """
166 Count tokens for a list of chat messages.
168 Args:
169 messages: List of {"role": "...", "content": "..."} dicts.
170 model: Model identifier.
172 Returns:
173 TokenCount with total prompt tokens.
174 """
175 total = 0
176 for msg in messages:
177 content = msg.get("content", "")
178 # Role overhead: ~4 tokens per message
179 total += 4
180 total += self.count(content, model=model).prompt_tokens
182 result = TokenCount(
183 prompt_tokens=total,
184 total_tokens=total,
185 model=model,
186 )
187 self._usage_log.append(result)
188 return result
190 def estimate_cost(
191 self, token_count: TokenCount, model: Optional[str] = None,
192 ) -> CostEstimate:
193 """
194 Estimate USD cost from token usage.
196 Args:
197 token_count: Token counts from count() or count_messages().
198 model: Override model for pricing lookup.
200 Returns:
201 CostEstimate with total USD cost.
202 """
203 m = model or token_count.model
204 pricing = self._get_pricing(m)
206 prompt_cost = (token_count.prompt_tokens / 1_000_000) * pricing[0]
207 completion_cost = (token_count.completion_tokens / 1_000_000) * pricing[1]
209 return CostEstimate(
210 prompt_cost=prompt_cost,
211 completion_cost=completion_cost,
212 total_cost=prompt_cost + completion_cost,
213 token_count=token_count,
214 )
216 def _get_pricing(self, model: str) -> tuple[float, float]:
217 """Find closest pricing match for model."""
218 if model in PRICING_TABLE:
219 return PRICING_TABLE[model]
221 # Try prefix match
222 for key, pricing in PRICING_TABLE.items():
223 if model.startswith(key):
224 return pricing
226 # Default: conservative estimate
227 return (1.00, 3.00)
229 def _classify_model(self, model: str) -> ModelFamily:
230 model_lower = model.lower()
231 if "gpt-4o" in model_lower:
232 return ModelFamily.GPT4O
233 if "gpt-4" in model_lower:
234 return ModelFamily.GPT4
235 if "gpt-3.5" in model_lower:
236 return ModelFamily.GPT35
237 if "claude-3.5" in model_lower:
238 return ModelFamily.CLAUDE35
239 if "claude-3" in model_lower or "claude" in model_lower:
240 return ModelFamily.CLAUDE3
241 if "gemini" in model_lower:
242 return ModelFamily.GEMINI
243 if "llama" in model_lower:
244 return ModelFamily.LLAMA
245 if "mixtral" in model_lower:
246 return ModelFamily.MIXTRAL
247 return ModelFamily.UNKNOWN
249 def get_total_usage(self) -> TokenCount:
250 """Aggregate all logged usage."""
251 prompt = sum(u.prompt_tokens for u in self._usage_log)
252 completion = sum(u.completion_tokens for u in self._usage_log)
253 return TokenCount(
254 prompt_tokens=prompt,
255 completion_tokens=completion,
256 total_tokens=prompt + completion,
257 )
259 def get_total_cost(self) -> CostEstimate:
260 """Estimate total cost of all logged usage."""
261 total_tokens = self.get_total_usage()
262 total_cost = 0.0
263 for entry in self._usage_log:
264 cost = self.estimate_cost(entry)
265 total_cost += cost.total_cost
266 return CostEstimate(total_cost=total_cost, token_count=total_tokens)
268 def reset_usage(self) -> None:
269 self._usage_log.clear()
271 @staticmethod
272 def format_cost(cost: CostEstimate) -> str:
273 """Human-readable cost string."""
274 if cost.total_cost < 0.01:
275 return f"${cost.total_cost:.6f}"
276 if cost.total_cost < 1.0:
277 return f"${cost.total_cost:.4f}"
278 return f"${cost.total_cost:.2f}"
280 @staticmethod
281 def format_tokens(tokens: TokenCount) -> str:
282 """Human-readable token count string."""
283 if tokens.total_tokens < 1000:
284 return str(tokens.total_tokens)
285 return f"{tokens.total_tokens / 1000:.1f}K"