Coverage for session_buddy / token_optimizer.py: 89.02%

253 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 00:43 -0800

1#!/usr/bin/env python3 

2"""Token Optimization for Session Management MCP Server. 

3 

4Provides response chunking, content truncation, and context window monitoring 

5to reduce token usage while maintaining functionality. 

6""" 

7 

8import hashlib 

9import json 

10import operator 

11from dataclasses import dataclass 

12from datetime import datetime, timedelta 

13from typing import Any 

14 

15import tiktoken 

16from session_buddy.acb_cache_adapter import ACBChunkCache, get_chunk_cache 

17 

18 

19@dataclass 

20class TokenUsageMetrics: 

21 """Token usage metrics for monitoring.""" 

22 

23 request_tokens: int 

24 response_tokens: int 

25 total_tokens: int 

26 timestamp: str 

27 operation: str 

28 optimization_applied: str | None = None 

29 

30 

31@dataclass 

32class ChunkResult: 

33 """Result of response chunking.""" 

34 

35 chunks: list[str] 

36 total_chunks: int 

37 current_chunk: int 

38 cache_key: str 

39 metadata: dict[str, Any] 

40 

41 

42class TokenOptimizer: 

43 """Main token optimization class.""" 

44 

45 def __init__(self, max_tokens: int = 4000, chunk_size: int = 2000) -> None: 

46 self.max_tokens = max_tokens 

47 self.chunk_size = chunk_size 

48 self.encoding = self._get_encoding() 

49 self.usage_history: list[TokenUsageMetrics] = [] 

50 self.chunk_cache: ACBChunkCache = get_chunk_cache() # ACB-backed cache 

51 

52 # Token optimization strategies 

53 self.strategies = { 

54 "truncate_old": self._truncate_old_conversations, 

55 "summarize_content": self._summarize_long_content, 

56 "chunk_response": self._chunk_large_response, 

57 "filter_duplicates": self._filter_duplicate_content, 

58 "prioritize_recent": self._prioritize_recent_content, 

59 } 

60 

61 def _get_encoding(self) -> Any: 

62 """Get tiktoken encoding for token counting.""" 

63 try: 

64 return tiktoken.get_encoding("cl100k_base") # GPT-4 encoding 

65 except Exception: 

66 # Fallback to approximate counting 

67 return None 

68 

69 def count_tokens(self, text: str) -> int: 

70 """Count tokens in text.""" 

71 if self.encoding: 

72 return len(self.encoding.encode(text)) 

73 # Rough approximation: ~4 chars per token 

74 return len(text) // 4 

75 

76 async def optimize_search_results( 

77 self, 

78 results: list[dict[str, Any]], 

79 strategy: str = "truncate_old", 

80 max_tokens: int | None = None, 

81 ) -> tuple[list[dict[str, Any]], dict[str, Any]]: 

82 """Optimize search results to reduce token usage.""" 

83 max_tokens = max_tokens or self.max_tokens 

84 

85 optimization_info: dict[str, Any] = {} 

86 if strategy in self.strategies: 86 ↛ 93line 86 didn't jump to line 93 because the condition on line 86 was always true

87 optimized_results, info = await self.strategies[strategy]( 

88 results, 

89 max_tokens, 

90 ) 

91 optimization_info = info 

92 else: 

93 optimized_results = results # type: ignore[assignment] 

94 optimization_info["strategy"] = "none" 

95 

96 # Track optimization metrics 

97 optimization_info["original_count"] = len(results) 

98 optimization_info["optimized_count"] = len(optimized_results) 

99 optimization_info["token_savings"] = self._calculate_token_savings( 

100 results, 

101 optimized_results, 

102 ) 

103 

104 return optimized_results, optimization_info 

105 

106 async def _truncate_old_conversations( 

107 self, 

108 results: list[dict[str, Any]], 

109 max_tokens: int, 

110 ) -> tuple[list[dict[str, Any]], dict[str, Any]]: 

111 """Truncate old conversations based on age and importance.""" 

112 if not results: 

113 return results, {"strategy": "truncate_old", "action": "no_results"} 

114 

115 # Sort by timestamp (newest first) 

116 sorted_results = sorted( 

117 results, 

118 key=lambda x: x.get("timestamp", ""), 

119 reverse=True, 

120 ) 

121 

122 optimized_results: list[dict[str, Any]] = [] 

123 current_tokens = 0 

124 truncation_count = 0 

125 

126 for result in sorted_results: 

127 content = result.get("content", "") 

128 content_tokens = self.count_tokens(content) 

129 

130 # Check if adding this result exceeds token limit 

131 if current_tokens + content_tokens > max_tokens: 

132 # Try truncating the content 

133 if len(optimized_results) < 3: # Always keep at least 3 recent results 133 ↛ 147line 133 didn't jump to line 147 because the condition on line 133 was always true

134 truncated_content = self._truncate_content( 

135 content, 

136 max_tokens - current_tokens, 

137 ) 

138 if truncated_content: 138 ↛ 126line 138 didn't jump to line 126 because the condition on line 138 was always true

139 result_copy = result.copy() 

140 result_copy["content"] = ( 

141 truncated_content + "... [truncated for token limit]" 

142 ) 

143 optimized_results.append(result_copy) 

144 truncation_count += 1 

145 break 

146 else: 

147 break 

148 else: 

149 optimized_results.append(result) 

150 current_tokens += content_tokens 

151 

152 return optimized_results, { 

153 "strategy": "truncate_old", 

154 "action": "truncated", 

155 "truncation_count": truncation_count, 

156 "final_token_count": current_tokens, 

157 } 

158 

159 async def _summarize_long_content( 

160 self, 

161 results: list[dict[str, Any]], 

162 max_tokens: int, 

163 ) -> tuple[list[dict[str, Any]], dict[str, Any]]: 

164 """Summarize long content to reduce tokens.""" 

165 optimized_results = [] 

166 summarized_count = 0 

167 

168 for result in results: 

169 content = result.get("content", "") 

170 content_tokens = self.count_tokens(content) 

171 

172 if content_tokens > 500: # Summarize content longer than 500 tokens 

173 summary = self._create_quick_summary(content) 

174 result_copy = result.copy() 

175 result_copy["content"] = summary + " [auto-summarized]" 

176 optimized_results.append(result_copy) 

177 summarized_count += 1 

178 else: 

179 optimized_results.append(result) 

180 

181 return optimized_results, { 

182 "strategy": "summarize_content", 

183 "action": "summarized", 

184 "summarized_count": summarized_count, 

185 } 

186 

187 async def _chunk_large_response( 

188 self, 

189 results: list[dict[str, Any]], 

190 max_tokens: int, 

191 ) -> tuple[list[dict[str, Any]], dict[str, Any]]: 

192 """Chunk large response into manageable pieces.""" 

193 if not results: 

194 return results, {"strategy": "chunk_response", "action": "no_results"} 

195 

196 # Estimate total tokens 

197 total_tokens = sum( 

198 self.count_tokens(str(result.get("content", ""))) for result in results 

199 ) 

200 

201 if total_tokens <= max_tokens: 201 ↛ 202line 201 didn't jump to line 202 because the condition on line 201 was never true

202 return results, { 

203 "strategy": "chunk_response", 

204 "action": "no_chunking_needed", 

205 } 

206 

207 # Create chunks 

208 chunks: list[list[dict[str, Any]]] = [] 

209 current_chunk: list[dict[str, Any]] = [] 

210 current_chunk_tokens = 0 

211 

212 for result in results: 

213 result_tokens = self.count_tokens(str(result.get("content", ""))) 

214 

215 if current_chunk_tokens + result_tokens > self.chunk_size and current_chunk: 

216 chunks.append(current_chunk.copy()) 

217 current_chunk = [result] 

218 current_chunk_tokens = result_tokens 

219 else: 

220 current_chunk.append(result) 

221 current_chunk_tokens += result_tokens 

222 

223 if current_chunk: 223 ↛ 227line 223 didn't jump to line 227 because the condition on line 223 was always true

224 chunks.append(current_chunk) 

225 

226 # Return first chunk and create cache entry for the rest 

227 if chunks: 227 ↛ 238line 227 didn't jump to line 238 because the condition on line 227 was always true

228 cache_key = await self._create_chunk_cache_entry(chunks) 

229 return chunks[0], { 

230 "strategy": "chunk_response", 

231 "action": "chunked", 

232 "total_chunks": len(chunks), 

233 "current_chunk": 1, 

234 "cache_key": cache_key, 

235 "has_more": len(chunks) > 1, 

236 } 

237 

238 return results, {"strategy": "chunk_response", "action": "failed"} 

239 

240 async def _filter_duplicate_content( 

241 self, 

242 results: list[dict[str, Any]], 

243 max_tokens: int, 

244 ) -> tuple[list[dict[str, Any]], dict[str, Any]]: 

245 """Filter out duplicate or very similar content.""" 

246 if not results: 

247 return results, {"strategy": "filter_duplicates", "action": "no_results"} 

248 

249 seen_hashes = set() 

250 unique_results = [] 

251 duplicates_removed = 0 

252 

253 for result in results: 

254 content = result.get("content", "") 

255 # Create hash of normalized content using validated pattern 

256 from session_buddy.utils.regex_patterns import SAFE_PATTERNS 

257 

258 normalize_pattern = SAFE_PATTERNS["whitespace_normalize"] 

259 normalized_content = normalize_pattern.apply(content.lower().strip()) 

260 content_hash = hashlib.md5( 

261 normalized_content.encode(), 

262 usedforsecurity=False, 

263 ).hexdigest() 

264 

265 if content_hash not in seen_hashes: 

266 seen_hashes.add(content_hash) 

267 unique_results.append(result) 

268 else: 

269 duplicates_removed += 1 

270 

271 return unique_results, { 

272 "strategy": "filter_duplicates", 

273 "action": "filtered", 

274 "duplicates_removed": duplicates_removed, 

275 } 

276 

277 async def _prioritize_recent_content( 

278 self, 

279 results: list[dict[str, Any]], 

280 max_tokens: int, 

281 ) -> tuple[list[dict[str, Any]], dict[str, Any]]: 

282 """Prioritize recent content and score-based ranking.""" 

283 if not results: 

284 return results, {"strategy": "prioritize_recent", "action": "no_results"} 

285 

286 # Calculate priority scores 

287 now = datetime.now() 

288 scored_results = [] 

289 

290 for result in results: 

291 score = 0.0 

292 

293 # Recency score (0-0.5) 

294 try: 

295 timestamp = datetime.fromisoformat(result.get("timestamp", "")) 

296 days_old = (now - timestamp).days 

297 recency_score = max(0, 0.5 - (days_old / 365) * 0.5) 

298 score += recency_score 

299 except (ValueError, TypeError): 

300 score += 0.1 # Default low recency score 

301 

302 # Relevance score if available (0-0.3) 

303 if "score" in result: 

304 score += result["score"] * 0.3 

305 

306 # Length penalty for very long content (0 to -0.2) 

307 content = result.get("content", "") 

308 if len(content) > 2000: 

309 score -= 0.2 

310 elif len(content) > 1000: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true

311 score -= 0.1 

312 

313 # Code/technical content bonus (0-0.2) 

314 if any( 

315 keyword in content.lower() 

316 for keyword in ("def ", "class ", "function", "error", "exception") 

317 ): 

318 score += 0.2 

319 

320 scored_results.append((score, result)) 

321 

322 # Sort by priority score and take top results within token limit 

323 scored_results.sort(key=operator.itemgetter(0), reverse=True) 

324 

325 prioritized_results = [] 

326 current_tokens = 0 

327 

328 for score, result in scored_results: 

329 result_tokens = self.count_tokens(str(result.get("content", ""))) 

330 if current_tokens + result_tokens <= max_tokens: 

331 prioritized_results.append(result) 

332 current_tokens += result_tokens 

333 else: 

334 break 

335 

336 return prioritized_results, { 

337 "strategy": "prioritize_recent", 

338 "action": "prioritized", 

339 "final_token_count": current_tokens, 

340 } 

341 

342 def _truncate_content(self, content: str, max_tokens: int) -> str: 

343 """Truncate content to fit within token limit.""" 

344 if self.count_tokens(content) <= max_tokens: 344 ↛ 345line 344 didn't jump to line 345 because the condition on line 344 was never true

345 return content 

346 

347 # Try to truncate at sentence boundaries 

348 sentences = content.split(". ") 

349 truncated = "" 

350 

351 for sentence in sentences: 351 ↛ 358line 351 didn't jump to line 358 because the loop on line 351 didn't complete

352 test_content = truncated + sentence + ". " 

353 if self.count_tokens(test_content) <= max_tokens: 

354 truncated = test_content 

355 else: 

356 break 

357 

358 if not truncated: 358 ↛ 360line 358 didn't jump to line 360 because the condition on line 358 was never true

359 # Fallback to character-based truncation 

360 if self.encoding: 

361 tokens = self.encoding.encode(content)[:max_tokens] 

362 truncated = self.encoding.decode(tokens) 

363 else: 

364 # Rough character limit 

365 char_limit = max_tokens * 4 

366 truncated = content[:char_limit] 

367 

368 return truncated.strip() 

369 

370 def _create_quick_summary(self, content: str, max_length: int = 200) -> str: 

371 """Create a quick summary of content.""" 

372 # Extract first and last sentences 

373 sentences = [s.strip() for s in content.split(".") if s.strip()] 

374 if not sentences: 374 ↛ 375line 374 didn't jump to line 375 because the condition on line 374 was never true

375 return content[:max_length] 

376 

377 if len(sentences) == 1: 377 ↛ 378line 377 didn't jump to line 378 because the condition on line 377 was never true

378 return sentences[0][:max_length] 

379 

380 first_sentence = sentences[0] 

381 last_sentence = sentences[-1] 

382 

383 summary = f"{first_sentence}. ... {last_sentence}" 

384 if len(summary) > max_length: 384 ↛ 387line 384 didn't jump to line 387 because the condition on line 384 was always true

385 summary = first_sentence[: max_length - 3] + "..." 

386 

387 return summary 

388 

389 async def _create_chunk_cache_entry( 

390 self, 

391 chunks: list[list[dict[str, Any]]], 

392 ) -> str: 

393 """Create cache entry for chunked results.""" 

394 cache_key = hashlib.md5( 

395 f"chunks_{datetime.now().isoformat()}_{len(chunks)}".encode(), 

396 usedforsecurity=False, 

397 ).hexdigest() 

398 

399 chunk_result = ChunkResult( 

400 chunks=[json.dumps(chunk) for chunk in chunks], 

401 total_chunks=len(chunks), 

402 current_chunk=1, 

403 cache_key=cache_key, 

404 metadata={ 

405 "created": datetime.now().isoformat(), 

406 "expires": (datetime.now() + timedelta(hours=1)).isoformat(), 

407 }, 

408 ) 

409 

410 await self.chunk_cache.set(cache_key, chunk_result) 

411 return cache_key 

412 

413 async def get_chunk( 

414 self, 

415 cache_key: str, 

416 chunk_index: int, 

417 ) -> dict[str, Any] | None: 

418 """Get a specific chunk from cache. 

419 

420 Args: 

421 cache_key: Unique cache key for the chunked data 

422 chunk_index: Index of chunk to retrieve (1-indexed) 

423 

424 Returns: 

425 Dict with chunk data and metadata, or None if not found 

426 

427 """ 

428 if not await self.chunk_cache.__contains__(cache_key): 

429 return None 

430 

431 chunk_result = await self.chunk_cache.get(cache_key) 

432 

433 if chunk_result and 1 <= chunk_index <= len(chunk_result.chunks): 

434 chunk_data = json.loads(chunk_result.chunks[chunk_index - 1]) 

435 return { 

436 "chunk": chunk_data, 

437 "current_chunk": chunk_index, 

438 "total_chunks": chunk_result.total_chunks, 

439 "cache_key": cache_key, 

440 "has_more": chunk_index < chunk_result.total_chunks, 

441 } 

442 

443 return None 

444 

445 def _calculate_token_savings( 

446 self, 

447 original: list[dict[str, Any]], 

448 optimized: list[dict[str, Any]], 

449 ) -> dict[str, int]: 

450 """Calculate token savings from optimization.""" 

451 original_tokens = sum( 

452 self.count_tokens(str(item.get("content", ""))) for item in original 

453 ) 

454 optimized_tokens = sum( 

455 self.count_tokens(str(item.get("content", ""))) for item in optimized 

456 ) 

457 

458 return { 

459 "original_tokens": original_tokens, 

460 "optimized_tokens": optimized_tokens, 

461 "tokens_saved": original_tokens - optimized_tokens, 

462 "savings_percentage": int( 

463 round( 

464 ((original_tokens - optimized_tokens) / original_tokens) * 100, 

465 1, 

466 ), 

467 ) 

468 if original_tokens > 0 

469 else 0, 

470 } 

471 

472 def track_usage( 

473 self, 

474 operation: str, 

475 request_tokens: int, 

476 response_tokens: int, 

477 optimization_applied: str | None = None, 

478 ) -> None: 

479 """Track token usage for monitoring.""" 

480 metrics = TokenUsageMetrics( 

481 request_tokens=request_tokens, 

482 response_tokens=response_tokens, 

483 total_tokens=request_tokens + response_tokens, 

484 timestamp=datetime.now().isoformat(), 

485 operation=operation, 

486 optimization_applied=optimization_applied, 

487 ) 

488 

489 self.usage_history.append(metrics) 

490 

491 # Keep only last 100 entries 

492 if len(self.usage_history) > 100: 492 ↛ 493line 492 didn't jump to line 493 because the condition on line 492 was never true

493 self.usage_history = self.usage_history[-100:] 

494 

495 def get_usage_stats(self, hours: int = 24) -> dict[str, Any]: 

496 """Get token usage statistics.""" 

497 cutoff = datetime.now() - timedelta(hours=hours) 

498 

499 recent_usage = [ 

500 m 

501 for m in self.usage_history 

502 if datetime.fromisoformat(m.timestamp) > cutoff 

503 ] 

504 

505 if not recent_usage: 505 ↛ 506line 505 didn't jump to line 506 because the condition on line 505 was never true

506 return {"status": "no_data", "period_hours": hours} 

507 

508 total_tokens = sum(m.total_tokens for m in recent_usage) 

509 avg_tokens = total_tokens / len(recent_usage) 

510 

511 # Count optimizations applied 

512 optimizations: dict[str, int] = {} 

513 for metric in recent_usage: 

514 if metric.optimization_applied: 

515 optimizations[metric.optimization_applied] = ( 

516 optimizations.get(metric.optimization_applied, 0) + 1 

517 ) 

518 

519 return { 

520 "status": "success", 

521 "period_hours": hours, 

522 "total_requests": len(recent_usage), 

523 "total_tokens": total_tokens, 

524 "average_tokens_per_request": round(avg_tokens, 1), 

525 "optimizations_applied": optimizations, 

526 "estimated_cost_savings": self._estimate_cost_savings(recent_usage), 

527 } 

528 

529 def _estimate_cost_savings( 

530 self, 

531 usage_metrics: list[TokenUsageMetrics], 

532 ) -> dict[str, float]: 

533 """Estimate cost savings from optimizations.""" 

534 # Rough cost estimation (adjust based on actual pricing) 

535 cost_per_1k_tokens = 0.01 # Example rate 

536 

537 optimized_requests = [m for m in usage_metrics if m.optimization_applied] 

538 if not optimized_requests: 538 ↛ 539line 538 didn't jump to line 539 because the condition on line 538 was never true

539 return {"savings_usd": 0.0, "requests_optimized": 0} 

540 

541 # Estimate 20-40% token savings from optimization 

542 estimated_savings_tokens = sum(m.total_tokens * 0.3 for m in optimized_requests) 

543 estimated_savings_usd = (estimated_savings_tokens / 1000) * cost_per_1k_tokens 

544 

545 return { 

546 "savings_usd": round(estimated_savings_usd, 4), 

547 "requests_optimized": len(optimized_requests), 

548 "estimated_tokens_saved": int(estimated_savings_tokens), 

549 } 

550 

551 async def cleanup_cache(self, max_age_hours: int = 1) -> int: 

552 """Clean up expired cache entries asynchronously.""" 

553 # ACB cache with TTL handles cleanup automatically 

554 return 0 

555 

556 

557# Global optimizer instance 

558_token_optimizer: TokenOptimizer | None = None 

559 

560 

561def get_token_optimizer() -> TokenOptimizer: 

562 """Get global token optimizer instance.""" 

563 global _token_optimizer 

564 if _token_optimizer is None: 

565 _token_optimizer = TokenOptimizer() 

566 return _token_optimizer 

567 

568 

569async def optimize_search_response( 

570 results: list[dict[str, Any]], 

571 strategy: str = "prioritize_recent", 

572 max_tokens: int = 4000, 

573) -> tuple[list[dict[str, Any]], dict[str, Any]]: 

574 """Async wrapper for search result optimization.""" 

575 optimizer = get_token_optimizer() 

576 return await optimizer.optimize_search_results(results, strategy, max_tokens) 

577 

578 

579async def get_cached_chunk(cache_key: str, chunk_index: int) -> dict[str, Any] | None: 

580 """Async wrapper for chunk retrieval.""" 

581 optimizer = get_token_optimizer() 

582 return await optimizer.get_chunk(cache_key, chunk_index) 

583 

584 

585async def track_token_usage( 

586 operation: str, 

587 request_tokens: int, 

588 response_tokens: int, 

589 optimization_applied: str | None = None, 

590) -> None: 

591 """Async wrapper for usage tracking.""" 

592 optimizer = get_token_optimizer() 

593 optimizer.track_usage( 

594 operation, 

595 request_tokens, 

596 response_tokens, 

597 optimization_applied, 

598 ) 

599 

600 

601async def get_token_usage_stats(hours: int = 24) -> dict[str, Any]: 

602 """Async wrapper for usage statistics.""" 

603 optimizer = get_token_optimizer() 

604 return optimizer.get_usage_stats(hours)