Coverage for session_mgmt_mcp/reflection_tools.py: 35.39%
201 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-01 05:22 -0700
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-01 05:22 -0700
1#!/usr/bin/env python3
2"""Reflection Tools for Claude Session Management.
4Provides memory and conversation search capabilities using DuckDB and local embeddings.
5"""
7import asyncio
8import hashlib
9import json
10import os
11import time
12from datetime import UTC, datetime
13from pathlib import Path
14from typing import Any
16# Database and embedding imports
17try:
18 import duckdb
20 DUCKDB_AVAILABLE = True
21except ImportError:
22 DUCKDB_AVAILABLE = False
24try:
25 import onnxruntime as ort
26 from transformers import AutoTokenizer
28 ONNX_AVAILABLE = True
29except ImportError:
30 ONNX_AVAILABLE = False
32import numpy as np
35class ReflectionDatabase:
36 """Manages DuckDB database for conversation memory and reflection."""
38 def __init__(self, db_path: str | None = None) -> None:
39 self.db_path = db_path or os.path.expanduser("~/.claude/data/reflection.duckdb")
40 Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
42 self.conn: duckdb.DuckDBPyConnection | None = None
43 self.onnx_session: ort.InferenceSession | None = None
44 self.tokenizer = None
45 self.embedding_dim = 384 # all-MiniLM-L6-v2 dimension
47 def __enter__(self):
48 """Context manager entry."""
49 return self
51 def __exit__(self, exc_type, exc_val, exc_tb):
52 """Context manager exit with cleanup."""
53 self.close()
55 def close(self) -> None:
56 """Close database connection."""
57 if self.conn:
58 try:
59 self.conn.close()
60 except Exception:
61 pass # Ignore errors during cleanup
62 finally:
63 self.conn = None
65 def __del__(self) -> None:
66 """Destructor to ensure cleanup."""
67 self.close()
69 async def initialize(self) -> None:
70 """Initialize database and embedding models."""
71 if not DUCKDB_AVAILABLE: 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true
72 msg = "DuckDB not available. Install with: pip install duckdb"
73 raise ImportError(msg)
75 # Initialize DuckDB connection with appropriate settings for concurrency
76 self.conn = duckdb.connect(self.db_path)
77 # DuckDB doesn't use SQLite-style PRAGMA commands
78 # DuckDB handles concurrency automatically with MVCC
80 # Initialize ONNX embedding model
81 if ONNX_AVAILABLE: 81 ↛ 82line 81 didn't jump to line 82 because the condition on line 81 was never true
82 try:
83 # Load tokenizer
84 self.tokenizer = AutoTokenizer.from_pretrained(
85 "sentence-transformers/all-MiniLM-L6-v2",
86 )
88 # Try to load ONNX model
89 model_path = os.path.expanduser(
90 "~/.claude/all-MiniLM-L6-v2/onnx/model.onnx",
91 )
92 if not os.path.exists(model_path):
93 print("ONNX model not found, will use text search fallback")
94 self.onnx_session = None
95 else:
96 self.onnx_session = ort.InferenceSession(model_path)
97 self.embedding_dim = 384
98 except Exception as e:
99 print(f"ONNX model loading failed, using text search: {e}")
100 self.onnx_session = None
101 else:
102 print("ONNX not available, using text search fallback")
104 # Create tables if they don't exist
105 await self._ensure_tables()
107 async def _ensure_tables(self) -> None:
108 """Ensure required tables exist."""
109 # Create conversations table
110 self.conn.execute("""
111 CREATE TABLE IF NOT EXISTS conversations (
112 id VARCHAR PRIMARY KEY,
113 content TEXT NOT NULL,
114 embedding FLOAT[384],
115 project VARCHAR,
116 timestamp TIMESTAMP,
117 metadata JSON
118 )
119 """)
121 # Create reflections table
122 self.conn.execute("""
123 CREATE TABLE IF NOT EXISTS reflections (
124 id VARCHAR PRIMARY KEY,
125 content TEXT NOT NULL,
126 embedding FLOAT[384],
127 tags VARCHAR[],
128 timestamp TIMESTAMP,
129 metadata JSON
130 )
131 """)
133 # Create project_groups table for multi-project coordination
134 self.conn.execute("""
135 CREATE TABLE IF NOT EXISTS project_groups (
136 id VARCHAR PRIMARY KEY,
137 name VARCHAR NOT NULL,
138 description TEXT,
139 projects VARCHAR[] NOT NULL,
140 created_at TIMESTAMP DEFAULT NOW(),
141 metadata JSON
142 )
143 """)
145 # Create project_dependencies table for project relationships
146 self.conn.execute("""
147 CREATE TABLE IF NOT EXISTS project_dependencies (
148 id VARCHAR PRIMARY KEY,
149 source_project VARCHAR NOT NULL,
150 target_project VARCHAR NOT NULL,
151 dependency_type VARCHAR NOT NULL,
152 description TEXT,
153 created_at TIMESTAMP DEFAULT NOW(),
154 metadata JSON,
155 UNIQUE(source_project, target_project, dependency_type)
156 )
157 """)
159 # Create session_links table for cross-project session coordination
160 self.conn.execute("""
161 CREATE TABLE IF NOT EXISTS session_links (
162 id VARCHAR PRIMARY KEY,
163 source_session_id VARCHAR NOT NULL,
164 target_session_id VARCHAR NOT NULL,
165 link_type VARCHAR NOT NULL,
166 context TEXT,
167 created_at TIMESTAMP DEFAULT NOW(),
168 metadata JSON,
169 UNIQUE(source_session_id, target_session_id, link_type)
170 )
171 """)
173 # Create search_index table for advanced search capabilities
174 self.conn.execute("""
175 CREATE TABLE IF NOT EXISTS search_index (
176 id VARCHAR PRIMARY KEY,
177 content_type VARCHAR NOT NULL, -- 'conversation', 'reflection', 'file', 'project'
178 content_id VARCHAR NOT NULL,
179 indexed_content TEXT NOT NULL,
180 search_metadata JSON,
181 last_indexed TIMESTAMP DEFAULT NOW(),
182 UNIQUE(content_type, content_id)
183 )
184 """)
186 # Create search_facets table for faceted search
187 self.conn.execute("""
188 CREATE TABLE IF NOT EXISTS search_facets (
189 id VARCHAR PRIMARY KEY,
190 content_type VARCHAR NOT NULL,
191 content_id VARCHAR NOT NULL,
192 facet_name VARCHAR NOT NULL,
193 facet_value VARCHAR NOT NULL,
194 created_at TIMESTAMP DEFAULT NOW(),
195 INDEX(facet_name, facet_value),
196 INDEX(content_type, content_id)
197 )
198 """)
200 # Create indices for better performance
201 await self._ensure_indices()
203 self.conn.commit()
205 async def _ensure_indices(self) -> None:
206 """Create indices for better query performance."""
207 indices = [
208 # Existing table indices
209 "CREATE INDEX IF NOT EXISTS idx_conversations_project ON conversations(project)",
210 "CREATE INDEX IF NOT EXISTS idx_conversations_timestamp ON conversations(timestamp)",
211 "CREATE INDEX IF NOT EXISTS idx_reflections_timestamp ON reflections(timestamp)",
212 # New multi-project indices
213 "CREATE INDEX IF NOT EXISTS idx_project_groups_projects ON project_groups USING GIN(projects)",
214 "CREATE INDEX IF NOT EXISTS idx_project_deps_source ON project_dependencies(source_project)",
215 "CREATE INDEX IF NOT EXISTS idx_project_deps_target ON project_dependencies(target_project)",
216 "CREATE INDEX IF NOT EXISTS idx_session_links_source ON session_links(source_session_id)",
217 "CREATE INDEX IF NOT EXISTS idx_session_links_target ON session_links(target_session_id)",
218 # Search indices
219 "CREATE INDEX IF NOT EXISTS idx_search_index_type ON search_index(content_type)",
220 "CREATE INDEX IF NOT EXISTS idx_search_index_last_indexed ON search_index(last_indexed)",
221 "CREATE INDEX IF NOT EXISTS idx_search_facets_name_value ON search_facets(facet_name, facet_value)",
222 ]
224 for index_sql in indices:
225 try:
226 self.conn.execute(index_sql)
227 except Exception as e:
228 # Some indices might not be supported in all DuckDB versions, continue
229 print(f"Index creation skipped: {e}")
231 async def get_embedding(self, text: str) -> list[float]:
232 """Get embedding for text using ONNX model."""
233 if self.onnx_session and self.tokenizer:
235 def _get_embedding():
236 # Tokenize text
237 encoded = self.tokenizer(
238 text,
239 truncation=True,
240 padding=True,
241 return_tensors="np",
242 )
244 # Run inference
245 outputs = self.onnx_session.run(
246 None,
247 {
248 "input_ids": encoded["input_ids"],
249 "attention_mask": encoded["attention_mask"],
250 "token_type_ids": encoded.get(
251 "token_type_ids",
252 np.zeros_like(encoded["input_ids"]),
253 ),
254 },
255 )
257 # Mean pooling
258 embeddings = outputs[0]
259 attention_mask = encoded["attention_mask"]
260 masked_embeddings = embeddings * np.expand_dims(attention_mask, axis=-1)
261 summed = np.sum(masked_embeddings, axis=1)
262 counts = np.sum(attention_mask, axis=1, keepdims=True)
263 mean_pooled = summed / counts
265 # Normalize
266 norms = np.linalg.norm(mean_pooled, axis=1, keepdims=True)
267 normalized = mean_pooled / norms
269 # Convert to float32 to match DuckDB FLOAT type
270 return normalized[0].astype(np.float32).tolist()
272 return await asyncio.get_event_loop().run_in_executor(None, _get_embedding)
274 msg = "No embedding model available"
275 raise RuntimeError(msg)
277 async def store_conversation(self, content: str, metadata: dict[str, Any]) -> str:
278 """Store conversation with optional embedding."""
279 conversation_id = hashlib.md5(f"{content}_{time.time()}".encode()).hexdigest()
281 if ONNX_AVAILABLE and self.onnx_session:
282 try:
283 embedding = await self.get_embedding(content)
284 except Exception:
285 embedding = None # Fallback to no embedding
286 else:
287 embedding = None # Store without embedding
289 await asyncio.get_event_loop().run_in_executor(
290 None,
291 lambda: self.conn.execute(
292 """
293 INSERT INTO conversations (id, content, embedding, project, timestamp, metadata)
294 VALUES (?, ?, ?, ?, ?, ?)
295 """,
296 [
297 conversation_id,
298 content,
299 embedding,
300 metadata.get("project"),
301 datetime.now(UTC),
302 json.dumps(metadata),
303 ],
304 ),
305 )
307 self.conn.commit()
308 return conversation_id
310 async def store_reflection(
311 self,
312 content: str,
313 tags: list[str] | None = None,
314 ) -> str:
315 """Store reflection/insight with optional embedding."""
316 reflection_id = hashlib.md5(
317 f"reflection_{content}_{time.time()}".encode(),
318 ).hexdigest()
320 if ONNX_AVAILABLE and self.onnx_session: 320 ↛ 321line 320 didn't jump to line 321 because the condition on line 320 was never true
321 try:
322 embedding = await self.get_embedding(content)
323 except Exception:
324 embedding = None # Fallback to no embedding
325 else:
326 embedding = None # Store without embedding
328 await asyncio.get_event_loop().run_in_executor(
329 None,
330 lambda: self.conn.execute(
331 """
332 INSERT INTO reflections (id, content, embedding, tags, timestamp, metadata)
333 VALUES (?, ?, ?, ?, ?, ?)
334 """,
335 [
336 reflection_id,
337 content,
338 embedding,
339 tags or [],
340 datetime.now(UTC),
341 json.dumps({"type": "reflection"}),
342 ],
343 ),
344 )
346 self.conn.commit()
347 return reflection_id
349 async def search_conversations(
350 self,
351 query: str,
352 limit: int = 5,
353 min_score: float = 0.7,
354 project: str | None = None,
355 ) -> list[dict[str, Any]]:
356 """Search conversations by text similarity (fallback to text search if no embeddings)."""
357 if ONNX_AVAILABLE and self.onnx_session:
358 # Use semantic search with embeddings
359 try:
360 query_embedding = await self.get_embedding(query)
362 sql = """
363 SELECT
364 id, content, embedding, project, timestamp, metadata,
365 array_cosine_similarity(embedding, CAST(? AS FLOAT[384])) as score
366 FROM conversations
367 WHERE embedding IS NOT NULL
368 """
369 params = [query_embedding]
371 if project:
372 sql += " AND project = ?"
373 params.append(project)
375 sql += """
376 ORDER BY score DESC
377 LIMIT ?
378 """
379 params.append(limit)
381 results = await asyncio.get_event_loop().run_in_executor(
382 None,
383 lambda: self.conn.execute(sql, params).fetchall(),
384 )
386 return [
387 {
388 "content": row[1],
389 "score": float(row[6]),
390 "timestamp": row[4],
391 "project": row[3],
392 "metadata": json.loads(row[5]) if row[5] else {},
393 }
394 for row in results
395 if float(row[6]) >= min_score
396 ]
397 except Exception as e:
398 print(f"Semantic search failed, falling back to text search: {e}")
399 # Fall through to text search
401 # Fallback to text search (if ONNX failed or not available)
402 search_terms = query.lower().split()
403 sql = "SELECT id, content, project, timestamp, metadata FROM conversations"
404 params = []
406 if project:
407 sql += " WHERE project = ?"
408 params.append(project)
410 sql += " ORDER BY timestamp DESC"
412 results = await asyncio.get_event_loop().run_in_executor(
413 None,
414 lambda: self.conn.execute(sql, params).fetchall(),
415 )
417 # Simple text matching score
418 matches = []
419 for row in results:
420 content_lower = row[1].lower()
421 score = sum(1 for term in search_terms if term in content_lower) / len(
422 search_terms,
423 )
425 if score > 0: # At least one term matches
426 matches.append(
427 {
428 "content": row[1],
429 "score": score,
430 "timestamp": row[3],
431 "project": row[2],
432 "metadata": json.loads(row[4]) if row[4] else {},
433 },
434 )
436 # Sort by score and return top matches
437 matches.sort(key=lambda x: x["score"], reverse=True)
438 return matches[:limit]
440 async def search_reflections(
441 self,
442 query: str,
443 limit: int = 5,
444 min_score: float = 0.7,
445 ) -> list[dict[str, Any]]:
446 """Search stored reflections by semantic similarity with text fallback."""
447 if ONNX_AVAILABLE and self.onnx_session: 447 ↛ 449line 447 didn't jump to line 449 because the condition on line 447 was never true
448 # Try semantic search first
449 try:
450 query_embedding = await self.get_embedding(query)
452 sql = """
453 SELECT
454 id, content, embedding, tags, timestamp, metadata,
455 array_cosine_similarity(embedding, CAST(? AS FLOAT[384])) as score
456 FROM reflections
457 WHERE embedding IS NOT NULL
458 ORDER BY score DESC
459 LIMIT ?
460 """
462 results = await asyncio.get_event_loop().run_in_executor(
463 None,
464 lambda: self.conn.execute(sql, [query_embedding, limit]).fetchall(),
465 )
467 semantic_results = [
468 {
469 "content": row[1],
470 "score": float(row[6]),
471 "tags": row[3] if row[3] else [],
472 "timestamp": row[4],
473 "metadata": json.loads(row[5]) if row[5] else {},
474 }
475 for row in results
476 if float(row[6]) >= min_score
477 ]
479 # If semantic search found results, return them
480 if semantic_results:
481 return semantic_results
483 except Exception as e:
484 print(f"Semantic search failed, falling back to text search: {e}")
486 # Fallback to text search for reflections
487 search_terms = query.lower().split()
488 sql = "SELECT id, content, tags, timestamp, metadata FROM reflections ORDER BY timestamp DESC"
490 results = await asyncio.get_event_loop().run_in_executor(
491 None,
492 lambda: self.conn.execute(sql).fetchall(),
493 )
495 # Simple text matching score for reflections
496 matches = []
497 for row in results:
498 content_lower = row[1].lower()
499 tags_lower = " ".join(row[2] if row[2] else []).lower()
500 combined_text = f"{content_lower} {tags_lower}"
502 # Calculate match score
503 score = sum(1 for term in search_terms if term in combined_text) / len(
504 search_terms,
505 )
507 if score > 0: # At least one term matches
508 matches.append(
509 {
510 "content": row[1],
511 "score": score,
512 "tags": row[2] if row[2] else [],
513 "timestamp": row[3],
514 "metadata": json.loads(row[4]) if row[4] else {},
515 },
516 )
518 # Sort by score and return top matches
519 matches.sort(key=lambda x: x["score"], reverse=True)
520 return matches[:limit]
522 async def search_by_file(
523 self,
524 file_path: str,
525 limit: int = 10,
526 project: str | None = None,
527 ) -> list[dict[str, Any]]:
528 """Search conversations that mention a specific file."""
529 sql = """
530 SELECT id, content, project, timestamp, metadata
531 FROM conversations
532 WHERE content LIKE ?
533 """
534 params = [f"%{file_path}%"]
536 if project:
537 sql += " AND project = ?"
538 params.append(project)
540 sql += " ORDER BY timestamp DESC LIMIT ?"
541 params.append(limit)
543 results = await asyncio.get_event_loop().run_in_executor(
544 None,
545 lambda: self.conn.execute(sql, params).fetchall(),
546 )
548 return [
549 {
550 "content": row[1],
551 "project": row[2],
552 "timestamp": row[3],
553 "metadata": json.loads(row[4]) if row[4] else {},
554 }
555 for row in results
556 ]
558 async def get_stats(self) -> dict[str, Any]:
559 """Get database statistics."""
560 try:
561 conv_count = await asyncio.get_event_loop().run_in_executor(
562 None,
563 lambda: self.conn.execute(
564 "SELECT COUNT(*) FROM conversations",
565 ).fetchone()[0],
566 )
568 refl_count = await asyncio.get_event_loop().run_in_executor(
569 None,
570 lambda: self.conn.execute(
571 "SELECT COUNT(*) FROM reflections",
572 ).fetchone()[0],
573 )
575 provider = (
576 "onnx-runtime"
577 if (self.onnx_session and ONNX_AVAILABLE)
578 else "text-search-only"
579 )
580 return {
581 "conversations_count": conv_count,
582 "reflections_count": refl_count,
583 "embedding_provider": provider,
584 "embedding_dimension": self.embedding_dim,
585 "database_path": str(self.db_path),
586 }
587 except Exception as e:
588 return {"error": f"Failed to get stats: {e}"}
591# Global database instance
592_reflection_db: ReflectionDatabase | None = None
595async def get_reflection_database() -> ReflectionDatabase:
596 """Get or create reflection database instance."""
597 global _reflection_db
598 if _reflection_db is None:
599 _reflection_db = ReflectionDatabase()
600 await _reflection_db.initialize()
601 return _reflection_db
604def cleanup_reflection_database() -> None:
605 """Clean up global reflection database instance."""
606 global _reflection_db
607 if _reflection_db:
608 _reflection_db.close()
609 _reflection_db = None
612def get_current_project() -> str | None:
613 """Get current project name from working directory."""
614 try:
615 cwd = Path.cwd()
616 # Try to detect project from common indicators
617 if (cwd / "pyproject.toml").exists() or (cwd / "package.json").exists(): 617 ↛ 620line 617 didn't jump to line 620 because the condition on line 617 was always true
618 return cwd.name
619 # Fallback to directory name
620 return cwd.name if cwd.name != "." else None
621 except Exception:
622 return None