Coverage for agentos/cost/token_counter.py: 35%

132 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2Token Counter — Model-aware token counting and cost estimation. 

3 

4Supports tiktoken-based counting for OpenAI models and approximate 

5counting for other providers (Anthropic, Google, local models). 

6""" 

7 

8from __future__ import annotations 

9 

10import math 

11from dataclasses import dataclass, field 

12from enum import Enum 

13from typing import Optional 

14 

15 

16class ModelFamily(Enum): 

17 

18 """模型系列枚举。""" 

19 

20 GPT4 = "gpt-4" 

21 GPT4O = "gpt-4o" 

22 GPT35 = "gpt-3.5-turbo" 

23 CLAUDE3 = "claude-3" 

24 CLAUDE35 = "claude-3.5" 

25 GEMINI = "gemini" 

26 LLAMA = "llama" 

27 MIXTRAL = "mixtral" 

28 UNKNOWN = "unknown" 

29 

30 

31@dataclass 

32class TokenCount: 

33 """Token counts for a message or conversation.""" 

34 

35 prompt_tokens: int = 0 

36 completion_tokens: int = 0 

37 total_tokens: int = 0 

38 model: str = "" 

39 

40 

41@dataclass 

42class CostEstimate: 

43 """Estimated cost for token usage.""" 

44 

45 prompt_cost: float = 0.0 

46 completion_cost: float = 0.0 

47 total_cost: float = 0.0 

48 currency: str = "USD" 

49 token_count: Optional[TokenCount] = None 

50 

51 

52# Pricing per 1M tokens (input, output) — updated mid-2025 

53PRICING_TABLE: dict[str, tuple[float, float]] = { 

54 # OpenAI 

55 "gpt-4o": (2.50, 10.00), 

56 "gpt-4o-mini": (0.15, 0.60), 

57 "gpt-4-turbo": (10.00, 30.00), 

58 "gpt-4": (30.00, 60.00), 

59 "gpt-3.5-turbo": (0.50, 1.50), 

60 # Anthropic 

61 "claude-3.5-sonnet": (3.00, 15.00), 

62 "claude-3-opus": (15.00, 75.00), 

63 "claude-3-haiku": (0.25, 1.25), 

64 "claude-3-sonnet": (3.00, 15.00), 

65 # Google 

66 "gemini-1.5-pro": (1.25, 5.00), 

67 "gemini-1.5-flash": (0.075, 0.30), 

68 "gemini-2.0-flash": (0.10, 0.40), 

69 # Open-source (hosted) 

70 "llama-3-70b": (0.59, 0.79), 

71 "llama-3-8b": (0.06, 0.06), 

72 "mixtral-8x7b": (0.24, 0.24), 

73} 

74 

75 

76class TokenCounter: 

77 """ 

78 Model-aware token counting and cost estimation. 

79 

80 Uses tiktoken when available for OpenAI models, falls back to 

81 character-based approximation for other models. 

82 

83 Example:: 

84 

85 counter = TokenCounter() 

86 tokens = counter.count("Hello, world!", model="gpt-4o") 

87 cost = counter.estimate_cost(tokens, model="gpt-4o") 

88 """ 

89 

90 # Characters per token — rough estimates per model family 

91 CHARS_PER_TOKEN: dict[ModelFamily, float] = { 

92 ModelFamily.GPT4: 3.5, 

93 ModelFamily.GPT4O: 3.8, 

94 ModelFamily.GPT35: 4.0, 

95 ModelFamily.CLAUDE3: 3.2, 

96 ModelFamily.CLAUDE35: 3.4, 

97 ModelFamily.GEMINI: 3.0, 

98 ModelFamily.LLAMA: 3.8, 

99 ModelFamily.MIXTRAL: 3.6, 

100 ModelFamily.UNKNOWN: 4.0, 

101 } 

102 

103 def __init__(self): 

104 self._tiktoken_available = self._try_load_tiktoken() 

105 self._encoders: dict[str, object] = {} 

106 self._usage_log: list[TokenCount] = [] 

107 

108 def _try_load_tiktoken(self) -> bool: 

109 try: 

110 import tiktoken 

111 self._tiktoken = tiktoken 

112 return True 

113 except ImportError: 

114 return False 

115 

116 def _get_encoder(self, model: str): 

117 """Get tiktoken encoder for model, with caching.""" 

118 if not self._tiktoken_available: 

119 return None 

120 

121 if model in self._encoders: 

122 return self._encoders[model] 

123 

124 try: 

125 encoder = self._tiktoken.encoding_for_model(model) 

126 except KeyError: 

127 try: 

128 encoder = self._tiktoken.get_encoding("cl100k_base") 

129 except Exception: 

130 return None 

131 self._encoders[model] = encoder 

132 return encoder 

133 

134 def count(self, text: str, model: str = "gpt-4o") -> TokenCount: 

135 """ 

136 Count tokens in text for a specific model. 

137 

138 Args: 

139 text: The text to count tokens for. 

140 model: Model identifier string. 

141 

142 Returns: 

143 TokenCount with prompt_tokens set (single text counts as prompt). 

144 """ 

145 family = self._classify_model(model) 

146 encoder = self._get_encoder(model) 

147 

148 if encoder: 

149 count_val = len(encoder.encode(text)) 

150 else: 

151 chars_per = self.CHARS_PER_TOKEN.get(family, 4.0) 

152 count_val = max(1, int(len(text) / chars_per)) 

153 

154 result = TokenCount( 

155 prompt_tokens=count_val, 

156 total_tokens=count_val, 

157 model=model, 

158 ) 

159 self._usage_log.append(result) 

160 return result 

161 

162 def count_messages( 

163 self, messages: list[dict[str, str]], model: str = "gpt-4o", 

164 ) -> TokenCount: 

165 """ 

166 Count tokens for a list of chat messages. 

167 

168 Args: 

169 messages: List of {"role": "...", "content": "..."} dicts. 

170 model: Model identifier. 

171 

172 Returns: 

173 TokenCount with total prompt tokens. 

174 """ 

175 total = 0 

176 for msg in messages: 

177 content = msg.get("content", "") 

178 # Role overhead: ~4 tokens per message 

179 total += 4 

180 total += self.count(content, model=model).prompt_tokens 

181 

182 result = TokenCount( 

183 prompt_tokens=total, 

184 total_tokens=total, 

185 model=model, 

186 ) 

187 self._usage_log.append(result) 

188 return result 

189 

190 def estimate_cost( 

191 self, token_count: TokenCount, model: Optional[str] = None, 

192 ) -> CostEstimate: 

193 """ 

194 Estimate USD cost from token usage. 

195 

196 Args: 

197 token_count: Token counts from count() or count_messages(). 

198 model: Override model for pricing lookup. 

199 

200 Returns: 

201 CostEstimate with total USD cost. 

202 """ 

203 m = model or token_count.model 

204 pricing = self._get_pricing(m) 

205 

206 prompt_cost = (token_count.prompt_tokens / 1_000_000) * pricing[0] 

207 completion_cost = (token_count.completion_tokens / 1_000_000) * pricing[1] 

208 

209 return CostEstimate( 

210 prompt_cost=prompt_cost, 

211 completion_cost=completion_cost, 

212 total_cost=prompt_cost + completion_cost, 

213 token_count=token_count, 

214 ) 

215 

216 def _get_pricing(self, model: str) -> tuple[float, float]: 

217 """Find closest pricing match for model.""" 

218 if model in PRICING_TABLE: 

219 return PRICING_TABLE[model] 

220 

221 # Try prefix match 

222 for key, pricing in PRICING_TABLE.items(): 

223 if model.startswith(key): 

224 return pricing 

225 

226 # Default: conservative estimate 

227 return (1.00, 3.00) 

228 

229 def _classify_model(self, model: str) -> ModelFamily: 

230 model_lower = model.lower() 

231 if "gpt-4o" in model_lower: 

232 return ModelFamily.GPT4O 

233 if "gpt-4" in model_lower: 

234 return ModelFamily.GPT4 

235 if "gpt-3.5" in model_lower: 

236 return ModelFamily.GPT35 

237 if "claude-3.5" in model_lower: 

238 return ModelFamily.CLAUDE35 

239 if "claude-3" in model_lower or "claude" in model_lower: 

240 return ModelFamily.CLAUDE3 

241 if "gemini" in model_lower: 

242 return ModelFamily.GEMINI 

243 if "llama" in model_lower: 

244 return ModelFamily.LLAMA 

245 if "mixtral" in model_lower: 

246 return ModelFamily.MIXTRAL 

247 return ModelFamily.UNKNOWN 

248 

249 def get_total_usage(self) -> TokenCount: 

250 """Aggregate all logged usage.""" 

251 prompt = sum(u.prompt_tokens for u in self._usage_log) 

252 completion = sum(u.completion_tokens for u in self._usage_log) 

253 return TokenCount( 

254 prompt_tokens=prompt, 

255 completion_tokens=completion, 

256 total_tokens=prompt + completion, 

257 ) 

258 

259 def get_total_cost(self) -> CostEstimate: 

260 """Estimate total cost of all logged usage.""" 

261 total_tokens = self.get_total_usage() 

262 total_cost = 0.0 

263 for entry in self._usage_log: 

264 cost = self.estimate_cost(entry) 

265 total_cost += cost.total_cost 

266 return CostEstimate(total_cost=total_cost, token_count=total_tokens) 

267 

268 def reset_usage(self) -> None: 

269 self._usage_log.clear() 

270 

271 @staticmethod 

272 def format_cost(cost: CostEstimate) -> str: 

273 """Human-readable cost string.""" 

274 if cost.total_cost < 0.01: 

275 return f"${cost.total_cost:.6f}" 

276 if cost.total_cost < 1.0: 

277 return f"${cost.total_cost:.4f}" 

278 return f"${cost.total_cost:.2f}" 

279 

280 @staticmethod 

281 def format_tokens(tokens: TokenCount) -> str: 

282 """Human-readable token count string.""" 

283 if tokens.total_tokens < 1000: 

284 return str(tokens.total_tokens) 

285 return f"{tokens.total_tokens / 1000:.1f}K"