Coverage for session_buddy / token_optimizer.py: 89.02%
253 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
1#!/usr/bin/env python3
2"""Token Optimization for Session Management MCP Server.
4Provides response chunking, content truncation, and context window monitoring
5to reduce token usage while maintaining functionality.
6"""
8import hashlib
9import json
10import operator
11from dataclasses import dataclass
12from datetime import datetime, timedelta
13from typing import Any
15import tiktoken
16from session_buddy.acb_cache_adapter import ACBChunkCache, get_chunk_cache
19@dataclass
20class TokenUsageMetrics:
21 """Token usage metrics for monitoring."""
23 request_tokens: int
24 response_tokens: int
25 total_tokens: int
26 timestamp: str
27 operation: str
28 optimization_applied: str | None = None
31@dataclass
32class ChunkResult:
33 """Result of response chunking."""
35 chunks: list[str]
36 total_chunks: int
37 current_chunk: int
38 cache_key: str
39 metadata: dict[str, Any]
42class TokenOptimizer:
43 """Main token optimization class."""
45 def __init__(self, max_tokens: int = 4000, chunk_size: int = 2000) -> None:
46 self.max_tokens = max_tokens
47 self.chunk_size = chunk_size
48 self.encoding = self._get_encoding()
49 self.usage_history: list[TokenUsageMetrics] = []
50 self.chunk_cache: ACBChunkCache = get_chunk_cache() # ACB-backed cache
52 # Token optimization strategies
53 self.strategies = {
54 "truncate_old": self._truncate_old_conversations,
55 "summarize_content": self._summarize_long_content,
56 "chunk_response": self._chunk_large_response,
57 "filter_duplicates": self._filter_duplicate_content,
58 "prioritize_recent": self._prioritize_recent_content,
59 }
61 def _get_encoding(self) -> Any:
62 """Get tiktoken encoding for token counting."""
63 try:
64 return tiktoken.get_encoding("cl100k_base") # GPT-4 encoding
65 except Exception:
66 # Fallback to approximate counting
67 return None
69 def count_tokens(self, text: str) -> int:
70 """Count tokens in text."""
71 if self.encoding:
72 return len(self.encoding.encode(text))
73 # Rough approximation: ~4 chars per token
74 return len(text) // 4
76 async def optimize_search_results(
77 self,
78 results: list[dict[str, Any]],
79 strategy: str = "truncate_old",
80 max_tokens: int | None = None,
81 ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
82 """Optimize search results to reduce token usage."""
83 max_tokens = max_tokens or self.max_tokens
85 optimization_info: dict[str, Any] = {}
86 if strategy in self.strategies: 86 ↛ 93line 86 didn't jump to line 93 because the condition on line 86 was always true
87 optimized_results, info = await self.strategies[strategy](
88 results,
89 max_tokens,
90 )
91 optimization_info = info
92 else:
93 optimized_results = results # type: ignore[assignment]
94 optimization_info["strategy"] = "none"
96 # Track optimization metrics
97 optimization_info["original_count"] = len(results)
98 optimization_info["optimized_count"] = len(optimized_results)
99 optimization_info["token_savings"] = self._calculate_token_savings(
100 results,
101 optimized_results,
102 )
104 return optimized_results, optimization_info
106 async def _truncate_old_conversations(
107 self,
108 results: list[dict[str, Any]],
109 max_tokens: int,
110 ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
111 """Truncate old conversations based on age and importance."""
112 if not results:
113 return results, {"strategy": "truncate_old", "action": "no_results"}
115 # Sort by timestamp (newest first)
116 sorted_results = sorted(
117 results,
118 key=lambda x: x.get("timestamp", ""),
119 reverse=True,
120 )
122 optimized_results: list[dict[str, Any]] = []
123 current_tokens = 0
124 truncation_count = 0
126 for result in sorted_results:
127 content = result.get("content", "")
128 content_tokens = self.count_tokens(content)
130 # Check if adding this result exceeds token limit
131 if current_tokens + content_tokens > max_tokens:
132 # Try truncating the content
133 if len(optimized_results) < 3: # Always keep at least 3 recent results 133 ↛ 147line 133 didn't jump to line 147 because the condition on line 133 was always true
134 truncated_content = self._truncate_content(
135 content,
136 max_tokens - current_tokens,
137 )
138 if truncated_content: 138 ↛ 126line 138 didn't jump to line 126 because the condition on line 138 was always true
139 result_copy = result.copy()
140 result_copy["content"] = (
141 truncated_content + "... [truncated for token limit]"
142 )
143 optimized_results.append(result_copy)
144 truncation_count += 1
145 break
146 else:
147 break
148 else:
149 optimized_results.append(result)
150 current_tokens += content_tokens
152 return optimized_results, {
153 "strategy": "truncate_old",
154 "action": "truncated",
155 "truncation_count": truncation_count,
156 "final_token_count": current_tokens,
157 }
159 async def _summarize_long_content(
160 self,
161 results: list[dict[str, Any]],
162 max_tokens: int,
163 ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
164 """Summarize long content to reduce tokens."""
165 optimized_results = []
166 summarized_count = 0
168 for result in results:
169 content = result.get("content", "")
170 content_tokens = self.count_tokens(content)
172 if content_tokens > 500: # Summarize content longer than 500 tokens
173 summary = self._create_quick_summary(content)
174 result_copy = result.copy()
175 result_copy["content"] = summary + " [auto-summarized]"
176 optimized_results.append(result_copy)
177 summarized_count += 1
178 else:
179 optimized_results.append(result)
181 return optimized_results, {
182 "strategy": "summarize_content",
183 "action": "summarized",
184 "summarized_count": summarized_count,
185 }
187 async def _chunk_large_response(
188 self,
189 results: list[dict[str, Any]],
190 max_tokens: int,
191 ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
192 """Chunk large response into manageable pieces."""
193 if not results:
194 return results, {"strategy": "chunk_response", "action": "no_results"}
196 # Estimate total tokens
197 total_tokens = sum(
198 self.count_tokens(str(result.get("content", ""))) for result in results
199 )
201 if total_tokens <= max_tokens: 201 ↛ 202line 201 didn't jump to line 202 because the condition on line 201 was never true
202 return results, {
203 "strategy": "chunk_response",
204 "action": "no_chunking_needed",
205 }
207 # Create chunks
208 chunks: list[list[dict[str, Any]]] = []
209 current_chunk: list[dict[str, Any]] = []
210 current_chunk_tokens = 0
212 for result in results:
213 result_tokens = self.count_tokens(str(result.get("content", "")))
215 if current_chunk_tokens + result_tokens > self.chunk_size and current_chunk:
216 chunks.append(current_chunk.copy())
217 current_chunk = [result]
218 current_chunk_tokens = result_tokens
219 else:
220 current_chunk.append(result)
221 current_chunk_tokens += result_tokens
223 if current_chunk: 223 ↛ 227line 223 didn't jump to line 227 because the condition on line 223 was always true
224 chunks.append(current_chunk)
226 # Return first chunk and create cache entry for the rest
227 if chunks: 227 ↛ 238line 227 didn't jump to line 238 because the condition on line 227 was always true
228 cache_key = await self._create_chunk_cache_entry(chunks)
229 return chunks[0], {
230 "strategy": "chunk_response",
231 "action": "chunked",
232 "total_chunks": len(chunks),
233 "current_chunk": 1,
234 "cache_key": cache_key,
235 "has_more": len(chunks) > 1,
236 }
238 return results, {"strategy": "chunk_response", "action": "failed"}
240 async def _filter_duplicate_content(
241 self,
242 results: list[dict[str, Any]],
243 max_tokens: int,
244 ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
245 """Filter out duplicate or very similar content."""
246 if not results:
247 return results, {"strategy": "filter_duplicates", "action": "no_results"}
249 seen_hashes = set()
250 unique_results = []
251 duplicates_removed = 0
253 for result in results:
254 content = result.get("content", "")
255 # Create hash of normalized content using validated pattern
256 from session_buddy.utils.regex_patterns import SAFE_PATTERNS
258 normalize_pattern = SAFE_PATTERNS["whitespace_normalize"]
259 normalized_content = normalize_pattern.apply(content.lower().strip())
260 content_hash = hashlib.md5(
261 normalized_content.encode(),
262 usedforsecurity=False,
263 ).hexdigest()
265 if content_hash not in seen_hashes:
266 seen_hashes.add(content_hash)
267 unique_results.append(result)
268 else:
269 duplicates_removed += 1
271 return unique_results, {
272 "strategy": "filter_duplicates",
273 "action": "filtered",
274 "duplicates_removed": duplicates_removed,
275 }
277 async def _prioritize_recent_content(
278 self,
279 results: list[dict[str, Any]],
280 max_tokens: int,
281 ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
282 """Prioritize recent content and score-based ranking."""
283 if not results:
284 return results, {"strategy": "prioritize_recent", "action": "no_results"}
286 # Calculate priority scores
287 now = datetime.now()
288 scored_results = []
290 for result in results:
291 score = 0.0
293 # Recency score (0-0.5)
294 try:
295 timestamp = datetime.fromisoformat(result.get("timestamp", ""))
296 days_old = (now - timestamp).days
297 recency_score = max(0, 0.5 - (days_old / 365) * 0.5)
298 score += recency_score
299 except (ValueError, TypeError):
300 score += 0.1 # Default low recency score
302 # Relevance score if available (0-0.3)
303 if "score" in result:
304 score += result["score"] * 0.3
306 # Length penalty for very long content (0 to -0.2)
307 content = result.get("content", "")
308 if len(content) > 2000:
309 score -= 0.2
310 elif len(content) > 1000: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true
311 score -= 0.1
313 # Code/technical content bonus (0-0.2)
314 if any(
315 keyword in content.lower()
316 for keyword in ("def ", "class ", "function", "error", "exception")
317 ):
318 score += 0.2
320 scored_results.append((score, result))
322 # Sort by priority score and take top results within token limit
323 scored_results.sort(key=operator.itemgetter(0), reverse=True)
325 prioritized_results = []
326 current_tokens = 0
328 for score, result in scored_results:
329 result_tokens = self.count_tokens(str(result.get("content", "")))
330 if current_tokens + result_tokens <= max_tokens:
331 prioritized_results.append(result)
332 current_tokens += result_tokens
333 else:
334 break
336 return prioritized_results, {
337 "strategy": "prioritize_recent",
338 "action": "prioritized",
339 "final_token_count": current_tokens,
340 }
342 def _truncate_content(self, content: str, max_tokens: int) -> str:
343 """Truncate content to fit within token limit."""
344 if self.count_tokens(content) <= max_tokens: 344 ↛ 345line 344 didn't jump to line 345 because the condition on line 344 was never true
345 return content
347 # Try to truncate at sentence boundaries
348 sentences = content.split(". ")
349 truncated = ""
351 for sentence in sentences: 351 ↛ 358line 351 didn't jump to line 358 because the loop on line 351 didn't complete
352 test_content = truncated + sentence + ". "
353 if self.count_tokens(test_content) <= max_tokens:
354 truncated = test_content
355 else:
356 break
358 if not truncated: 358 ↛ 360line 358 didn't jump to line 360 because the condition on line 358 was never true
359 # Fallback to character-based truncation
360 if self.encoding:
361 tokens = self.encoding.encode(content)[:max_tokens]
362 truncated = self.encoding.decode(tokens)
363 else:
364 # Rough character limit
365 char_limit = max_tokens * 4
366 truncated = content[:char_limit]
368 return truncated.strip()
370 def _create_quick_summary(self, content: str, max_length: int = 200) -> str:
371 """Create a quick summary of content."""
372 # Extract first and last sentences
373 sentences = [s.strip() for s in content.split(".") if s.strip()]
374 if not sentences: 374 ↛ 375line 374 didn't jump to line 375 because the condition on line 374 was never true
375 return content[:max_length]
377 if len(sentences) == 1: 377 ↛ 378line 377 didn't jump to line 378 because the condition on line 377 was never true
378 return sentences[0][:max_length]
380 first_sentence = sentences[0]
381 last_sentence = sentences[-1]
383 summary = f"{first_sentence}. ... {last_sentence}"
384 if len(summary) > max_length: 384 ↛ 387line 384 didn't jump to line 387 because the condition on line 384 was always true
385 summary = first_sentence[: max_length - 3] + "..."
387 return summary
389 async def _create_chunk_cache_entry(
390 self,
391 chunks: list[list[dict[str, Any]]],
392 ) -> str:
393 """Create cache entry for chunked results."""
394 cache_key = hashlib.md5(
395 f"chunks_{datetime.now().isoformat()}_{len(chunks)}".encode(),
396 usedforsecurity=False,
397 ).hexdigest()
399 chunk_result = ChunkResult(
400 chunks=[json.dumps(chunk) for chunk in chunks],
401 total_chunks=len(chunks),
402 current_chunk=1,
403 cache_key=cache_key,
404 metadata={
405 "created": datetime.now().isoformat(),
406 "expires": (datetime.now() + timedelta(hours=1)).isoformat(),
407 },
408 )
410 await self.chunk_cache.set(cache_key, chunk_result)
411 return cache_key
413 async def get_chunk(
414 self,
415 cache_key: str,
416 chunk_index: int,
417 ) -> dict[str, Any] | None:
418 """Get a specific chunk from cache.
420 Args:
421 cache_key: Unique cache key for the chunked data
422 chunk_index: Index of chunk to retrieve (1-indexed)
424 Returns:
425 Dict with chunk data and metadata, or None if not found
427 """
428 if not await self.chunk_cache.__contains__(cache_key):
429 return None
431 chunk_result = await self.chunk_cache.get(cache_key)
433 if chunk_result and 1 <= chunk_index <= len(chunk_result.chunks):
434 chunk_data = json.loads(chunk_result.chunks[chunk_index - 1])
435 return {
436 "chunk": chunk_data,
437 "current_chunk": chunk_index,
438 "total_chunks": chunk_result.total_chunks,
439 "cache_key": cache_key,
440 "has_more": chunk_index < chunk_result.total_chunks,
441 }
443 return None
445 def _calculate_token_savings(
446 self,
447 original: list[dict[str, Any]],
448 optimized: list[dict[str, Any]],
449 ) -> dict[str, int]:
450 """Calculate token savings from optimization."""
451 original_tokens = sum(
452 self.count_tokens(str(item.get("content", ""))) for item in original
453 )
454 optimized_tokens = sum(
455 self.count_tokens(str(item.get("content", ""))) for item in optimized
456 )
458 return {
459 "original_tokens": original_tokens,
460 "optimized_tokens": optimized_tokens,
461 "tokens_saved": original_tokens - optimized_tokens,
462 "savings_percentage": int(
463 round(
464 ((original_tokens - optimized_tokens) / original_tokens) * 100,
465 1,
466 ),
467 )
468 if original_tokens > 0
469 else 0,
470 }
472 def track_usage(
473 self,
474 operation: str,
475 request_tokens: int,
476 response_tokens: int,
477 optimization_applied: str | None = None,
478 ) -> None:
479 """Track token usage for monitoring."""
480 metrics = TokenUsageMetrics(
481 request_tokens=request_tokens,
482 response_tokens=response_tokens,
483 total_tokens=request_tokens + response_tokens,
484 timestamp=datetime.now().isoformat(),
485 operation=operation,
486 optimization_applied=optimization_applied,
487 )
489 self.usage_history.append(metrics)
491 # Keep only last 100 entries
492 if len(self.usage_history) > 100: 492 ↛ 493line 492 didn't jump to line 493 because the condition on line 492 was never true
493 self.usage_history = self.usage_history[-100:]
495 def get_usage_stats(self, hours: int = 24) -> dict[str, Any]:
496 """Get token usage statistics."""
497 cutoff = datetime.now() - timedelta(hours=hours)
499 recent_usage = [
500 m
501 for m in self.usage_history
502 if datetime.fromisoformat(m.timestamp) > cutoff
503 ]
505 if not recent_usage: 505 ↛ 506line 505 didn't jump to line 506 because the condition on line 505 was never true
506 return {"status": "no_data", "period_hours": hours}
508 total_tokens = sum(m.total_tokens for m in recent_usage)
509 avg_tokens = total_tokens / len(recent_usage)
511 # Count optimizations applied
512 optimizations: dict[str, int] = {}
513 for metric in recent_usage:
514 if metric.optimization_applied:
515 optimizations[metric.optimization_applied] = (
516 optimizations.get(metric.optimization_applied, 0) + 1
517 )
519 return {
520 "status": "success",
521 "period_hours": hours,
522 "total_requests": len(recent_usage),
523 "total_tokens": total_tokens,
524 "average_tokens_per_request": round(avg_tokens, 1),
525 "optimizations_applied": optimizations,
526 "estimated_cost_savings": self._estimate_cost_savings(recent_usage),
527 }
529 def _estimate_cost_savings(
530 self,
531 usage_metrics: list[TokenUsageMetrics],
532 ) -> dict[str, float]:
533 """Estimate cost savings from optimizations."""
534 # Rough cost estimation (adjust based on actual pricing)
535 cost_per_1k_tokens = 0.01 # Example rate
537 optimized_requests = [m for m in usage_metrics if m.optimization_applied]
538 if not optimized_requests: 538 ↛ 539line 538 didn't jump to line 539 because the condition on line 538 was never true
539 return {"savings_usd": 0.0, "requests_optimized": 0}
541 # Estimate 20-40% token savings from optimization
542 estimated_savings_tokens = sum(m.total_tokens * 0.3 for m in optimized_requests)
543 estimated_savings_usd = (estimated_savings_tokens / 1000) * cost_per_1k_tokens
545 return {
546 "savings_usd": round(estimated_savings_usd, 4),
547 "requests_optimized": len(optimized_requests),
548 "estimated_tokens_saved": int(estimated_savings_tokens),
549 }
551 async def cleanup_cache(self, max_age_hours: int = 1) -> int:
552 """Clean up expired cache entries asynchronously."""
553 # ACB cache with TTL handles cleanup automatically
554 return 0
557# Global optimizer instance
558_token_optimizer: TokenOptimizer | None = None
561def get_token_optimizer() -> TokenOptimizer:
562 """Get global token optimizer instance."""
563 global _token_optimizer
564 if _token_optimizer is None:
565 _token_optimizer = TokenOptimizer()
566 return _token_optimizer
569async def optimize_search_response(
570 results: list[dict[str, Any]],
571 strategy: str = "prioritize_recent",
572 max_tokens: int = 4000,
573) -> tuple[list[dict[str, Any]], dict[str, Any]]:
574 """Async wrapper for search result optimization."""
575 optimizer = get_token_optimizer()
576 return await optimizer.optimize_search_results(results, strategy, max_tokens)
579async def get_cached_chunk(cache_key: str, chunk_index: int) -> dict[str, Any] | None:
580 """Async wrapper for chunk retrieval."""
581 optimizer = get_token_optimizer()
582 return await optimizer.get_chunk(cache_key, chunk_index)
585async def track_token_usage(
586 operation: str,
587 request_tokens: int,
588 response_tokens: int,
589 optimization_applied: str | None = None,
590) -> None:
591 """Async wrapper for usage tracking."""
592 optimizer = get_token_optimizer()
593 optimizer.track_usage(
594 operation,
595 request_tokens,
596 response_tokens,
597 optimization_applied,
598 )
601async def get_token_usage_stats(hours: int = 24) -> dict[str, Any]:
602 """Async wrapper for usage statistics."""
603 optimizer = get_token_optimizer()
604 return optimizer.get_usage_stats(hours)