Coverage for session_mgmt_mcp/reflection_tools.py: 35.39%

201 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-01 05:22 -0700

1#!/usr/bin/env python3 

2"""Reflection Tools for Claude Session Management. 

3 

4Provides memory and conversation search capabilities using DuckDB and local embeddings. 

5""" 

6 

7import asyncio 

8import hashlib 

9import json 

10import os 

11import time 

12from datetime import UTC, datetime 

13from pathlib import Path 

14from typing import Any 

15 

16# Database and embedding imports 

17try: 

18 import duckdb 

19 

20 DUCKDB_AVAILABLE = True 

21except ImportError: 

22 DUCKDB_AVAILABLE = False 

23 

24try: 

25 import onnxruntime as ort 

26 from transformers import AutoTokenizer 

27 

28 ONNX_AVAILABLE = True 

29except ImportError: 

30 ONNX_AVAILABLE = False 

31 

32import numpy as np 

33 

34 

35class ReflectionDatabase: 

36 """Manages DuckDB database for conversation memory and reflection.""" 

37 

38 def __init__(self, db_path: str | None = None) -> None: 

39 self.db_path = db_path or os.path.expanduser("~/.claude/data/reflection.duckdb") 

40 Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) 

41 

42 self.conn: duckdb.DuckDBPyConnection | None = None 

43 self.onnx_session: ort.InferenceSession | None = None 

44 self.tokenizer = None 

45 self.embedding_dim = 384 # all-MiniLM-L6-v2 dimension 

46 

47 def __enter__(self): 

48 """Context manager entry.""" 

49 return self 

50 

51 def __exit__(self, exc_type, exc_val, exc_tb): 

52 """Context manager exit with cleanup.""" 

53 self.close() 

54 

55 def close(self) -> None: 

56 """Close database connection.""" 

57 if self.conn: 

58 try: 

59 self.conn.close() 

60 except Exception: 

61 pass # Ignore errors during cleanup 

62 finally: 

63 self.conn = None 

64 

65 def __del__(self) -> None: 

66 """Destructor to ensure cleanup.""" 

67 self.close() 

68 

69 async def initialize(self) -> None: 

70 """Initialize database and embedding models.""" 

71 if not DUCKDB_AVAILABLE: 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true

72 msg = "DuckDB not available. Install with: pip install duckdb" 

73 raise ImportError(msg) 

74 

75 # Initialize DuckDB connection with appropriate settings for concurrency 

76 self.conn = duckdb.connect(self.db_path) 

77 # DuckDB doesn't use SQLite-style PRAGMA commands 

78 # DuckDB handles concurrency automatically with MVCC 

79 

80 # Initialize ONNX embedding model 

81 if ONNX_AVAILABLE: 81 ↛ 82line 81 didn't jump to line 82 because the condition on line 81 was never true

82 try: 

83 # Load tokenizer 

84 self.tokenizer = AutoTokenizer.from_pretrained( 

85 "sentence-transformers/all-MiniLM-L6-v2", 

86 ) 

87 

88 # Try to load ONNX model 

89 model_path = os.path.expanduser( 

90 "~/.claude/all-MiniLM-L6-v2/onnx/model.onnx", 

91 ) 

92 if not os.path.exists(model_path): 

93 print("ONNX model not found, will use text search fallback") 

94 self.onnx_session = None 

95 else: 

96 self.onnx_session = ort.InferenceSession(model_path) 

97 self.embedding_dim = 384 

98 except Exception as e: 

99 print(f"ONNX model loading failed, using text search: {e}") 

100 self.onnx_session = None 

101 else: 

102 print("ONNX not available, using text search fallback") 

103 

104 # Create tables if they don't exist 

105 await self._ensure_tables() 

106 

107 async def _ensure_tables(self) -> None: 

108 """Ensure required tables exist.""" 

109 # Create conversations table 

110 self.conn.execute(""" 

111 CREATE TABLE IF NOT EXISTS conversations ( 

112 id VARCHAR PRIMARY KEY, 

113 content TEXT NOT NULL, 

114 embedding FLOAT[384], 

115 project VARCHAR, 

116 timestamp TIMESTAMP, 

117 metadata JSON 

118 ) 

119 """) 

120 

121 # Create reflections table 

122 self.conn.execute(""" 

123 CREATE TABLE IF NOT EXISTS reflections ( 

124 id VARCHAR PRIMARY KEY, 

125 content TEXT NOT NULL, 

126 embedding FLOAT[384], 

127 tags VARCHAR[], 

128 timestamp TIMESTAMP, 

129 metadata JSON 

130 ) 

131 """) 

132 

133 # Create project_groups table for multi-project coordination 

134 self.conn.execute(""" 

135 CREATE TABLE IF NOT EXISTS project_groups ( 

136 id VARCHAR PRIMARY KEY, 

137 name VARCHAR NOT NULL, 

138 description TEXT, 

139 projects VARCHAR[] NOT NULL, 

140 created_at TIMESTAMP DEFAULT NOW(), 

141 metadata JSON 

142 ) 

143 """) 

144 

145 # Create project_dependencies table for project relationships 

146 self.conn.execute(""" 

147 CREATE TABLE IF NOT EXISTS project_dependencies ( 

148 id VARCHAR PRIMARY KEY, 

149 source_project VARCHAR NOT NULL, 

150 target_project VARCHAR NOT NULL, 

151 dependency_type VARCHAR NOT NULL, 

152 description TEXT, 

153 created_at TIMESTAMP DEFAULT NOW(), 

154 metadata JSON, 

155 UNIQUE(source_project, target_project, dependency_type) 

156 ) 

157 """) 

158 

159 # Create session_links table for cross-project session coordination 

160 self.conn.execute(""" 

161 CREATE TABLE IF NOT EXISTS session_links ( 

162 id VARCHAR PRIMARY KEY, 

163 source_session_id VARCHAR NOT NULL, 

164 target_session_id VARCHAR NOT NULL, 

165 link_type VARCHAR NOT NULL, 

166 context TEXT, 

167 created_at TIMESTAMP DEFAULT NOW(), 

168 metadata JSON, 

169 UNIQUE(source_session_id, target_session_id, link_type) 

170 ) 

171 """) 

172 

173 # Create search_index table for advanced search capabilities 

174 self.conn.execute(""" 

175 CREATE TABLE IF NOT EXISTS search_index ( 

176 id VARCHAR PRIMARY KEY, 

177 content_type VARCHAR NOT NULL, -- 'conversation', 'reflection', 'file', 'project' 

178 content_id VARCHAR NOT NULL, 

179 indexed_content TEXT NOT NULL, 

180 search_metadata JSON, 

181 last_indexed TIMESTAMP DEFAULT NOW(), 

182 UNIQUE(content_type, content_id) 

183 ) 

184 """) 

185 

186 # Create search_facets table for faceted search 

187 self.conn.execute(""" 

188 CREATE TABLE IF NOT EXISTS search_facets ( 

189 id VARCHAR PRIMARY KEY, 

190 content_type VARCHAR NOT NULL, 

191 content_id VARCHAR NOT NULL, 

192 facet_name VARCHAR NOT NULL, 

193 facet_value VARCHAR NOT NULL, 

194 created_at TIMESTAMP DEFAULT NOW(), 

195 INDEX(facet_name, facet_value), 

196 INDEX(content_type, content_id) 

197 ) 

198 """) 

199 

200 # Create indices for better performance 

201 await self._ensure_indices() 

202 

203 self.conn.commit() 

204 

205 async def _ensure_indices(self) -> None: 

206 """Create indices for better query performance.""" 

207 indices = [ 

208 # Existing table indices 

209 "CREATE INDEX IF NOT EXISTS idx_conversations_project ON conversations(project)", 

210 "CREATE INDEX IF NOT EXISTS idx_conversations_timestamp ON conversations(timestamp)", 

211 "CREATE INDEX IF NOT EXISTS idx_reflections_timestamp ON reflections(timestamp)", 

212 # New multi-project indices 

213 "CREATE INDEX IF NOT EXISTS idx_project_groups_projects ON project_groups USING GIN(projects)", 

214 "CREATE INDEX IF NOT EXISTS idx_project_deps_source ON project_dependencies(source_project)", 

215 "CREATE INDEX IF NOT EXISTS idx_project_deps_target ON project_dependencies(target_project)", 

216 "CREATE INDEX IF NOT EXISTS idx_session_links_source ON session_links(source_session_id)", 

217 "CREATE INDEX IF NOT EXISTS idx_session_links_target ON session_links(target_session_id)", 

218 # Search indices 

219 "CREATE INDEX IF NOT EXISTS idx_search_index_type ON search_index(content_type)", 

220 "CREATE INDEX IF NOT EXISTS idx_search_index_last_indexed ON search_index(last_indexed)", 

221 "CREATE INDEX IF NOT EXISTS idx_search_facets_name_value ON search_facets(facet_name, facet_value)", 

222 ] 

223 

224 for index_sql in indices: 

225 try: 

226 self.conn.execute(index_sql) 

227 except Exception as e: 

228 # Some indices might not be supported in all DuckDB versions, continue 

229 print(f"Index creation skipped: {e}") 

230 

231 async def get_embedding(self, text: str) -> list[float]: 

232 """Get embedding for text using ONNX model.""" 

233 if self.onnx_session and self.tokenizer: 

234 

235 def _get_embedding(): 

236 # Tokenize text 

237 encoded = self.tokenizer( 

238 text, 

239 truncation=True, 

240 padding=True, 

241 return_tensors="np", 

242 ) 

243 

244 # Run inference 

245 outputs = self.onnx_session.run( 

246 None, 

247 { 

248 "input_ids": encoded["input_ids"], 

249 "attention_mask": encoded["attention_mask"], 

250 "token_type_ids": encoded.get( 

251 "token_type_ids", 

252 np.zeros_like(encoded["input_ids"]), 

253 ), 

254 }, 

255 ) 

256 

257 # Mean pooling 

258 embeddings = outputs[0] 

259 attention_mask = encoded["attention_mask"] 

260 masked_embeddings = embeddings * np.expand_dims(attention_mask, axis=-1) 

261 summed = np.sum(masked_embeddings, axis=1) 

262 counts = np.sum(attention_mask, axis=1, keepdims=True) 

263 mean_pooled = summed / counts 

264 

265 # Normalize 

266 norms = np.linalg.norm(mean_pooled, axis=1, keepdims=True) 

267 normalized = mean_pooled / norms 

268 

269 # Convert to float32 to match DuckDB FLOAT type 

270 return normalized[0].astype(np.float32).tolist() 

271 

272 return await asyncio.get_event_loop().run_in_executor(None, _get_embedding) 

273 

274 msg = "No embedding model available" 

275 raise RuntimeError(msg) 

276 

277 async def store_conversation(self, content: str, metadata: dict[str, Any]) -> str: 

278 """Store conversation with optional embedding.""" 

279 conversation_id = hashlib.md5(f"{content}_{time.time()}".encode()).hexdigest() 

280 

281 if ONNX_AVAILABLE and self.onnx_session: 

282 try: 

283 embedding = await self.get_embedding(content) 

284 except Exception: 

285 embedding = None # Fallback to no embedding 

286 else: 

287 embedding = None # Store without embedding 

288 

289 await asyncio.get_event_loop().run_in_executor( 

290 None, 

291 lambda: self.conn.execute( 

292 """ 

293 INSERT INTO conversations (id, content, embedding, project, timestamp, metadata) 

294 VALUES (?, ?, ?, ?, ?, ?) 

295 """, 

296 [ 

297 conversation_id, 

298 content, 

299 embedding, 

300 metadata.get("project"), 

301 datetime.now(UTC), 

302 json.dumps(metadata), 

303 ], 

304 ), 

305 ) 

306 

307 self.conn.commit() 

308 return conversation_id 

309 

310 async def store_reflection( 

311 self, 

312 content: str, 

313 tags: list[str] | None = None, 

314 ) -> str: 

315 """Store reflection/insight with optional embedding.""" 

316 reflection_id = hashlib.md5( 

317 f"reflection_{content}_{time.time()}".encode(), 

318 ).hexdigest() 

319 

320 if ONNX_AVAILABLE and self.onnx_session: 320 ↛ 321line 320 didn't jump to line 321 because the condition on line 320 was never true

321 try: 

322 embedding = await self.get_embedding(content) 

323 except Exception: 

324 embedding = None # Fallback to no embedding 

325 else: 

326 embedding = None # Store without embedding 

327 

328 await asyncio.get_event_loop().run_in_executor( 

329 None, 

330 lambda: self.conn.execute( 

331 """ 

332 INSERT INTO reflections (id, content, embedding, tags, timestamp, metadata) 

333 VALUES (?, ?, ?, ?, ?, ?) 

334 """, 

335 [ 

336 reflection_id, 

337 content, 

338 embedding, 

339 tags or [], 

340 datetime.now(UTC), 

341 json.dumps({"type": "reflection"}), 

342 ], 

343 ), 

344 ) 

345 

346 self.conn.commit() 

347 return reflection_id 

348 

349 async def search_conversations( 

350 self, 

351 query: str, 

352 limit: int = 5, 

353 min_score: float = 0.7, 

354 project: str | None = None, 

355 ) -> list[dict[str, Any]]: 

356 """Search conversations by text similarity (fallback to text search if no embeddings).""" 

357 if ONNX_AVAILABLE and self.onnx_session: 

358 # Use semantic search with embeddings 

359 try: 

360 query_embedding = await self.get_embedding(query) 

361 

362 sql = """ 

363 SELECT 

364 id, content, embedding, project, timestamp, metadata, 

365 array_cosine_similarity(embedding, CAST(? AS FLOAT[384])) as score 

366 FROM conversations 

367 WHERE embedding IS NOT NULL 

368 """ 

369 params = [query_embedding] 

370 

371 if project: 

372 sql += " AND project = ?" 

373 params.append(project) 

374 

375 sql += """ 

376 ORDER BY score DESC 

377 LIMIT ? 

378 """ 

379 params.append(limit) 

380 

381 results = await asyncio.get_event_loop().run_in_executor( 

382 None, 

383 lambda: self.conn.execute(sql, params).fetchall(), 

384 ) 

385 

386 return [ 

387 { 

388 "content": row[1], 

389 "score": float(row[6]), 

390 "timestamp": row[4], 

391 "project": row[3], 

392 "metadata": json.loads(row[5]) if row[5] else {}, 

393 } 

394 for row in results 

395 if float(row[6]) >= min_score 

396 ] 

397 except Exception as e: 

398 print(f"Semantic search failed, falling back to text search: {e}") 

399 # Fall through to text search 

400 

401 # Fallback to text search (if ONNX failed or not available) 

402 search_terms = query.lower().split() 

403 sql = "SELECT id, content, project, timestamp, metadata FROM conversations" 

404 params = [] 

405 

406 if project: 

407 sql += " WHERE project = ?" 

408 params.append(project) 

409 

410 sql += " ORDER BY timestamp DESC" 

411 

412 results = await asyncio.get_event_loop().run_in_executor( 

413 None, 

414 lambda: self.conn.execute(sql, params).fetchall(), 

415 ) 

416 

417 # Simple text matching score 

418 matches = [] 

419 for row in results: 

420 content_lower = row[1].lower() 

421 score = sum(1 for term in search_terms if term in content_lower) / len( 

422 search_terms, 

423 ) 

424 

425 if score > 0: # At least one term matches 

426 matches.append( 

427 { 

428 "content": row[1], 

429 "score": score, 

430 "timestamp": row[3], 

431 "project": row[2], 

432 "metadata": json.loads(row[4]) if row[4] else {}, 

433 }, 

434 ) 

435 

436 # Sort by score and return top matches 

437 matches.sort(key=lambda x: x["score"], reverse=True) 

438 return matches[:limit] 

439 

440 async def search_reflections( 

441 self, 

442 query: str, 

443 limit: int = 5, 

444 min_score: float = 0.7, 

445 ) -> list[dict[str, Any]]: 

446 """Search stored reflections by semantic similarity with text fallback.""" 

447 if ONNX_AVAILABLE and self.onnx_session: 447 ↛ 449line 447 didn't jump to line 449 because the condition on line 447 was never true

448 # Try semantic search first 

449 try: 

450 query_embedding = await self.get_embedding(query) 

451 

452 sql = """ 

453 SELECT 

454 id, content, embedding, tags, timestamp, metadata, 

455 array_cosine_similarity(embedding, CAST(? AS FLOAT[384])) as score 

456 FROM reflections 

457 WHERE embedding IS NOT NULL 

458 ORDER BY score DESC 

459 LIMIT ? 

460 """ 

461 

462 results = await asyncio.get_event_loop().run_in_executor( 

463 None, 

464 lambda: self.conn.execute(sql, [query_embedding, limit]).fetchall(), 

465 ) 

466 

467 semantic_results = [ 

468 { 

469 "content": row[1], 

470 "score": float(row[6]), 

471 "tags": row[3] if row[3] else [], 

472 "timestamp": row[4], 

473 "metadata": json.loads(row[5]) if row[5] else {}, 

474 } 

475 for row in results 

476 if float(row[6]) >= min_score 

477 ] 

478 

479 # If semantic search found results, return them 

480 if semantic_results: 

481 return semantic_results 

482 

483 except Exception as e: 

484 print(f"Semantic search failed, falling back to text search: {e}") 

485 

486 # Fallback to text search for reflections 

487 search_terms = query.lower().split() 

488 sql = "SELECT id, content, tags, timestamp, metadata FROM reflections ORDER BY timestamp DESC" 

489 

490 results = await asyncio.get_event_loop().run_in_executor( 

491 None, 

492 lambda: self.conn.execute(sql).fetchall(), 

493 ) 

494 

495 # Simple text matching score for reflections 

496 matches = [] 

497 for row in results: 

498 content_lower = row[1].lower() 

499 tags_lower = " ".join(row[2] if row[2] else []).lower() 

500 combined_text = f"{content_lower} {tags_lower}" 

501 

502 # Calculate match score 

503 score = sum(1 for term in search_terms if term in combined_text) / len( 

504 search_terms, 

505 ) 

506 

507 if score > 0: # At least one term matches 

508 matches.append( 

509 { 

510 "content": row[1], 

511 "score": score, 

512 "tags": row[2] if row[2] else [], 

513 "timestamp": row[3], 

514 "metadata": json.loads(row[4]) if row[4] else {}, 

515 }, 

516 ) 

517 

518 # Sort by score and return top matches 

519 matches.sort(key=lambda x: x["score"], reverse=True) 

520 return matches[:limit] 

521 

522 async def search_by_file( 

523 self, 

524 file_path: str, 

525 limit: int = 10, 

526 project: str | None = None, 

527 ) -> list[dict[str, Any]]: 

528 """Search conversations that mention a specific file.""" 

529 sql = """ 

530 SELECT id, content, project, timestamp, metadata 

531 FROM conversations 

532 WHERE content LIKE ? 

533 """ 

534 params = [f"%{file_path}%"] 

535 

536 if project: 

537 sql += " AND project = ?" 

538 params.append(project) 

539 

540 sql += " ORDER BY timestamp DESC LIMIT ?" 

541 params.append(limit) 

542 

543 results = await asyncio.get_event_loop().run_in_executor( 

544 None, 

545 lambda: self.conn.execute(sql, params).fetchall(), 

546 ) 

547 

548 return [ 

549 { 

550 "content": row[1], 

551 "project": row[2], 

552 "timestamp": row[3], 

553 "metadata": json.loads(row[4]) if row[4] else {}, 

554 } 

555 for row in results 

556 ] 

557 

558 async def get_stats(self) -> dict[str, Any]: 

559 """Get database statistics.""" 

560 try: 

561 conv_count = await asyncio.get_event_loop().run_in_executor( 

562 None, 

563 lambda: self.conn.execute( 

564 "SELECT COUNT(*) FROM conversations", 

565 ).fetchone()[0], 

566 ) 

567 

568 refl_count = await asyncio.get_event_loop().run_in_executor( 

569 None, 

570 lambda: self.conn.execute( 

571 "SELECT COUNT(*) FROM reflections", 

572 ).fetchone()[0], 

573 ) 

574 

575 provider = ( 

576 "onnx-runtime" 

577 if (self.onnx_session and ONNX_AVAILABLE) 

578 else "text-search-only" 

579 ) 

580 return { 

581 "conversations_count": conv_count, 

582 "reflections_count": refl_count, 

583 "embedding_provider": provider, 

584 "embedding_dimension": self.embedding_dim, 

585 "database_path": str(self.db_path), 

586 } 

587 except Exception as e: 

588 return {"error": f"Failed to get stats: {e}"} 

589 

590 

591# Global database instance 

592_reflection_db: ReflectionDatabase | None = None 

593 

594 

595async def get_reflection_database() -> ReflectionDatabase: 

596 """Get or create reflection database instance.""" 

597 global _reflection_db 

598 if _reflection_db is None: 

599 _reflection_db = ReflectionDatabase() 

600 await _reflection_db.initialize() 

601 return _reflection_db 

602 

603 

604def cleanup_reflection_database() -> None: 

605 """Clean up global reflection database instance.""" 

606 global _reflection_db 

607 if _reflection_db: 

608 _reflection_db.close() 

609 _reflection_db = None 

610 

611 

612def get_current_project() -> str | None: 

613 """Get current project name from working directory.""" 

614 try: 

615 cwd = Path.cwd() 

616 # Try to detect project from common indicators 

617 if (cwd / "pyproject.toml").exists() or (cwd / "package.json").exists(): 617 ↛ 620line 617 didn't jump to line 620 because the condition on line 617 was always true

618 return cwd.name 

619 # Fallback to directory name 

620 return cwd.name if cwd.name != "." else None 

621 except Exception: 

622 return None