Coverage for session_mgmt_mcp/token_optimizer.py: 89.33%
264 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-01 05:22 -0700
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-01 05:22 -0700
1#!/usr/bin/env python3
2"""Token Optimization for Session Management MCP Server.
4Provides response chunking, content truncation, and context window monitoring
5to reduce token usage while maintaining functionality.
6"""
8import hashlib
9import json
10import re
11from dataclasses import dataclass
12from datetime import datetime, timedelta
13from typing import Any
15import tiktoken
18@dataclass
19class TokenUsageMetrics:
20 """Token usage metrics for monitoring."""
22 request_tokens: int
23 response_tokens: int
24 total_tokens: int
25 timestamp: str
26 operation: str
27 optimization_applied: str | None = None
30@dataclass
31class ChunkResult:
32 """Result of response chunking."""
34 chunks: list[str]
35 total_chunks: int
36 current_chunk: int
37 cache_key: str
38 metadata: dict[str, Any]
41class TokenOptimizer:
42 """Main token optimization class."""
44 def __init__(self, max_tokens: int = 4000, chunk_size: int = 2000) -> None:
45 self.max_tokens = max_tokens
46 self.chunk_size = chunk_size
47 self.encoding = self._get_encoding()
48 self.usage_history: list[TokenUsageMetrics] = []
49 self.chunk_cache: dict[str, ChunkResult] = {}
51 # Token optimization strategies
52 self.strategies = {
53 "truncate_old": self._truncate_old_conversations,
54 "summarize_content": self._summarize_long_content,
55 "chunk_response": self._chunk_large_response,
56 "filter_duplicates": self._filter_duplicate_content,
57 "prioritize_recent": self._prioritize_recent_content,
58 }
60 def _get_encoding(self):
61 """Get tiktoken encoding for token counting."""
62 try:
63 return tiktoken.get_encoding("cl100k_base") # GPT-4 encoding
64 except Exception:
65 # Fallback to approximate counting
66 return None
68 def count_tokens(self, text: str) -> int:
69 """Count tokens in text."""
70 if self.encoding:
71 return len(self.encoding.encode(text))
72 # Rough approximation: ~4 chars per token
73 return len(text) // 4
75 def optimize_search_results(
76 self,
77 results: list[dict[str, Any]],
78 strategy: str = "truncate_old",
79 max_tokens: int | None = None,
80 ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
81 """Optimize search results to reduce token usage."""
82 max_tokens = max_tokens or self.max_tokens
84 if strategy in self.strategies: 84 ↛ 90line 84 didn't jump to line 90 because the condition on line 84 was always true
85 optimized_results, optimization_info = self.strategies[strategy](
86 results,
87 max_tokens,
88 )
89 else:
90 optimized_results, optimization_info = results, {"strategy": "none"}
92 # Track optimization metrics
93 optimization_info["original_count"] = len(results)
94 optimization_info["optimized_count"] = len(optimized_results)
95 optimization_info["token_savings"] = self._calculate_token_savings(
96 results,
97 optimized_results,
98 )
100 return optimized_results, optimization_info
102 def _truncate_old_conversations(
103 self,
104 results: list[dict[str, Any]],
105 max_tokens: int,
106 ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
107 """Truncate old conversations based on age and importance."""
108 if not results:
109 return results, {"strategy": "truncate_old", "action": "no_results"}
111 # Sort by timestamp (newest first)
112 sorted_results = sorted(
113 results,
114 key=lambda x: x.get("timestamp", ""),
115 reverse=True,
116 )
118 optimized_results = []
119 current_tokens = 0
120 truncation_count = 0
122 for result in sorted_results:
123 content = result.get("content", "")
124 content_tokens = self.count_tokens(content)
126 # Check if adding this result exceeds token limit
127 if current_tokens + content_tokens > max_tokens:
128 # Try truncating the content
129 if len(optimized_results) < 3: # Always keep at least 3 recent results
130 truncated_content = self._truncate_content(
131 content,
132 max_tokens - current_tokens,
133 )
134 if truncated_content: 134 ↛ 122line 134 didn't jump to line 122 because the condition on line 134 was always true
135 result_copy = result.copy()
136 result_copy["content"] = (
137 truncated_content + "... [truncated for token limit]"
138 )
139 optimized_results.append(result_copy)
140 truncation_count += 1
141 break
142 else:
143 break
144 else:
145 optimized_results.append(result)
146 current_tokens += content_tokens
148 return optimized_results, {
149 "strategy": "truncate_old",
150 "action": "truncated",
151 "truncation_count": truncation_count,
152 "final_token_count": current_tokens,
153 }
155 def _summarize_long_content(
156 self,
157 results: list[dict[str, Any]],
158 max_tokens: int,
159 ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
160 """Summarize long content to reduce tokens."""
161 optimized_results = []
162 summarized_count = 0
164 for result in results:
165 content = result.get("content", "")
166 content_tokens = self.count_tokens(content)
168 if content_tokens > 500: # Summarize content longer than 500 tokens
169 summary = self._create_quick_summary(content)
170 result_copy = result.copy()
171 result_copy["content"] = summary + " [auto-summarized]"
172 optimized_results.append(result_copy)
173 summarized_count += 1
174 else:
175 optimized_results.append(result)
177 return optimized_results, {
178 "strategy": "summarize_content",
179 "action": "summarized",
180 "summarized_count": summarized_count,
181 }
183 def _chunk_large_response(
184 self,
185 results: list[dict[str, Any]],
186 max_tokens: int,
187 ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
188 """Chunk large response into manageable pieces."""
189 if not results:
190 return results, {"strategy": "chunk_response", "action": "no_results"}
192 # Estimate total tokens
193 total_tokens = sum(
194 self.count_tokens(str(result.get("content", ""))) for result in results
195 )
197 if total_tokens <= max_tokens: 197 ↛ 198line 197 didn't jump to line 198 because the condition on line 197 was never true
198 return results, {
199 "strategy": "chunk_response",
200 "action": "no_chunking_needed",
201 }
203 # Create chunks
204 chunks = []
205 current_chunk = []
206 current_chunk_tokens = 0
208 for result in results:
209 result_tokens = self.count_tokens(str(result.get("content", "")))
211 if current_chunk_tokens + result_tokens > self.chunk_size and current_chunk:
212 chunks.append(current_chunk.copy())
213 current_chunk = [result]
214 current_chunk_tokens = result_tokens
215 else:
216 current_chunk.append(result)
217 current_chunk_tokens += result_tokens
219 if current_chunk: 219 ↛ 223line 219 didn't jump to line 223 because the condition on line 219 was always true
220 chunks.append(current_chunk)
222 # Return first chunk and create cache entry for the rest
223 if chunks: 223 ↛ 234line 223 didn't jump to line 234 because the condition on line 223 was always true
224 cache_key = self._create_chunk_cache_entry(chunks)
225 return chunks[0], {
226 "strategy": "chunk_response",
227 "action": "chunked",
228 "total_chunks": len(chunks),
229 "current_chunk": 1,
230 "cache_key": cache_key,
231 "has_more": len(chunks) > 1,
232 }
234 return results, {"strategy": "chunk_response", "action": "failed"}
236 def _filter_duplicate_content(
237 self,
238 results: list[dict[str, Any]],
239 max_tokens: int,
240 ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
241 """Filter out duplicate or very similar content."""
242 if not results:
243 return results, {"strategy": "filter_duplicates", "action": "no_results"}
245 seen_hashes = set()
246 unique_results = []
247 duplicates_removed = 0
249 for result in results:
250 content = result.get("content", "")
251 # Create hash of normalized content
252 normalized_content = re.sub(r"\s+", " ", content.lower().strip())
253 content_hash = hashlib.md5(normalized_content.encode()).hexdigest()
255 if content_hash not in seen_hashes:
256 seen_hashes.add(content_hash)
257 unique_results.append(result)
258 else:
259 duplicates_removed += 1
261 return unique_results, {
262 "strategy": "filter_duplicates",
263 "action": "filtered",
264 "duplicates_removed": duplicates_removed,
265 }
267 def _prioritize_recent_content(
268 self,
269 results: list[dict[str, Any]],
270 max_tokens: int,
271 ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
272 """Prioritize recent content and score-based ranking."""
273 if not results:
274 return results, {"strategy": "prioritize_recent", "action": "no_results"}
276 # Calculate priority scores
277 now = datetime.now()
278 scored_results = []
280 for result in results:
281 score = 0.0
283 # Recency score (0-0.5)
284 try:
285 timestamp = datetime.fromisoformat(result.get("timestamp", ""))
286 days_old = (now - timestamp).days
287 recency_score = max(0, 0.5 - (days_old / 365) * 0.5)
288 score += recency_score
289 except (ValueError, TypeError):
290 score += 0.1 # Default low recency score
292 # Relevance score if available (0-0.3)
293 if "score" in result:
294 score += result["score"] * 0.3
296 # Length penalty for very long content (0 to -0.2)
297 content = result.get("content", "")
298 if len(content) > 2000:
299 score -= 0.2
300 elif len(content) > 1000: 300 ↛ 301line 300 didn't jump to line 301 because the condition on line 300 was never true
301 score -= 0.1
303 # Code/technical content bonus (0-0.2)
304 if any(
305 keyword in content.lower()
306 for keyword in ["def ", "class ", "function", "error", "exception"]
307 ):
308 score += 0.2
310 scored_results.append((score, result))
312 # Sort by priority score and take top results within token limit
313 scored_results.sort(key=lambda x: x[0], reverse=True)
315 prioritized_results = []
316 current_tokens = 0
318 for score, result in scored_results:
319 result_tokens = self.count_tokens(str(result.get("content", "")))
320 if current_tokens + result_tokens <= max_tokens:
321 prioritized_results.append(result)
322 current_tokens += result_tokens
323 else:
324 break
326 return prioritized_results, {
327 "strategy": "prioritize_recent",
328 "action": "prioritized",
329 "final_token_count": current_tokens,
330 }
332 def _truncate_content(self, content: str, max_tokens: int) -> str:
333 """Truncate content to fit within token limit."""
334 if self.count_tokens(content) <= max_tokens: 334 ↛ 335line 334 didn't jump to line 335 because the condition on line 334 was never true
335 return content
337 # Try to truncate at sentence boundaries
338 sentences = content.split(". ")
339 truncated = ""
341 for sentence in sentences: 341 ↛ 348line 341 didn't jump to line 348 because the loop on line 341 didn't complete
342 test_content = truncated + sentence + ". "
343 if self.count_tokens(test_content) <= max_tokens:
344 truncated = test_content
345 else:
346 break
348 if not truncated: 348 ↛ 350line 348 didn't jump to line 350 because the condition on line 348 was never true
349 # Fallback to character-based truncation
350 if self.encoding:
351 tokens = self.encoding.encode(content)[:max_tokens]
352 truncated = self.encoding.decode(tokens)
353 else:
354 # Rough character limit
355 char_limit = max_tokens * 4
356 truncated = content[:char_limit]
358 return truncated.strip()
360 def _create_quick_summary(self, content: str, max_length: int = 200) -> str:
361 """Create a quick summary of content."""
362 # Extract first and last sentences
363 sentences = [s.strip() for s in content.split(".") if s.strip()]
364 if not sentences: 364 ↛ 365line 364 didn't jump to line 365 because the condition on line 364 was never true
365 return content[:max_length]
367 if len(sentences) == 1: 367 ↛ 368line 367 didn't jump to line 368 because the condition on line 367 was never true
368 return sentences[0][:max_length]
370 first_sentence = sentences[0]
371 last_sentence = sentences[-1]
373 summary = f"{first_sentence}. ... {last_sentence}"
374 if len(summary) > max_length: 374 ↛ 377line 374 didn't jump to line 377 because the condition on line 374 was always true
375 summary = first_sentence[: max_length - 3] + "..."
377 return summary
379 def _create_chunk_cache_entry(self, chunks: list[list[dict[str, Any]]]) -> str:
380 """Create cache entry for chunked results."""
381 cache_key = hashlib.md5(
382 f"chunks_{datetime.now().isoformat()}_{len(chunks)}".encode(),
383 ).hexdigest()
385 chunk_result = ChunkResult(
386 chunks=[json.dumps(chunk) for chunk in chunks],
387 total_chunks=len(chunks),
388 current_chunk=1,
389 cache_key=cache_key,
390 metadata={
391 "created": datetime.now().isoformat(),
392 "expires": (datetime.now() + timedelta(hours=1)).isoformat(),
393 },
394 )
396 self.chunk_cache[cache_key] = chunk_result
397 return cache_key
399 def get_chunk(self, cache_key: str, chunk_index: int) -> dict[str, Any] | None:
400 """Get a specific chunk from cache."""
401 if cache_key not in self.chunk_cache:
402 return None
404 chunk_result = self.chunk_cache[cache_key]
406 # Check expiration
407 try:
408 expires = datetime.fromisoformat(chunk_result.metadata["expires"])
409 if datetime.now() > expires: 409 ↛ 410line 409 didn't jump to line 410 because the condition on line 409 was never true
410 del self.chunk_cache[cache_key]
411 return None
412 except (ValueError, KeyError):
413 pass
415 if 1 <= chunk_index <= len(chunk_result.chunks):
416 chunk_data = json.loads(chunk_result.chunks[chunk_index - 1])
417 return {
418 "chunk": chunk_data,
419 "current_chunk": chunk_index,
420 "total_chunks": chunk_result.total_chunks,
421 "cache_key": cache_key,
422 "has_more": chunk_index < chunk_result.total_chunks,
423 }
425 return None
427 def _calculate_token_savings(
428 self,
429 original: list[dict[str, Any]],
430 optimized: list[dict[str, Any]],
431 ) -> dict[str, int]:
432 """Calculate token savings from optimization."""
433 original_tokens = sum(
434 self.count_tokens(str(item.get("content", ""))) for item in original
435 )
436 optimized_tokens = sum(
437 self.count_tokens(str(item.get("content", ""))) for item in optimized
438 )
440 return {
441 "original_tokens": original_tokens,
442 "optimized_tokens": optimized_tokens,
443 "tokens_saved": original_tokens - optimized_tokens,
444 "savings_percentage": round(
445 ((original_tokens - optimized_tokens) / original_tokens) * 100,
446 1,
447 )
448 if original_tokens > 0
449 else 0,
450 }
452 def track_usage(
453 self,
454 operation: str,
455 request_tokens: int,
456 response_tokens: int,
457 optimization_applied: str | None = None,
458 ) -> None:
459 """Track token usage for monitoring."""
460 metrics = TokenUsageMetrics(
461 request_tokens=request_tokens,
462 response_tokens=response_tokens,
463 total_tokens=request_tokens + response_tokens,
464 timestamp=datetime.now().isoformat(),
465 operation=operation,
466 optimization_applied=optimization_applied,
467 )
469 self.usage_history.append(metrics)
471 # Keep only last 100 entries
472 if len(self.usage_history) > 100: 472 ↛ 473line 472 didn't jump to line 473 because the condition on line 472 was never true
473 self.usage_history = self.usage_history[-100:]
475 def get_usage_stats(self, hours: int = 24) -> dict[str, Any]:
476 """Get token usage statistics."""
477 cutoff = datetime.now() - timedelta(hours=hours)
479 recent_usage = [
480 m
481 for m in self.usage_history
482 if datetime.fromisoformat(m.timestamp) > cutoff
483 ]
485 if not recent_usage:
486 return {"status": "no_data", "period_hours": hours}
488 total_tokens = sum(m.total_tokens for m in recent_usage)
489 avg_tokens = total_tokens / len(recent_usage)
491 # Count optimizations applied
492 optimizations = {}
493 for metric in recent_usage:
494 if metric.optimization_applied:
495 optimizations[metric.optimization_applied] = (
496 optimizations.get(metric.optimization_applied, 0) + 1
497 )
499 return {
500 "status": "success",
501 "period_hours": hours,
502 "total_requests": len(recent_usage),
503 "total_tokens": total_tokens,
504 "average_tokens_per_request": round(avg_tokens, 1),
505 "optimizations_applied": optimizations,
506 "estimated_cost_savings": self._estimate_cost_savings(recent_usage),
507 }
509 def _estimate_cost_savings(
510 self,
511 usage_metrics: list[TokenUsageMetrics],
512 ) -> dict[str, float]:
513 """Estimate cost savings from optimizations."""
514 # Rough cost estimation (adjust based on actual pricing)
515 cost_per_1k_tokens = 0.01 # Example rate
517 optimized_requests = [m for m in usage_metrics if m.optimization_applied]
518 if not optimized_requests: 518 ↛ 519line 518 didn't jump to line 519 because the condition on line 518 was never true
519 return {"savings_usd": 0.0, "requests_optimized": 0}
521 # Estimate 20-40% token savings from optimization
522 estimated_savings_tokens = sum(m.total_tokens * 0.3 for m in optimized_requests)
523 estimated_savings_usd = (estimated_savings_tokens / 1000) * cost_per_1k_tokens
525 return {
526 "savings_usd": round(estimated_savings_usd, 4),
527 "requests_optimized": len(optimized_requests),
528 "estimated_tokens_saved": int(estimated_savings_tokens),
529 }
531 def cleanup_cache(self, max_age_hours: int = 1):
532 """Clean up expired cache entries."""
533 cutoff = datetime.now() - timedelta(hours=max_age_hours)
534 expired_keys = []
536 for key, chunk_result in self.chunk_cache.items():
537 try:
538 expires = datetime.fromisoformat(chunk_result.metadata["expires"])
539 if expires < cutoff:
540 expired_keys.append(key)
541 except (ValueError, KeyError):
542 expired_keys.append(key) # Remove entries with invalid expiration
544 for key in expired_keys:
545 del self.chunk_cache[key]
547 return len(expired_keys)
550# Global optimizer instance
551_token_optimizer = None
554def get_token_optimizer() -> TokenOptimizer:
555 """Get global token optimizer instance."""
556 global _token_optimizer
557 if _token_optimizer is None:
558 _token_optimizer = TokenOptimizer()
559 return _token_optimizer
562async def optimize_search_response(
563 results: list[dict[str, Any]],
564 strategy: str = "prioritize_recent",
565 max_tokens: int = 4000,
566) -> tuple[list[dict[str, Any]], dict[str, Any]]:
567 """Async wrapper for search result optimization."""
568 optimizer = get_token_optimizer()
569 return optimizer.optimize_search_results(results, strategy, max_tokens)
572async def get_cached_chunk(cache_key: str, chunk_index: int) -> dict[str, Any] | None:
573 """Async wrapper for chunk retrieval."""
574 optimizer = get_token_optimizer()
575 return optimizer.get_chunk(cache_key, chunk_index)
578async def track_token_usage(
579 operation: str,
580 request_tokens: int,
581 response_tokens: int,
582 optimization_applied: str | None = None,
583) -> None:
584 """Async wrapper for usage tracking."""
585 optimizer = get_token_optimizer()
586 optimizer.track_usage(
587 operation,
588 request_tokens,
589 response_tokens,
590 optimization_applied,
591 )
594async def get_token_usage_stats(hours: int = 24) -> dict[str, Any]:
595 """Async wrapper for usage statistics."""
596 optimizer = get_token_optimizer()
597 return optimizer.get_usage_stats(hours)