Coverage for session_mgmt_mcp/token_optimizer.py: 89.33%

264 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-01 05:22 -0700

1#!/usr/bin/env python3 

2"""Token Optimization for Session Management MCP Server. 

3 

4Provides response chunking, content truncation, and context window monitoring 

5to reduce token usage while maintaining functionality. 

6""" 

7 

8import hashlib 

9import json 

10import re 

11from dataclasses import dataclass 

12from datetime import datetime, timedelta 

13from typing import Any 

14 

15import tiktoken 

16 

17 

18@dataclass 

19class TokenUsageMetrics: 

20 """Token usage metrics for monitoring.""" 

21 

22 request_tokens: int 

23 response_tokens: int 

24 total_tokens: int 

25 timestamp: str 

26 operation: str 

27 optimization_applied: str | None = None 

28 

29 

30@dataclass 

31class ChunkResult: 

32 """Result of response chunking.""" 

33 

34 chunks: list[str] 

35 total_chunks: int 

36 current_chunk: int 

37 cache_key: str 

38 metadata: dict[str, Any] 

39 

40 

41class TokenOptimizer: 

42 """Main token optimization class.""" 

43 

44 def __init__(self, max_tokens: int = 4000, chunk_size: int = 2000) -> None: 

45 self.max_tokens = max_tokens 

46 self.chunk_size = chunk_size 

47 self.encoding = self._get_encoding() 

48 self.usage_history: list[TokenUsageMetrics] = [] 

49 self.chunk_cache: dict[str, ChunkResult] = {} 

50 

51 # Token optimization strategies 

52 self.strategies = { 

53 "truncate_old": self._truncate_old_conversations, 

54 "summarize_content": self._summarize_long_content, 

55 "chunk_response": self._chunk_large_response, 

56 "filter_duplicates": self._filter_duplicate_content, 

57 "prioritize_recent": self._prioritize_recent_content, 

58 } 

59 

60 def _get_encoding(self): 

61 """Get tiktoken encoding for token counting.""" 

62 try: 

63 return tiktoken.get_encoding("cl100k_base") # GPT-4 encoding 

64 except Exception: 

65 # Fallback to approximate counting 

66 return None 

67 

68 def count_tokens(self, text: str) -> int: 

69 """Count tokens in text.""" 

70 if self.encoding: 

71 return len(self.encoding.encode(text)) 

72 # Rough approximation: ~4 chars per token 

73 return len(text) // 4 

74 

75 def optimize_search_results( 

76 self, 

77 results: list[dict[str, Any]], 

78 strategy: str = "truncate_old", 

79 max_tokens: int | None = None, 

80 ) -> tuple[list[dict[str, Any]], dict[str, Any]]: 

81 """Optimize search results to reduce token usage.""" 

82 max_tokens = max_tokens or self.max_tokens 

83 

84 if strategy in self.strategies: 84 ↛ 90line 84 didn't jump to line 90 because the condition on line 84 was always true

85 optimized_results, optimization_info = self.strategies[strategy]( 

86 results, 

87 max_tokens, 

88 ) 

89 else: 

90 optimized_results, optimization_info = results, {"strategy": "none"} 

91 

92 # Track optimization metrics 

93 optimization_info["original_count"] = len(results) 

94 optimization_info["optimized_count"] = len(optimized_results) 

95 optimization_info["token_savings"] = self._calculate_token_savings( 

96 results, 

97 optimized_results, 

98 ) 

99 

100 return optimized_results, optimization_info 

101 

102 def _truncate_old_conversations( 

103 self, 

104 results: list[dict[str, Any]], 

105 max_tokens: int, 

106 ) -> tuple[list[dict[str, Any]], dict[str, Any]]: 

107 """Truncate old conversations based on age and importance.""" 

108 if not results: 

109 return results, {"strategy": "truncate_old", "action": "no_results"} 

110 

111 # Sort by timestamp (newest first) 

112 sorted_results = sorted( 

113 results, 

114 key=lambda x: x.get("timestamp", ""), 

115 reverse=True, 

116 ) 

117 

118 optimized_results = [] 

119 current_tokens = 0 

120 truncation_count = 0 

121 

122 for result in sorted_results: 

123 content = result.get("content", "") 

124 content_tokens = self.count_tokens(content) 

125 

126 # Check if adding this result exceeds token limit 

127 if current_tokens + content_tokens > max_tokens: 

128 # Try truncating the content 

129 if len(optimized_results) < 3: # Always keep at least 3 recent results 

130 truncated_content = self._truncate_content( 

131 content, 

132 max_tokens - current_tokens, 

133 ) 

134 if truncated_content: 134 ↛ 122line 134 didn't jump to line 122 because the condition on line 134 was always true

135 result_copy = result.copy() 

136 result_copy["content"] = ( 

137 truncated_content + "... [truncated for token limit]" 

138 ) 

139 optimized_results.append(result_copy) 

140 truncation_count += 1 

141 break 

142 else: 

143 break 

144 else: 

145 optimized_results.append(result) 

146 current_tokens += content_tokens 

147 

148 return optimized_results, { 

149 "strategy": "truncate_old", 

150 "action": "truncated", 

151 "truncation_count": truncation_count, 

152 "final_token_count": current_tokens, 

153 } 

154 

155 def _summarize_long_content( 

156 self, 

157 results: list[dict[str, Any]], 

158 max_tokens: int, 

159 ) -> tuple[list[dict[str, Any]], dict[str, Any]]: 

160 """Summarize long content to reduce tokens.""" 

161 optimized_results = [] 

162 summarized_count = 0 

163 

164 for result in results: 

165 content = result.get("content", "") 

166 content_tokens = self.count_tokens(content) 

167 

168 if content_tokens > 500: # Summarize content longer than 500 tokens 

169 summary = self._create_quick_summary(content) 

170 result_copy = result.copy() 

171 result_copy["content"] = summary + " [auto-summarized]" 

172 optimized_results.append(result_copy) 

173 summarized_count += 1 

174 else: 

175 optimized_results.append(result) 

176 

177 return optimized_results, { 

178 "strategy": "summarize_content", 

179 "action": "summarized", 

180 "summarized_count": summarized_count, 

181 } 

182 

183 def _chunk_large_response( 

184 self, 

185 results: list[dict[str, Any]], 

186 max_tokens: int, 

187 ) -> tuple[list[dict[str, Any]], dict[str, Any]]: 

188 """Chunk large response into manageable pieces.""" 

189 if not results: 

190 return results, {"strategy": "chunk_response", "action": "no_results"} 

191 

192 # Estimate total tokens 

193 total_tokens = sum( 

194 self.count_tokens(str(result.get("content", ""))) for result in results 

195 ) 

196 

197 if total_tokens <= max_tokens: 197 ↛ 198line 197 didn't jump to line 198 because the condition on line 197 was never true

198 return results, { 

199 "strategy": "chunk_response", 

200 "action": "no_chunking_needed", 

201 } 

202 

203 # Create chunks 

204 chunks = [] 

205 current_chunk = [] 

206 current_chunk_tokens = 0 

207 

208 for result in results: 

209 result_tokens = self.count_tokens(str(result.get("content", ""))) 

210 

211 if current_chunk_tokens + result_tokens > self.chunk_size and current_chunk: 

212 chunks.append(current_chunk.copy()) 

213 current_chunk = [result] 

214 current_chunk_tokens = result_tokens 

215 else: 

216 current_chunk.append(result) 

217 current_chunk_tokens += result_tokens 

218 

219 if current_chunk: 219 ↛ 223line 219 didn't jump to line 223 because the condition on line 219 was always true

220 chunks.append(current_chunk) 

221 

222 # Return first chunk and create cache entry for the rest 

223 if chunks: 223 ↛ 234line 223 didn't jump to line 234 because the condition on line 223 was always true

224 cache_key = self._create_chunk_cache_entry(chunks) 

225 return chunks[0], { 

226 "strategy": "chunk_response", 

227 "action": "chunked", 

228 "total_chunks": len(chunks), 

229 "current_chunk": 1, 

230 "cache_key": cache_key, 

231 "has_more": len(chunks) > 1, 

232 } 

233 

234 return results, {"strategy": "chunk_response", "action": "failed"} 

235 

236 def _filter_duplicate_content( 

237 self, 

238 results: list[dict[str, Any]], 

239 max_tokens: int, 

240 ) -> tuple[list[dict[str, Any]], dict[str, Any]]: 

241 """Filter out duplicate or very similar content.""" 

242 if not results: 

243 return results, {"strategy": "filter_duplicates", "action": "no_results"} 

244 

245 seen_hashes = set() 

246 unique_results = [] 

247 duplicates_removed = 0 

248 

249 for result in results: 

250 content = result.get("content", "") 

251 # Create hash of normalized content 

252 normalized_content = re.sub(r"\s+", " ", content.lower().strip()) 

253 content_hash = hashlib.md5(normalized_content.encode()).hexdigest() 

254 

255 if content_hash not in seen_hashes: 

256 seen_hashes.add(content_hash) 

257 unique_results.append(result) 

258 else: 

259 duplicates_removed += 1 

260 

261 return unique_results, { 

262 "strategy": "filter_duplicates", 

263 "action": "filtered", 

264 "duplicates_removed": duplicates_removed, 

265 } 

266 

267 def _prioritize_recent_content( 

268 self, 

269 results: list[dict[str, Any]], 

270 max_tokens: int, 

271 ) -> tuple[list[dict[str, Any]], dict[str, Any]]: 

272 """Prioritize recent content and score-based ranking.""" 

273 if not results: 

274 return results, {"strategy": "prioritize_recent", "action": "no_results"} 

275 

276 # Calculate priority scores 

277 now = datetime.now() 

278 scored_results = [] 

279 

280 for result in results: 

281 score = 0.0 

282 

283 # Recency score (0-0.5) 

284 try: 

285 timestamp = datetime.fromisoformat(result.get("timestamp", "")) 

286 days_old = (now - timestamp).days 

287 recency_score = max(0, 0.5 - (days_old / 365) * 0.5) 

288 score += recency_score 

289 except (ValueError, TypeError): 

290 score += 0.1 # Default low recency score 

291 

292 # Relevance score if available (0-0.3) 

293 if "score" in result: 

294 score += result["score"] * 0.3 

295 

296 # Length penalty for very long content (0 to -0.2) 

297 content = result.get("content", "") 

298 if len(content) > 2000: 

299 score -= 0.2 

300 elif len(content) > 1000: 300 ↛ 301line 300 didn't jump to line 301 because the condition on line 300 was never true

301 score -= 0.1 

302 

303 # Code/technical content bonus (0-0.2) 

304 if any( 

305 keyword in content.lower() 

306 for keyword in ["def ", "class ", "function", "error", "exception"] 

307 ): 

308 score += 0.2 

309 

310 scored_results.append((score, result)) 

311 

312 # Sort by priority score and take top results within token limit 

313 scored_results.sort(key=lambda x: x[0], reverse=True) 

314 

315 prioritized_results = [] 

316 current_tokens = 0 

317 

318 for score, result in scored_results: 

319 result_tokens = self.count_tokens(str(result.get("content", ""))) 

320 if current_tokens + result_tokens <= max_tokens: 

321 prioritized_results.append(result) 

322 current_tokens += result_tokens 

323 else: 

324 break 

325 

326 return prioritized_results, { 

327 "strategy": "prioritize_recent", 

328 "action": "prioritized", 

329 "final_token_count": current_tokens, 

330 } 

331 

332 def _truncate_content(self, content: str, max_tokens: int) -> str: 

333 """Truncate content to fit within token limit.""" 

334 if self.count_tokens(content) <= max_tokens: 334 ↛ 335line 334 didn't jump to line 335 because the condition on line 334 was never true

335 return content 

336 

337 # Try to truncate at sentence boundaries 

338 sentences = content.split(". ") 

339 truncated = "" 

340 

341 for sentence in sentences: 341 ↛ 348line 341 didn't jump to line 348 because the loop on line 341 didn't complete

342 test_content = truncated + sentence + ". " 

343 if self.count_tokens(test_content) <= max_tokens: 

344 truncated = test_content 

345 else: 

346 break 

347 

348 if not truncated: 348 ↛ 350line 348 didn't jump to line 350 because the condition on line 348 was never true

349 # Fallback to character-based truncation 

350 if self.encoding: 

351 tokens = self.encoding.encode(content)[:max_tokens] 

352 truncated = self.encoding.decode(tokens) 

353 else: 

354 # Rough character limit 

355 char_limit = max_tokens * 4 

356 truncated = content[:char_limit] 

357 

358 return truncated.strip() 

359 

360 def _create_quick_summary(self, content: str, max_length: int = 200) -> str: 

361 """Create a quick summary of content.""" 

362 # Extract first and last sentences 

363 sentences = [s.strip() for s in content.split(".") if s.strip()] 

364 if not sentences: 364 ↛ 365line 364 didn't jump to line 365 because the condition on line 364 was never true

365 return content[:max_length] 

366 

367 if len(sentences) == 1: 367 ↛ 368line 367 didn't jump to line 368 because the condition on line 367 was never true

368 return sentences[0][:max_length] 

369 

370 first_sentence = sentences[0] 

371 last_sentence = sentences[-1] 

372 

373 summary = f"{first_sentence}. ... {last_sentence}" 

374 if len(summary) > max_length: 374 ↛ 377line 374 didn't jump to line 377 because the condition on line 374 was always true

375 summary = first_sentence[: max_length - 3] + "..." 

376 

377 return summary 

378 

379 def _create_chunk_cache_entry(self, chunks: list[list[dict[str, Any]]]) -> str: 

380 """Create cache entry for chunked results.""" 

381 cache_key = hashlib.md5( 

382 f"chunks_{datetime.now().isoformat()}_{len(chunks)}".encode(), 

383 ).hexdigest() 

384 

385 chunk_result = ChunkResult( 

386 chunks=[json.dumps(chunk) for chunk in chunks], 

387 total_chunks=len(chunks), 

388 current_chunk=1, 

389 cache_key=cache_key, 

390 metadata={ 

391 "created": datetime.now().isoformat(), 

392 "expires": (datetime.now() + timedelta(hours=1)).isoformat(), 

393 }, 

394 ) 

395 

396 self.chunk_cache[cache_key] = chunk_result 

397 return cache_key 

398 

399 def get_chunk(self, cache_key: str, chunk_index: int) -> dict[str, Any] | None: 

400 """Get a specific chunk from cache.""" 

401 if cache_key not in self.chunk_cache: 

402 return None 

403 

404 chunk_result = self.chunk_cache[cache_key] 

405 

406 # Check expiration 

407 try: 

408 expires = datetime.fromisoformat(chunk_result.metadata["expires"]) 

409 if datetime.now() > expires: 409 ↛ 410line 409 didn't jump to line 410 because the condition on line 409 was never true

410 del self.chunk_cache[cache_key] 

411 return None 

412 except (ValueError, KeyError): 

413 pass 

414 

415 if 1 <= chunk_index <= len(chunk_result.chunks): 

416 chunk_data = json.loads(chunk_result.chunks[chunk_index - 1]) 

417 return { 

418 "chunk": chunk_data, 

419 "current_chunk": chunk_index, 

420 "total_chunks": chunk_result.total_chunks, 

421 "cache_key": cache_key, 

422 "has_more": chunk_index < chunk_result.total_chunks, 

423 } 

424 

425 return None 

426 

427 def _calculate_token_savings( 

428 self, 

429 original: list[dict[str, Any]], 

430 optimized: list[dict[str, Any]], 

431 ) -> dict[str, int]: 

432 """Calculate token savings from optimization.""" 

433 original_tokens = sum( 

434 self.count_tokens(str(item.get("content", ""))) for item in original 

435 ) 

436 optimized_tokens = sum( 

437 self.count_tokens(str(item.get("content", ""))) for item in optimized 

438 ) 

439 

440 return { 

441 "original_tokens": original_tokens, 

442 "optimized_tokens": optimized_tokens, 

443 "tokens_saved": original_tokens - optimized_tokens, 

444 "savings_percentage": round( 

445 ((original_tokens - optimized_tokens) / original_tokens) * 100, 

446 1, 

447 ) 

448 if original_tokens > 0 

449 else 0, 

450 } 

451 

452 def track_usage( 

453 self, 

454 operation: str, 

455 request_tokens: int, 

456 response_tokens: int, 

457 optimization_applied: str | None = None, 

458 ) -> None: 

459 """Track token usage for monitoring.""" 

460 metrics = TokenUsageMetrics( 

461 request_tokens=request_tokens, 

462 response_tokens=response_tokens, 

463 total_tokens=request_tokens + response_tokens, 

464 timestamp=datetime.now().isoformat(), 

465 operation=operation, 

466 optimization_applied=optimization_applied, 

467 ) 

468 

469 self.usage_history.append(metrics) 

470 

471 # Keep only last 100 entries 

472 if len(self.usage_history) > 100: 472 ↛ 473line 472 didn't jump to line 473 because the condition on line 472 was never true

473 self.usage_history = self.usage_history[-100:] 

474 

475 def get_usage_stats(self, hours: int = 24) -> dict[str, Any]: 

476 """Get token usage statistics.""" 

477 cutoff = datetime.now() - timedelta(hours=hours) 

478 

479 recent_usage = [ 

480 m 

481 for m in self.usage_history 

482 if datetime.fromisoformat(m.timestamp) > cutoff 

483 ] 

484 

485 if not recent_usage: 

486 return {"status": "no_data", "period_hours": hours} 

487 

488 total_tokens = sum(m.total_tokens for m in recent_usage) 

489 avg_tokens = total_tokens / len(recent_usage) 

490 

491 # Count optimizations applied 

492 optimizations = {} 

493 for metric in recent_usage: 

494 if metric.optimization_applied: 

495 optimizations[metric.optimization_applied] = ( 

496 optimizations.get(metric.optimization_applied, 0) + 1 

497 ) 

498 

499 return { 

500 "status": "success", 

501 "period_hours": hours, 

502 "total_requests": len(recent_usage), 

503 "total_tokens": total_tokens, 

504 "average_tokens_per_request": round(avg_tokens, 1), 

505 "optimizations_applied": optimizations, 

506 "estimated_cost_savings": self._estimate_cost_savings(recent_usage), 

507 } 

508 

509 def _estimate_cost_savings( 

510 self, 

511 usage_metrics: list[TokenUsageMetrics], 

512 ) -> dict[str, float]: 

513 """Estimate cost savings from optimizations.""" 

514 # Rough cost estimation (adjust based on actual pricing) 

515 cost_per_1k_tokens = 0.01 # Example rate 

516 

517 optimized_requests = [m for m in usage_metrics if m.optimization_applied] 

518 if not optimized_requests: 518 ↛ 519line 518 didn't jump to line 519 because the condition on line 518 was never true

519 return {"savings_usd": 0.0, "requests_optimized": 0} 

520 

521 # Estimate 20-40% token savings from optimization 

522 estimated_savings_tokens = sum(m.total_tokens * 0.3 for m in optimized_requests) 

523 estimated_savings_usd = (estimated_savings_tokens / 1000) * cost_per_1k_tokens 

524 

525 return { 

526 "savings_usd": round(estimated_savings_usd, 4), 

527 "requests_optimized": len(optimized_requests), 

528 "estimated_tokens_saved": int(estimated_savings_tokens), 

529 } 

530 

531 def cleanup_cache(self, max_age_hours: int = 1): 

532 """Clean up expired cache entries.""" 

533 cutoff = datetime.now() - timedelta(hours=max_age_hours) 

534 expired_keys = [] 

535 

536 for key, chunk_result in self.chunk_cache.items(): 

537 try: 

538 expires = datetime.fromisoformat(chunk_result.metadata["expires"]) 

539 if expires < cutoff: 

540 expired_keys.append(key) 

541 except (ValueError, KeyError): 

542 expired_keys.append(key) # Remove entries with invalid expiration 

543 

544 for key in expired_keys: 

545 del self.chunk_cache[key] 

546 

547 return len(expired_keys) 

548 

549 

550# Global optimizer instance 

551_token_optimizer = None 

552 

553 

554def get_token_optimizer() -> TokenOptimizer: 

555 """Get global token optimizer instance.""" 

556 global _token_optimizer 

557 if _token_optimizer is None: 

558 _token_optimizer = TokenOptimizer() 

559 return _token_optimizer 

560 

561 

562async def optimize_search_response( 

563 results: list[dict[str, Any]], 

564 strategy: str = "prioritize_recent", 

565 max_tokens: int = 4000, 

566) -> tuple[list[dict[str, Any]], dict[str, Any]]: 

567 """Async wrapper for search result optimization.""" 

568 optimizer = get_token_optimizer() 

569 return optimizer.optimize_search_results(results, strategy, max_tokens) 

570 

571 

572async def get_cached_chunk(cache_key: str, chunk_index: int) -> dict[str, Any] | None: 

573 """Async wrapper for chunk retrieval.""" 

574 optimizer = get_token_optimizer() 

575 return optimizer.get_chunk(cache_key, chunk_index) 

576 

577 

578async def track_token_usage( 

579 operation: str, 

580 request_tokens: int, 

581 response_tokens: int, 

582 optimization_applied: str | None = None, 

583) -> None: 

584 """Async wrapper for usage tracking.""" 

585 optimizer = get_token_optimizer() 

586 optimizer.track_usage( 

587 operation, 

588 request_tokens, 

589 response_tokens, 

590 optimization_applied, 

591 ) 

592 

593 

594async def get_token_usage_stats(hours: int = 24) -> dict[str, Any]: 

595 """Async wrapper for usage statistics.""" 

596 optimizer = get_token_optimizer() 

597 return optimizer.get_usage_stats(hours)