Coverage for session_buddy / tools / memory_tools.py: 81.12%
298 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
1#!/usr/bin/env python3
2"""Memory and reflection management MCP tools.
4This module provides tools for storing, searching, and managing reflections and conversation memories.
6Refactored to use utility modules for reduced code duplication.
7"""
9from __future__ import annotations
11import asyncio
12import operator
13import typing as t
14from datetime import datetime
15from typing import TYPE_CHECKING, Any
17from session_buddy.utils.database_helpers import require_reflection_database
18from session_buddy.utils.error_handlers import (
19 DatabaseUnavailableError,
20 ValidationError,
21 _get_logger,
22 validate_required,
23)
24from session_buddy.utils.messages import ToolMessages
25from session_buddy.utils.tool_wrapper import format_reflection_result
27if TYPE_CHECKING:
28 from session_buddy.adapters.reflection_adapter import ReflectionDatabaseAdapter
31_reflection_tools_available: bool | None = None
32_reflection_db: ReflectionDatabaseAdapter | None = None
35def _check_reflection_tools_available() -> bool:
36 """Check if reflection tools are available, cached for reuse."""
37 global _reflection_tools_available
38 if _reflection_tools_available is not None: 38 ↛ 40line 38 didn't jump to line 40 because the condition on line 38 was always true
39 return _reflection_tools_available
40 try:
41 import duckdb
43 _reflection_tools_available = True
44 except ImportError:
45 _reflection_tools_available = False
46 return _reflection_tools_available
49async def _get_reflection_database() -> ReflectionDatabaseAdapter:
50 """Get reflection database instance (patchable for tests)."""
51 global _reflection_db
52 if _reflection_db is not None:
53 return _reflection_db
54 _reflection_db = await require_reflection_database()
55 return _reflection_db
58async def _execute_database_tool(
59 operation: t.Callable[[ReflectionDatabaseAdapter], t.Awaitable[t.Any]],
60 formatter: t.Callable[[t.Any], str],
61 operation_name: str,
62 validator: t.Callable[[], None] | None = None,
63) -> str:
64 try:
65 if validator:
66 validator()
68 db = await _get_reflection_database()
69 result = await operation(db)
70 return formatter(result)
71 except ValidationError as e:
72 return ToolMessages.validation_error(operation_name, str(e))
73 except DatabaseUnavailableError as e:
74 return ToolMessages.not_available(operation_name, str(e))
75 except Exception as e:
76 _get_logger().exception(f"Error in {operation_name}: {e}")
77 return ToolMessages.operation_failed(operation_name, e)
80async def _execute_simple_database_tool(
81 operation: t.Callable[[ReflectionDatabaseAdapter], t.Awaitable[str]],
82 operation_name: str,
83) -> str:
84 try:
85 db = await _get_reflection_database()
86 return await operation(db)
87 except DatabaseUnavailableError as e:
88 return ToolMessages.not_available(operation_name, str(e))
89 except Exception as e:
90 _get_logger().exception(f"Error in {operation_name}: {e}")
91 return ToolMessages.operation_failed(operation_name, e)
94def _format_score(score: float) -> str:
95 """Format a score as a percentage or relevance indicator."""
96 return f"{score:.2f}"
99# ============================================================================
100# Store Reflection Tool
101# ============================================================================
104async def _store_reflection_operation(
105 db: ReflectionDatabaseAdapter, content: str, tags: list[str]
106) -> dict[str, Any]:
107 """Execute reflection storage operation."""
108 success = await db.store_reflection(content, tags=tags)
109 return {
110 "success": success,
111 "content": content,
112 "tags": tags,
113 "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
114 }
117def _format_store_reflection_result(result: dict[str, Any]) -> str:
118 """Format reflection storage result."""
119 return format_reflection_result(
120 result["success"],
121 result["content"],
122 result.get("tags"),
123 result.get("timestamp"),
124 )
127async def _store_reflection_impl(content: str, tags: list[str] | None = None) -> str:
128 """Implementation for store_reflection tool."""
129 if not _check_reflection_tools_available():
130 return "Reflection tools not available. Install dependencies: uv sync --extra embeddings"
132 try:
133 validate_required(content, "content")
134 db = await _get_reflection_database()
135 result = await _store_reflection_operation(db, content, tags or [])
136 return _format_store_reflection_result(result)
137 except ValidationError as e:
138 return ToolMessages.validation_error("Store reflection", str(e))
139 except DatabaseUnavailableError as e:
140 return ToolMessages.not_available("Store reflection", str(e))
141 except Exception as e:
142 _get_logger().exception(f"Error storing reflection: {e}")
143 return f"Error storing reflection: {e}"
146# ============================================================================
147# Quick Search Tool
148# ============================================================================
151async def _quick_search_operation(
152 db: ReflectionDatabaseAdapter,
153 query: str,
154 project: str | None,
155 min_score: float,
156) -> str:
157 """Execute quick search operation and format results."""
158 results = await db.search_conversations(
159 query=query,
160 project=project,
161 limit=1,
162 min_score=min_score,
163 )
165 lines = [f"🔍 Quick search for: '{query}'"]
167 if results:
168 result = results[0]
169 lines.extend(
170 (
171 "📊 Found results (showing top 1)",
172 f"📝 {ToolMessages.truncate_text(result['content'], 150)}",
173 )
174 )
175 if result.get("project"): 175 ↛ 177line 175 didn't jump to line 177 because the condition on line 175 was always true
176 lines.append(f"📁 Project: {result['project']}")
177 if result.get("score") is not None: 177 ↛ 179line 177 didn't jump to line 179 because the condition on line 177 was always true
178 lines.append(f"⭐ Relevance: {_format_score(result['score'])}")
179 lines.append(f"📅 Date: {result.get('timestamp', 'Unknown')}")
180 else:
181 lines.extend(
182 (
183 "🔍 No results found",
184 "💡 Try adjusting your search terms or lowering min_score",
185 )
186 )
188 return "\n".join(lines)
191async def _quick_search_impl(
192 query: str,
193 min_score: float = 0.7,
194 project: str | None = None,
195) -> str:
196 """Implementation for quick_search tool."""
197 if not _check_reflection_tools_available():
198 return "Reflection tools not available. Install dependencies: uv sync --extra embeddings"
200 async def operation(db: ReflectionDatabaseAdapter) -> str:
201 return await _quick_search_operation(db, query, project, min_score)
203 return await _execute_simple_database_tool(operation, "Quick search")
206# ============================================================================
207# Search Summary Tool
208# ============================================================================
211async def _analyze_project_distribution(
212 results: list[dict[str, Any]],
213) -> dict[str, int]:
214 """Analyze project distribution of search results."""
215 projects: dict[str, int] = {}
216 for result in results:
217 proj = result.get("project", "Unknown")
218 projects[proj] = projects.get(proj, 0) + 1
219 return projects
222async def _analyze_relevance_scores(
223 results: list[dict[str, Any]],
224) -> tuple[float, list[float]]:
225 """Analyze relevance scores of search results."""
226 scores = [r.get("score", 0.0) for r in results if r.get("score") is not None]
227 avg_score = sum(scores) / len(scores) if scores else 0.0
228 return avg_score, scores
231async def _extract_common_themes(
232 results: list[dict[str, Any]],
233) -> list[tuple[str, int]]:
234 """Extract common themes from search results."""
235 all_content = " ".join([r["content"] for r in results])
236 words = all_content.lower().split()
237 word_freq: dict[str, int] = {}
239 for word in words:
240 if len(word) > 4: # Skip short words
241 word_freq[word] = word_freq.get(word, 0) + 1
243 if word_freq: 243 ↛ 245line 243 didn't jump to line 245 because the condition on line 243 was always true
244 return sorted(word_freq.items(), key=operator.itemgetter(1), reverse=True)[:5]
245 return []
248async def _format_search_summary(
249 query: str,
250 results: list[dict[str, Any]],
251) -> str:
252 """Format complete search summary."""
253 lines = [
254 f"📊 Search Summary for: '{query}'",
255 "=" * 50,
256 ]
258 if not results:
259 lines.extend(
260 [
261 "🔍 No results found",
262 "💡 Try different search terms or lower the min_score threshold",
263 ]
264 )
265 return "\n".join(lines)
267 # Basic stats
268 lines.append(f"📈 Total results: {len(results)}")
270 # Project distribution
271 projects = await _analyze_project_distribution(results)
272 if len(projects) > 1: 272 ↛ 280line 272 didn't jump to line 280 because the condition on line 272 was always true
273 lines.append("📁 Project distribution:")
274 for proj, count in sorted(
275 projects.items(), key=operator.itemgetter(1), reverse=True
276 ):
277 lines.append(f" • {proj}: {count} results")
279 # Time distribution
280 timestamps = [r.get("timestamp") for r in results if r.get("timestamp")]
281 if timestamps: 281 ↛ 285line 281 didn't jump to line 285 because the condition on line 281 was always true
282 lines.append(f"📅 Time range: {len(timestamps)} results with dates")
284 # Relevance scores
285 avg_score, scores = await _analyze_relevance_scores(results)
286 if scores: 286 ↛ 290line 286 didn't jump to line 290 because the condition on line 286 was always true
287 lines.append(f"⭐ Average relevance: {_format_score(avg_score)}")
289 # Common themes
290 top_words = await _extract_common_themes(results)
291 if top_words: 291 ↛ 296line 291 didn't jump to line 296 because the condition on line 291 was always true
292 lines.append("🔤 Common themes:")
293 for word, freq in top_words:
294 lines.append(f" • {word}: {freq} mentions")
296 return "\n".join(lines)
299async def _search_summary_operation(
300 db: ReflectionDatabaseAdapter,
301 query: str,
302 project: str | None,
303 min_score: float,
304) -> str:
305 """Execute search summary operation."""
306 results = await db.search_conversations(
307 query=query,
308 project=project,
309 limit=20,
310 min_score=min_score,
311 )
312 return await _format_search_summary(query, results)
315async def _search_summary_impl(
316 query: str,
317 min_score: float = 0.7,
318 project: str | None = None,
319) -> str:
320 """Implementation for search_summary tool."""
321 if not _check_reflection_tools_available():
322 return "Reflection tools not available. Install dependencies: uv sync --extra embeddings"
324 try:
325 db = await _get_reflection_database()
326 return await _search_summary_operation(db, query, project, min_score)
327 except DatabaseUnavailableError as e:
328 return ToolMessages.not_available("Search summary", str(e))
329 except Exception as e:
330 _get_logger().exception(f"Search summary error: {e}")
331 return f"Search summary error: {e}"
334# ============================================================================
335# Search by File Tool
336# ============================================================================
339async def _format_file_search_results(
340 file_path: str,
341 results: list[dict[str, Any]],
342) -> str:
343 """Format file search results."""
344 lines = [
345 f"📁 Searching conversations about: {file_path}",
346 "=" * 50,
347 ]
349 if not results:
350 lines.extend(
351 [
352 "🔍 No conversations found about this file",
353 "💡 The file might not have been discussed in previous sessions",
354 ]
355 )
356 return "\n".join(lines)
358 lines.append(f"📈 Found {len(results)} relevant conversations:")
360 for i, result in enumerate(results, 1):
361 lines.append(
362 f"\n{i}. 📝 {ToolMessages.truncate_text(result['content'], 200)}",
363 )
364 if result.get("project"): 364 ↛ 366line 364 didn't jump to line 366 because the condition on line 364 was always true
365 lines.append(f" 📁 Project: {result['project']}")
366 if result.get("score") is not None: 366 ↛ 368line 366 didn't jump to line 368 because the condition on line 366 was always true
367 lines.append(f" ⭐ Relevance: {_format_score(result['score'])}")
368 if result.get("timestamp"): 368 ↛ 360line 368 didn't jump to line 360 because the condition on line 368 was always true
369 lines.append(f" 📅 Date: {result['timestamp']}")
371 return "\n".join(lines)
374async def _search_by_file_operation(
375 db: ReflectionDatabaseAdapter,
376 file_path: str,
377 limit: int,
378 project: str | None,
379) -> str:
380 """Execute file search operation."""
381 results = await db.search_conversations(
382 query=file_path,
383 project=project,
384 limit=limit,
385 )
386 return await _format_file_search_results(file_path, results)
389async def _search_by_file_impl(
390 file_path: str,
391 limit: int = 10,
392 project: str | None = None,
393) -> str:
394 """Implementation for search_by_file tool."""
395 if not _check_reflection_tools_available():
396 return "Reflection tools not available. Install dependencies: uv sync --extra embeddings"
398 try:
399 db = await _get_reflection_database()
400 return await _search_by_file_operation(db, file_path, limit, project)
401 except DatabaseUnavailableError as e:
402 return ToolMessages.not_available("Search by file", str(e))
403 except Exception as e:
404 _get_logger().exception(f"File search error: {e}")
405 return f"File search error: {e}"
408# ============================================================================
409# Search by Concept Tool
410# ============================================================================
413async def _format_concept_search_results(
414 concept: str,
415 results: list[dict[str, Any]],
416 include_files: bool,
417) -> str:
418 """Format concept search results."""
419 lines = [
420 f"🧠 Searching for concept: '{concept}'",
421 "=" * 50,
422 ]
424 if not results:
425 lines.extend(
426 [
427 "🔍 No conversations found about this concept",
428 "💡 Try related terms or broader concepts",
429 ]
430 )
431 return "\n".join(lines)
433 lines.append(f"📈 Found {len(results)} related conversations:")
435 for i, result in enumerate(results, 1):
436 lines.append(
437 f"\n{i}. 📝 {ToolMessages.truncate_text(result['content'], 250)}",
438 )
439 if result.get("project"): 439 ↛ 441line 439 didn't jump to line 441 because the condition on line 439 was always true
440 lines.append(f" 📁 Project: {result['project']}")
441 if result.get("score") is not None: 441 ↛ 443line 441 didn't jump to line 443 because the condition on line 441 was always true
442 lines.append(f" ⭐ Relevance: {_format_score(result['score'])}")
443 if result.get("timestamp"): 443 ↛ 446line 443 didn't jump to line 446 because the condition on line 443 was always true
444 lines.append(f" 📅 Date: {result['timestamp']}")
446 if include_files and result.get("files"): 446 ↛ 447line 446 didn't jump to line 447 because the condition on line 446 was never true
447 files = result["files"][:3]
448 if files:
449 lines.append(f" 📄 Files: {', '.join(files)}")
451 return "\n".join(lines)
454async def _search_by_concept_operation(
455 db: ReflectionDatabaseAdapter,
456 concept: str,
457 include_files: bool,
458 limit: int,
459 project: str | None,
460) -> str:
461 """Execute concept search operation."""
462 results = await db.search_conversations(
463 query=concept,
464 project=project,
465 limit=limit,
466 )
467 return await _format_concept_search_results(concept, results, include_files)
470async def _search_by_concept_impl(
471 concept: str,
472 include_files: bool = True,
473 limit: int = 10,
474 project: str | None = None,
475) -> str:
476 """Implementation for search_by_concept tool."""
477 if not _check_reflection_tools_available():
478 return "Reflection tools not available. Install dependencies: uv sync --extra embeddings"
480 try:
481 db = await _get_reflection_database()
482 return await _search_by_concept_operation(
483 db, concept, include_files, limit, project
484 )
485 except DatabaseUnavailableError as e:
486 return ToolMessages.not_available("Search by concept", str(e))
487 except Exception as e:
488 _get_logger().exception(f"Concept search error: {e}")
489 return f"Concept search error: {e}"
492# ============================================================================
493# Reflection Stats Tool
494# ============================================================================
497def _format_stats_new(stats: dict[str, t.Any]) -> list[str]:
498 """Format statistics in new format (conversations_count, reflections_count)."""
499 conv_count = stats.get("conversations_count", 0)
500 refl_count = stats.get("reflections_count", 0)
501 provider = stats.get("embedding_provider", "unknown")
503 return [
504 f"📈 Total conversations: {conv_count}",
505 f"💭 Total reflections: {refl_count}",
506 f"🔧 Embedding provider: {provider}",
507 f"\n🏥 Database health: {'✅ Healthy' if (conv_count + refl_count) > 0 else '⚠️ Empty'}",
508 ]
511def _format_new_stats(stats: dict[str, t.Any]) -> list[str]:
512 """Backward-compatible alias for _format_stats_new."""
513 return _format_stats_new(stats)
516def _format_stats_old(stats: dict[str, t.Any]) -> list[str]:
517 """Format statistics in old/test format (total_reflections, projects, date_range)."""
518 output = [
519 f"📈 Total reflections: {stats.get('total_reflections', 0)}",
520 f"📁 Projects: {stats.get('projects', 0)}",
521 ]
523 # Add date range if present
524 date_range = stats.get("date_range")
525 if isinstance(date_range, dict):
526 output.append(
527 f"📅 Date range: {date_range.get('start')} to {date_range.get('end')}"
528 )
530 # Add recent activity if present
531 recent_activity = stats.get("recent_activity", [])
532 if recent_activity:
533 output.append("\n🕐 Recent activity:")
534 output.extend([f" • {activity}" for activity in recent_activity[:5]])
536 # Database health
537 is_healthy = stats.get("total_reflections", 0) > 0
538 output.append(f"\n🏥 Database health: {'✅ Healthy' if is_healthy else '⚠️ Empty'}")
540 return output
543def _format_old_stats(stats: dict[str, t.Any]) -> list[str]:
544 """Backward-compatible alias for _format_stats_old."""
545 return _format_stats_old(stats)
548async def _reflection_stats_operation(db: ReflectionDatabaseAdapter) -> str:
549 """Execute reflection stats operation."""
550 stats = await db.get_stats()
552 lines = ["📊 Reflection Database Statistics", "=" * 40]
554 if stats and "error" not in stats: 554 ↛ 561line 554 didn't jump to line 561 because the condition on line 554 was always true
555 # Format based on stat structure
556 if "conversations_count" in stats: 556 ↛ 559line 556 didn't jump to line 559 because the condition on line 556 was always true
557 lines.extend(_format_stats_new(stats))
558 else:
559 lines.extend(_format_stats_old(stats))
560 else:
561 lines.extend(
562 [
563 "📊 No statistics available",
564 "💡 Database may be empty or inaccessible",
565 ]
566 )
568 return "\n".join(lines)
571async def _reflection_stats_impl() -> str:
572 """Implementation for reflection_stats tool."""
573 if not _check_reflection_tools_available():
574 return "Reflection tools not available. Install dependencies: uv sync --extra embeddings"
576 async def operation(db: ReflectionDatabaseAdapter) -> str:
577 return await _reflection_stats_operation(db)
579 return await _execute_simple_database_tool(operation, "Reflection stats")
582# ============================================================================
583# Reset Database Tool
584# ============================================================================
587async def _close_db_connection(conn: t.Any) -> None:
588 """Close database connection, handling both async and sync cases."""
589 close_method = getattr(conn, "close", None)
590 if not callable(close_method): 590 ↛ 591line 590 didn't jump to line 591 because the condition on line 590 was never true
591 return
593 result = close_method()
594 if asyncio.iscoroutine(result): 594 ↛ exitline 594 didn't return from function '_close_db_connection' because the condition on line 594 was always true
595 await result
598async def _close_db_object(db_obj: t.Any) -> None:
599 """Close database object using async or sync close method."""
600 # Try async close first
601 aclose_method = getattr(db_obj, "aclose", None)
602 if callable(aclose_method): 602 ↛ 609line 602 didn't jump to line 609 because the condition on line 602 was always true
603 result = aclose_method()
604 if asyncio.iscoroutine(result): 604 ↛ 606line 604 didn't jump to line 606 because the condition on line 604 was always true
605 await result
606 return
608 # Fallback to sync close
609 close_method = getattr(db_obj, "close", None)
610 if callable(close_method):
611 close_method()
614async def _close_reflection_db_safely(db_obj: t.Any) -> None:
615 """Safely close reflection database and its connection.
617 Handles both legacy and adapter-style DB objects.
618 """
619 # Close connection if it exists (legacy style)
620 conn = getattr(db_obj, "conn", None)
621 if conn: 621 ↛ 625line 621 didn't jump to line 625 because the condition on line 621 was always true
622 await _close_db_connection(conn)
624 # Close the database object itself
625 await _close_db_object(db_obj)
628async def _reset_reflection_database_impl() -> str:
629 """Implementation for reset_reflection_database tool."""
630 if not _check_reflection_tools_available():
631 return "Reflection tools not available. Install dependencies: uv sync --extra embeddings"
633 global _reflection_db
634 try:
635 if _reflection_db:
636 await _close_reflection_db_safely(_reflection_db)
638 _reflection_db = None
639 await _get_reflection_database()
641 lines = [
642 "🔄 Reflection database connection reset",
643 "✅ New connection established successfully",
644 "💡 Database locks should be resolved",
645 ]
646 return "\n".join(lines)
648 except Exception as e:
649 return ToolMessages.operation_failed("Reset database", e)
652# ============================================================================
653# MCP Tool Registration
654# ============================================================================
657def register_memory_tools(mcp_server: Any) -> None:
658 """Register all memory management tools with the MCP server."""
660 @mcp_server.tool() # type: ignore[misc]
661 async def store_reflection(content: str, tags: list[str] | None = None) -> str:
662 """Store an important insight or reflection for future reference."""
663 return await _store_reflection_impl(content, tags)
665 @mcp_server.tool() # type: ignore[misc]
666 async def quick_search(
667 query: str,
668 min_score: float = 0.7,
669 project: str | None = None,
670 ) -> str:
671 """Quick search that returns only the count and top result for fast overview."""
672 return await _quick_search_impl(query, min_score, project)
674 @mcp_server.tool() # type: ignore[misc]
675 async def search_summary(
676 query: str,
677 limit: int = 10,
678 project: str | None = None,
679 min_score: float = 0.7,
680 ) -> str:
681 """Get aggregated insights from search results without individual result details."""
682 return await _search_summary_impl(query, min_score, project)
684 @mcp_server.tool() # type: ignore[misc]
685 async def search_by_file(
686 file_path: str,
687 limit: int = 10,
688 project: str | None = None,
689 min_score: float = 0.7,
690 ) -> str:
691 """Search for conversations that analyzed a specific file."""
692 return await _search_by_file_impl(file_path, limit, project)
694 @mcp_server.tool() # type: ignore[misc]
695 async def search_by_concept(
696 concept: str,
697 include_files: bool = True,
698 limit: int = 10,
699 project: str | None = None,
700 min_score: float = 0.7,
701 ) -> str:
702 """Search for conversations about a specific development concept."""
703 return await _search_by_concept_impl(concept, include_files, limit, project)
705 @mcp_server.tool() # type: ignore[misc]
706 async def reflection_stats(project: str | None = None) -> str:
707 """Get statistics about the reflection database."""
708 return await _reflection_stats_impl()
710 @mcp_server.tool() # type: ignore[misc]
711 async def reset_reflection_database() -> str:
712 """Reset the reflection database connection to fix lock issues."""
713 return await _reset_reflection_database_impl()