Coverage for session_buddy / reflection_tools.py: 75.78%
450 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
1#!/usr/bin/env python3
2"""Reflection Tools for Claude Session Management.
4Provides memory and conversation search capabilities using DuckDB and local embeddings.
6DEPRECATION NOTICE (Phase 2.7 - January 2025):
7 The ReflectionDatabase class in this module is deprecated and will be removed
8 in a future release. Please use ReflectionDatabaseAdapter from
9 session_buddy.adapters.reflection_adapter instead.
11 Migration Guide:
12 # Old (deprecated):
13 from session_buddy.reflection_tools import ReflectionDatabase
15 # New (recommended):
16 from session_buddy.adapters.reflection_adapter import ReflectionDatabaseAdapter
18 The adapter provides the same API while using ACB (Asynchronous Component Base)
19 for improved connection pooling, lifecycle management, and integration with
20 the dependency injection system.
21"""
23import asyncio
24import base64
25import hashlib
26import json
27import os
28import threading
29import time
30import warnings
31from contextlib import suppress
32from datetime import UTC, datetime
33from pathlib import Path
34from types import TracebackType
35from typing import TYPE_CHECKING, Any, Self
37if TYPE_CHECKING:
38 import duckdb
39 import onnxruntime as ort
40 from transformers import AutoTokenizer
42# Database and embedding imports
43try:
44 import duckdb
46 DUCKDB_AVAILABLE = True
47except ImportError:
48 DUCKDB_AVAILABLE = False
50import tempfile
52try:
53 import onnxruntime as ort
54 from transformers import AutoTokenizer
56 ONNX_AVAILABLE = True
57except ImportError:
58 ONNX_AVAILABLE = False
60import operator
62import numpy as np
64# Import the new adapter for replacement
65from session_buddy.adapters.reflection_adapter import ReflectionDatabaseAdapter
67_DB_PATH_UNSET = object()
70_SURROGATE_PREFIX = "__SB64__"
73def _encode_text_for_db(text: str) -> str:
74 try:
75 text.encode("utf-8")
76 return text
77 except UnicodeEncodeError:
78 data = text.encode("utf-8", "surrogatepass")
79 return _SURROGATE_PREFIX + base64.b64encode(data).decode("ascii")
82def _decode_text_from_db(text: str) -> str:
83 if text.startswith(_SURROGATE_PREFIX):
84 data = base64.b64decode(text[len(_SURROGATE_PREFIX) :])
85 return data.decode("utf-8", "surrogatepass")
86 return text
89class ReflectionDatabase:
90 """Manages DuckDB database for conversation memory and reflection.
92 DEPRECATED: This class is deprecated as of Phase 2.7 (January 2025).
93 Use ReflectionDatabaseAdapter from session_buddy.adapters.reflection_adapter instead.
95 The adapter provides the same API with improved ACB integration:
96 - Connection pooling and lifecycle management
97 - Dependency injection support
98 - Better async/await patterns
100 This class will be removed in a future release.
101 """
103 def __init__(self, db_path: str | None | object = _DB_PATH_UNSET) -> None:
104 # Issue deprecation warning
105 warnings.warn(
106 "ReflectionDatabase is deprecated and will be removed in a future release. "
107 "Use ReflectionDatabaseAdapter from session_buddy.adapters.reflection_adapter instead.",
108 DeprecationWarning,
109 stacklevel=2,
110 )
112 if db_path is None:
113 msg = "db_path cannot be None"
114 raise TypeError(msg)
116 if db_path is _DB_PATH_UNSET:
117 resolved_path: str = os.path.expanduser("~/.claude/data/reflection.duckdb")
118 else:
119 resolved_path = os.path.expanduser(str(db_path))
121 # Special-case empty path: treat as in-memory to avoid filesystem issues
122 if resolved_path in {"", ":memory:"}:
123 self.db_path = ":memory:"
124 self.is_temp_db = True
125 else:
126 self.db_path = resolved_path
127 self.is_temp_db = False
129 # Use thread-local storage for connections to avoid threading issues
130 self.local = threading.local()
131 self.lock = threading.RLock() # Re-entrant for nested access in temp DB
132 self.onnx_session: ort.InferenceSession | None = None
133 self.tokenizer = None
134 self.embedding_dim = 384 # all-MiniLM-L6-v2 dimension
135 self._initialized = False # Track initialization state
137 @property
138 def conn(self) -> duckdb.DuckDBPyConnection | None:
139 """Get the connection for the current thread (for backward compatibility)."""
140 return getattr(self.local, "conn", None)
142 def __enter__(self) -> Self:
143 """Context manager entry."""
144 return self
146 def __exit__(
147 self,
148 exc_type: type[BaseException] | None,
149 exc_val: BaseException | None,
150 exc_tb: TracebackType | None,
151 ) -> None:
152 """Context manager exit with cleanup."""
153 self.close()
155 async def __aenter__(self) -> Self:
156 """Async context manager entry."""
157 await self.initialize()
158 return self
160 async def __aexit__(
161 self,
162 exc_type: type[BaseException] | None,
163 exc_val: BaseException | None,
164 exc_tb: TracebackType | None,
165 ) -> None:
166 """Async context manager exit with cleanup."""
167 self.close()
169 def close(self) -> None:
170 """Close database connections for all threads."""
171 if hasattr(self.local, "conn") and self.local.conn:
172 try:
173 self.local.conn.close()
174 except Exception:
175 # nosec B110 - intentionally suppressing exceptions during cleanup
176 pass # Ignore errors during cleanup
177 finally:
178 self.local.conn = None
180 def __del__(self) -> None:
181 """Destructor to ensure cleanup."""
182 self.close()
184 async def initialize(self) -> None:
185 """Initialize database and embedding models."""
186 if not DUCKDB_AVAILABLE: 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true
187 msg = "DuckDB not available. Install with: pip install duckdb"
188 raise ImportError(msg)
190 # Initialize ONNX embedding model
191 if ONNX_AVAILABLE and not os.environ.get("PYTEST_CURRENT_TEST"): 191 ↛ 192line 191 didn't jump to line 192 because the condition on line 191 was never true
192 try:
193 model_path = os.path.expanduser(
194 "~/.claude/all-MiniLM-L6-v2/onnx/model.onnx",
195 )
196 if Path(model_path).exists():
197 # Load tokenizer with revision pinning for security
198 self.tokenizer = AutoTokenizer.from_pretrained(
199 "sentence-transformers/all-MiniLM-L6-v2",
200 revision="7dbbc90392e2f80f3d3c277d6e90027e55de9125", # Pin to specific commit
201 )
202 self.onnx_session = ort.InferenceSession(model_path)
203 self.embedding_dim = 384
204 else:
205 self.onnx_session = None
206 self.tokenizer = None
207 except Exception:
208 self.onnx_session = None
209 self.tokenizer = None
210 else:
211 self.onnx_session = None
212 self.tokenizer = None
214 if not self.is_temp_db:
215 with suppress(Exception):
216 Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
218 # Create tables if they don't exist (this will initialize a connection in the main thread)
219 # During initialization, we need to create a direct connection without going through _get_conn
220 # since _get_conn checks for initialization state
221 try:
222 temp_conn = duckdb.connect(
223 self.db_path, config={"allow_unsigned_extensions": True}
224 )
225 except Exception as e:
226 msg = f"Database connection error (directory/permission): {e}"
227 raise RuntimeError(msg) from e
228 try:
229 # Create conversations table
230 temp_conn.execute("""
231 CREATE TABLE IF NOT EXISTS conversations (
232 id VARCHAR PRIMARY KEY,
233 content TEXT NOT NULL,
234 embedding FLOAT[384],
235 project VARCHAR,
236 timestamp TIMESTAMP,
237 metadata JSON
238 )
239 """)
241 # Create reflections table
242 temp_conn.execute("""
243 CREATE TABLE IF NOT EXISTS reflections (
244 id VARCHAR PRIMARY KEY,
245 content TEXT NOT NULL,
246 embedding FLOAT[384],
247 project VARCHAR,
248 tags VARCHAR[],
249 timestamp TIMESTAMP,
250 metadata JSON
251 )
252 """)
254 # Create reflection_tags table for tag-based search
255 temp_conn.execute("""
256 CREATE TABLE IF NOT EXISTS reflection_tags (
257 reflection_id VARCHAR,
258 tag VARCHAR,
259 PRIMARY KEY (reflection_id, tag)
260 )
261 """)
263 # Create indexes for performance
264 temp_conn.execute(
265 "CREATE INDEX IF NOT EXISTS idx_conversations_project ON conversations(project)"
266 )
267 temp_conn.execute(
268 "CREATE INDEX IF NOT EXISTS idx_conversations_timestamp ON conversations(timestamp)"
269 )
270 temp_conn.execute(
271 "CREATE INDEX IF NOT EXISTS idx_reflections_project ON reflections(project)"
272 )
273 temp_conn.execute(
274 "CREATE INDEX IF NOT EXISTS idx_reflections_timestamp ON reflections(timestamp)"
275 )
276 temp_conn.execute(
277 "CREATE INDEX IF NOT EXISTS idx_reflection_tags_tag ON reflection_tags(tag)"
278 )
279 finally:
280 temp_conn.close()
282 # Now mark as initialized
283 self._initialized = True
285 # Create the connection for the current thread so that the conn property works
286 if self.is_temp_db:
287 # For temp DBs, create the shared connection
288 with self.lock:
289 self._shared_conn = duckdb.connect(
290 self.db_path, config={"allow_unsigned_extensions": True}
291 )
292 # Create tables in the shared connection for in-memory DB
293 self._initialize_shared_tables()
294 # Backward-compat: expose connection via thread-local conn property
295 self.local.conn = self._shared_conn
296 else:
297 # For non-temporary DBs, create a connection in the local storage
298 self.local.conn = duckdb.connect(
299 self.db_path, config={"allow_unsigned_extensions": True}
300 )
302 def _get_conn(self) -> duckdb.DuckDBPyConnection:
303 """Get database connection for the current thread, initializing if needed."""
304 if not self._initialized:
305 msg = "Database connection not initialized. Call initialize() first"
306 raise RuntimeError(msg)
308 # For test environments using in-memory DB, create a shared connection with locking
309 if self.is_temp_db:
310 with self.lock:
311 if not hasattr(self, "_shared_conn"): 311 ↛ 312line 311 didn't jump to line 312 because the condition on line 311 was never true
312 self._shared_conn = duckdb.connect(
313 self.db_path, config={"allow_unsigned_extensions": True}
314 )
315 # Create tables in the shared connection for in-memory DB
316 self._initialize_shared_tables()
317 self.local.conn = self._shared_conn
318 return self._shared_conn
320 # For normal environments, use thread-local storage
321 if not hasattr(self.local, "conn") or self.local.conn is None:
322 self.local.conn = duckdb.connect(
323 self.db_path, config={"allow_unsigned_extensions": True}
324 )
325 return self.local.conn
327 def _initialize_shared_tables(self) -> None:
328 """Initialize tables in the shared connection for in-memory databases."""
329 # Access the shared connection through the instance variable
330 conn = getattr(self, "_shared_conn", None)
331 if not conn: 331 ↛ 332line 331 didn't jump to line 332 because the condition on line 331 was never true
332 return # Defensive check
334 # Create conversations table
335 conn.execute("""
336 CREATE TABLE IF NOT EXISTS conversations (
337 id VARCHAR PRIMARY KEY,
338 content TEXT NOT NULL,
339 embedding FLOAT[384],
340 project VARCHAR,
341 timestamp TIMESTAMP,
342 metadata JSON
343 )
344 """)
346 # Create reflections table
347 conn.execute("""
348 CREATE TABLE IF NOT EXISTS reflections (
349 id VARCHAR PRIMARY KEY,
350 content TEXT NOT NULL,
351 embedding FLOAT[384],
352 project VARCHAR,
353 tags VARCHAR[],
354 timestamp TIMESTAMP,
355 metadata JSON
356 )
357 """)
359 # Create reflection_tags table for tag-based search (no FK: DuckDB has limitations on updates)
360 conn.execute("""
361 CREATE TABLE IF NOT EXISTS reflection_tags (
362 reflection_id VARCHAR,
363 tag VARCHAR,
364 PRIMARY KEY (reflection_id, tag)
365 )
366 """)
368 # Create project_groups table for multi-project coordination
369 conn.execute("""
370 CREATE TABLE IF NOT EXISTS project_groups (
371 id VARCHAR PRIMARY KEY,
372 name VARCHAR NOT NULL,
373 description TEXT,
374 projects VARCHAR[] NOT NULL,
375 created_at TIMESTAMP DEFAULT NOW(),
376 metadata JSON
377 )
378 """)
380 # Create project_dependencies table for project relationships
381 conn.execute("""
382 CREATE TABLE IF NOT EXISTS project_dependencies (
383 id VARCHAR PRIMARY KEY,
384 source_project VARCHAR NOT NULL,
385 target_project VARCHAR NOT NULL,
386 dependency_type VARCHAR NOT NULL,
387 description TEXT,
388 created_at TIMESTAMP DEFAULT NOW(),
389 metadata JSON,
390 UNIQUE(source_project, target_project, dependency_type)
391 )
392 """)
394 # Create session_links table for cross-project session coordination
395 conn.execute("""
396 CREATE TABLE IF NOT EXISTS session_links (
397 id VARCHAR PRIMARY KEY,
398 source_session_id VARCHAR NOT NULL,
399 target_session_id VARCHAR NOT NULL,
400 link_type VARCHAR NOT NULL,
401 context TEXT,
402 created_at TIMESTAMP DEFAULT NOW(),
403 metadata JSON,
404 UNIQUE(source_session_id, target_session_id, link_type)
405 )
406 """)
408 # Create search_index table for advanced search capabilities
409 conn.execute("""
410 CREATE TABLE IF NOT EXISTS search_index (
411 id VARCHAR PRIMARY KEY,
412 content_type VARCHAR NOT NULL, -- 'conversation', 'reflection', 'file', 'project'
413 content_id VARCHAR NOT NULL,
414 indexed_content TEXT NOT NULL,
415 search_metadata JSON,
416 last_indexed TIMESTAMP DEFAULT NOW(),
417 UNIQUE(content_type, content_id)
418 )
419 """)
421 # Create search_facets table for faceted search
422 conn.execute("""
423 CREATE TABLE IF NOT EXISTS search_facets (
424 id VARCHAR PRIMARY KEY,
425 content_type VARCHAR NOT NULL,
426 content_id VARCHAR NOT NULL,
427 facet_name VARCHAR NOT NULL,
428 facet_value VARCHAR NOT NULL,
429 created_at TIMESTAMP DEFAULT NOW()
430 )
431 """)
433 async def _ensure_tables(self) -> None:
434 """Ensure required tables exist."""
435 # Create conversations table
436 self._get_conn().execute("""
437 CREATE TABLE IF NOT EXISTS conversations (
438 id VARCHAR PRIMARY KEY,
439 content TEXT NOT NULL,
440 embedding FLOAT[384],
441 project VARCHAR,
442 timestamp TIMESTAMP,
443 metadata JSON
444 )
445 """)
447 # Create reflections table
448 self._get_conn().execute("""
449 CREATE TABLE IF NOT EXISTS reflections (
450 id VARCHAR PRIMARY KEY,
451 content TEXT NOT NULL,
452 embedding FLOAT[384],
453 tags VARCHAR[],
454 timestamp TIMESTAMP,
455 metadata JSON
456 )
457 """)
459 # Create project_groups table for multi-project coordination
460 self._get_conn().execute("""
461 CREATE TABLE IF NOT EXISTS project_groups (
462 id VARCHAR PRIMARY KEY,
463 name VARCHAR NOT NULL,
464 description TEXT,
465 projects VARCHAR[] NOT NULL,
466 created_at TIMESTAMP DEFAULT NOW(),
467 metadata JSON
468 )
469 """)
471 # Create project_dependencies table for project relationships
472 self._get_conn().execute("""
473 CREATE TABLE IF NOT EXISTS project_dependencies (
474 id VARCHAR PRIMARY KEY,
475 source_project VARCHAR NOT NULL,
476 target_project VARCHAR NOT NULL,
477 dependency_type VARCHAR NOT NULL,
478 description TEXT,
479 created_at TIMESTAMP DEFAULT NOW(),
480 metadata JSON,
481 UNIQUE(source_project, target_project, dependency_type)
482 )
483 """)
485 # Create session_links table for cross-project session coordination
486 self._get_conn().execute("""
487 CREATE TABLE IF NOT EXISTS session_links (
488 id VARCHAR PRIMARY KEY,
489 source_session_id VARCHAR NOT NULL,
490 target_session_id VARCHAR NOT NULL,
491 link_type VARCHAR NOT NULL,
492 context TEXT,
493 created_at TIMESTAMP DEFAULT NOW(),
494 metadata JSON,
495 UNIQUE(source_session_id, target_session_id, link_type)
496 )
497 """)
499 # Create search_index table for advanced search capabilities
500 self._get_conn().execute("""
501 CREATE TABLE IF NOT EXISTS search_index (
502 id VARCHAR PRIMARY KEY,
503 content_type VARCHAR NOT NULL, -- 'conversation', 'reflection', 'file', 'project'
504 content_id VARCHAR NOT NULL,
505 indexed_content TEXT NOT NULL,
506 search_metadata JSON,
507 last_indexed TIMESTAMP DEFAULT NOW(),
508 UNIQUE(content_type, content_id)
509 )
510 """)
512 # Create search_facets table for faceted search
513 self._get_conn().execute("""
514 CREATE TABLE IF NOT EXISTS search_facets (
515 id VARCHAR PRIMARY KEY,
516 content_type VARCHAR NOT NULL,
517 content_id VARCHAR NOT NULL,
518 facet_name VARCHAR NOT NULL,
519 facet_value VARCHAR NOT NULL,
520 created_at TIMESTAMP DEFAULT NOW()
521 )
522 """)
524 # Create indices for better performance
525 await self._ensure_indices()
527 async def _ensure_indices(self) -> None:
528 """Create indices for better query performance."""
529 indices = [
530 # Existing table indices
531 "CREATE INDEX IF NOT EXISTS idx_conversations_project ON conversations(project)",
532 "CREATE INDEX IF NOT EXISTS idx_conversations_timestamp ON conversations(timestamp)",
533 "CREATE INDEX IF NOT EXISTS idx_reflections_timestamp ON reflections(timestamp)",
534 # New multi-project indices
535 "CREATE INDEX IF NOT EXISTS idx_project_deps_source ON project_dependencies(source_project)",
536 "CREATE INDEX IF NOT EXISTS idx_project_deps_target ON project_dependencies(target_project)",
537 "CREATE INDEX IF NOT EXISTS idx_session_links_source ON session_links(source_session_id)",
538 "CREATE INDEX IF NOT EXISTS idx_session_links_target ON session_links(target_session_id)",
539 # Search indices
540 "CREATE INDEX IF NOT EXISTS idx_search_index_type ON search_index(content_type)",
541 "CREATE INDEX IF NOT EXISTS idx_search_index_last_indexed ON search_index(last_indexed)",
542 "CREATE INDEX IF NOT EXISTS idx_search_facets_name_value ON search_facets(facet_name, facet_value)",
543 "CREATE INDEX IF NOT EXISTS idx_search_facets_content ON search_facets(content_type, content_id)",
544 ]
546 for index_sql in indices:
547 with suppress(Exception):
548 # Some indices might not be supported in all DuckDB versions, continue
549 self._get_conn().execute(index_sql)
551 async def get_embedding(self, text: str) -> list[float]:
552 """Get embedding for text using ONNX model."""
553 if self.onnx_session and self.tokenizer:
555 def _get_embedding() -> list[float]:
556 # Tokenize text
557 assert self.tokenizer is not None # For type checker
558 encoded = self.tokenizer(
559 text,
560 truncation=True,
561 padding=True,
562 return_tensors="np",
563 )
565 # Run inference
566 assert self.onnx_session is not None # For type checker
567 outputs = self.onnx_session.run(
568 None,
569 {
570 "input_ids": encoded["input_ids"],
571 "attention_mask": encoded["attention_mask"],
572 "token_type_ids": encoded.get(
573 "token_type_ids",
574 np.zeros_like(encoded["input_ids"]),
575 ),
576 },
577 )
579 # Mean pooling
580 embeddings = outputs[0]
581 attention_mask = encoded["attention_mask"]
582 masked_embeddings = embeddings * np.expand_dims(attention_mask, axis=-1)
583 summed = np.sum(masked_embeddings, axis=1)
584 counts = np.sum(attention_mask, axis=1, keepdims=True)
585 mean_pooled = summed / counts
587 # Normalize
588 norms = np.linalg.norm(mean_pooled, axis=1, keepdims=True)
589 normalized = mean_pooled / norms
591 # Convert to float32 to match DuckDB FLOAT type
592 return normalized[0].astype(np.float32).tolist()
594 return await asyncio.get_event_loop().run_in_executor(None, _get_embedding)
596 msg = "No embedding model available"
597 raise RuntimeError(msg)
599 async def store_conversation(self, content: str, metadata: dict[str, Any]) -> str:
600 """Store conversation with optional embedding."""
601 conversation_id = hashlib.md5(
602 f"{content}_{time.time()}".encode("utf-8", "surrogatepass"),
603 usedforsecurity=False,
604 ).hexdigest()
606 db_content = _encode_text_for_db(content)
608 embedding: list[float] | None = None
610 if ONNX_AVAILABLE and self.onnx_session: 610 ↛ 611line 610 didn't jump to line 611 because the condition on line 610 was never true
611 try:
612 embedding = await self.get_embedding(content)
613 except Exception:
614 embedding = None # Fallback to no embedding
615 else:
616 embedding = None # Store without embedding
618 # For synchronized database access in test environments using in-memory DB
619 if self.is_temp_db:
620 # Use lock to protect database operations for in-memory DB
621 with self.lock:
622 self._get_conn().execute(
623 """
624 INSERT INTO conversations (id, content, embedding, project, timestamp, metadata)
625 VALUES (?, ?, ?, ?, ?, ?)
626 """,
627 [
628 conversation_id,
629 db_content,
630 embedding,
631 metadata.get("project"),
632 datetime.now(UTC),
633 json.dumps(metadata),
634 ],
635 )
636 else:
637 # For normal file-based DB, run in executor for thread safety
638 await asyncio.get_event_loop().run_in_executor(
639 None,
640 lambda: self._get_conn().execute(
641 """
642 INSERT INTO conversations (id, content, embedding, project, timestamp, metadata)
643 VALUES (?, ?, ?, ?, ?, ?)
644 """,
645 [
646 conversation_id,
647 db_content,
648 embedding,
649 metadata.get("project"),
650 datetime.now(UTC),
651 json.dumps(metadata),
652 ],
653 ),
654 )
656 # DuckDB is ACID-compliant by default, explicit commit is not required for individual operations
657 # However, if needed, we can call commit on the thread-local connection
658 # self._get_conn().commit()
659 return conversation_id
661 async def store_reflection(
662 self,
663 content: str,
664 tags: list[str] | None = None,
665 project: str | None = None,
666 ) -> str:
667 """Store reflection/insight with optional embedding."""
668 if content is None:
669 msg = "content cannot be None"
670 raise TypeError(msg)
672 reflection_id = hashlib.md5(
673 f"reflection_{content}_{time.time()}".encode("utf-8", "surrogatepass"),
674 usedforsecurity=False,
675 ).hexdigest()
677 db_content = _encode_text_for_db(content)
679 tags_list = tags or []
681 embedding: list[float] | None = None
683 if ONNX_AVAILABLE and self.onnx_session: 683 ↛ 684line 683 didn't jump to line 684 because the condition on line 683 was never true
684 try:
685 embedding = await self.get_embedding(content)
686 except Exception:
687 embedding = None # Fallback to no embedding
688 else:
689 embedding = None # Store without embedding
691 def _store() -> None:
692 conn = self._get_conn()
693 conn.execute(
694 """
695 INSERT INTO reflections (id, content, embedding, project, tags, timestamp, metadata)
696 VALUES (?, ?, ?, ?, ?, ?, ?)
697 """,
698 [
699 reflection_id,
700 db_content,
701 embedding,
702 project,
703 tags_list,
704 datetime.now(UTC),
705 json.dumps({"type": "reflection", "project": project}),
706 ],
707 )
708 conn.execute(
709 "DELETE FROM reflection_tags WHERE reflection_id = ?",
710 [reflection_id],
711 )
712 tags_unique = list(dict.fromkeys(tags_list))
713 for tag in tags_unique:
714 conn.execute(
715 "INSERT INTO reflection_tags (reflection_id, tag) VALUES (?, ?)",
716 [reflection_id, tag],
717 )
719 # For synchronized database access in test environments using in-memory DB
720 if self.is_temp_db:
721 with self.lock:
722 _store()
723 else:
724 await asyncio.get_event_loop().run_in_executor(None, _store)
726 # DuckDB is ACID-compliant by default, explicit commit is not required for individual operations
727 # However, if needed, we can call commit on the thread-local connection
728 # self._get_conn().commit()
729 return reflection_id
731 async def get_reflection(self, reflection_id: str | None) -> dict[str, Any] | None:
732 """Get a reflection by ID."""
733 if not reflection_id or not isinstance(reflection_id, str):
734 return None
735 if len(reflection_id) < 5 or len(reflection_id) > 128:
736 return None
738 rows = await self._execute_query(
739 "SELECT id, content, project, tags, timestamp, metadata FROM reflections WHERE id = ?",
740 [reflection_id],
741 )
742 if not rows:
743 return None
745 row = rows[0]
746 return {
747 "id": row[0],
748 "content": _decode_text_from_db(row[1]),
749 "project": row[2],
750 "tags": list(row[3]) if row[3] else [],
751 "timestamp": row[4],
752 "metadata": json.loads(row[5]) if row[5] else {},
753 }
755 async def update_reflection(
756 self,
757 reflection_id: str | None,
758 content: str | None,
759 tags: list[str] | None = None,
760 project: str | None = None,
761 ) -> None:
762 """Update an existing reflection.
764 This is best-effort: updating a non-existent reflection is a no-op.
765 """
766 if ( 766 ↛ 771line 766 didn't jump to line 771 because the condition on line 766 was never true
767 reflection_id is None
768 or not isinstance(reflection_id, str)
769 or not reflection_id
770 ):
771 return
772 if content is None:
773 msg = "content cannot be None"
774 raise TypeError(msg)
776 tags_list = tags or []
778 db_content = _encode_text_for_db(content)
780 embedding: list[float] | None = None
781 if ONNX_AVAILABLE and self.onnx_session: 781 ↛ 782line 781 didn't jump to line 782 because the condition on line 781 was never true
782 with suppress(Exception):
783 embedding = await self.get_embedding(content)
785 def _update() -> None:
786 conn = self._get_conn()
788 result = conn.execute(
789 "SELECT COUNT(*) FROM reflections WHERE id = ?",
790 [reflection_id],
791 ).fetchone()
792 exists = result[0] if result else 0
793 if exists <= 0:
794 return
796 conn.execute(
797 """
798 UPDATE reflections
799 SET content = ?,
800 embedding = ?,
801 tags = ?,
802 project = COALESCE(?, project),
803 timestamp = ?,
804 metadata = ?
805 WHERE id = ?
806 """,
807 [
808 db_content,
809 embedding,
810 tags_list,
811 project,
812 datetime.now(UTC),
813 json.dumps({"type": "reflection", "project": project}),
814 reflection_id,
815 ],
816 )
817 conn.execute(
818 "DELETE FROM reflection_tags WHERE reflection_id = ?",
819 [reflection_id],
820 )
821 tags_unique = list(dict.fromkeys(tags_list))
822 for tag in tags_unique:
823 conn.execute(
824 "INSERT INTO reflection_tags (reflection_id, tag) VALUES (?, ?)",
825 [reflection_id, tag],
826 )
828 if self.is_temp_db: 828 ↛ 829line 828 didn't jump to line 829 because the condition on line 828 was never true
829 with self.lock:
830 _update()
831 else:
832 await asyncio.get_event_loop().run_in_executor(None, _update)
834 async def delete_reflection(self, reflection_id: str | None) -> None:
835 """Delete a reflection by ID.
837 Deleting a non-existent reflection is a no-op.
838 """
839 if reflection_id is None:
840 msg = "reflection_id cannot be None"
841 raise TypeError(msg)
842 if not isinstance(reflection_id, str) or not reflection_id: 842 ↛ 843line 842 didn't jump to line 843 because the condition on line 842 was never true
843 msg = "reflection_id must be a non-empty string"
844 raise ValueError(msg)
846 def _delete() -> None:
847 conn = self._get_conn()
848 conn.execute(
849 "DELETE FROM reflection_tags WHERE reflection_id = ?",
850 [reflection_id],
851 )
852 conn.execute(
853 "DELETE FROM reflections WHERE id = ?",
854 [reflection_id],
855 )
857 if self.is_temp_db: 857 ↛ 858line 857 didn't jump to line 858 because the condition on line 857 was never true
858 with self.lock:
859 _delete()
860 else:
861 await asyncio.get_event_loop().run_in_executor(None, _delete)
863 async def search_conversations(
864 self,
865 query: str,
866 limit: int = 5,
867 min_score: float = 0.7,
868 project: str | None = None,
869 ) -> list[dict[str, Any]]:
870 """Search conversations by text similarity (fallback to text search if no embeddings)."""
871 if ONNX_AVAILABLE and self.onnx_session: 871 ↛ 872line 871 didn't jump to line 872 because the condition on line 871 was never true
872 return await self._semantic_search_conversations(
873 query, limit, min_score, project
874 )
875 return await self._text_search_conversations(query, limit, project)
877 async def _semantic_search_conversations(
878 self, query: str, limit: int, min_score: float, project: str | None
879 ) -> list[dict[str, Any]]:
880 """Semantic search implementation with embeddings."""
881 with suppress(Exception):
882 query_embedding = await self.get_embedding(query)
884 sql = """
885 SELECT
886 id, content, embedding, project, timestamp, metadata,
887 array_cosine_similarity(embedding, CAST(? AS FLOAT[384])) as score
888 FROM conversations
889 WHERE embedding IS NOT NULL
890 """
891 params: list[Any] = [query_embedding]
893 if project:
894 sql += " AND project = ?"
895 params.append(project)
897 sql += """
898 ORDER BY score DESC
899 LIMIT ?
900 """
901 params.append(limit)
903 # For synchronized database access in test environments using in-memory DB
904 if self.is_temp_db:
905 # Use lock to protect database operations for in-memory DB
906 with self.lock:
907 results = self._get_conn().execute(sql, params).fetchall()
908 else:
909 # For normal file-based DB, run in executor for thread safety
910 results = await asyncio.get_event_loop().run_in_executor(
911 None,
912 lambda: self._get_conn().execute(sql, params).fetchall(),
913 )
915 # Build results and log accesses into v2 access log (best-effort)
916 filtered = [row for row in results if float(row[6]) >= min_score]
917 self._log_accesses([str(row[0]) for row in filtered])
919 return [
920 {
921 "content": _decode_text_from_db(row[1]),
922 "score": float(row[6]),
923 "timestamp": row[4],
924 "project": row[3],
925 "metadata": json.loads(row[5]) if row[5] else {},
926 }
927 for row in filtered
928 ]
930 # If semantic search fails or is not available, fallback to text search
931 return await self._text_search_conversations(query, limit, project)
933 async def _text_search_conversations(
934 self, query: str, limit: int, project: str | None
935 ) -> list[dict[str, Any]]:
936 """Fallback text search implementation."""
937 search_terms = query.lower().split()
939 # Return empty list when query is empty
940 if not search_terms:
941 return []
943 sql = "SELECT id, content, project, timestamp, metadata FROM conversations"
944 params = []
946 if project: 946 ↛ 947line 946 didn't jump to line 947 because the condition on line 946 was never true
947 sql += " WHERE project = ?"
948 params.append(project)
950 sql += " ORDER BY timestamp DESC"
952 # For synchronized database access in test environments using in-memory DB
953 if self.is_temp_db:
954 # Use lock to protect database operations for in-memory DB
955 with self.lock:
956 results = self._get_conn().execute(sql, params).fetchall()
957 else:
958 # For normal file-based DB, run in executor for thread safety
959 results = await asyncio.get_event_loop().run_in_executor(
960 None,
961 lambda: self._get_conn().execute(sql, params).fetchall(),
962 )
964 # Simple text matching score
965 matches = []
966 matched_ids: list[str] = []
967 for row in results:
968 content = _decode_text_from_db(row[1])
969 content_lower = content.lower()
970 score = sum(1 for term in search_terms if term in content_lower) / len(
971 search_terms,
972 )
974 if score > 0: # At least one term matches
975 matches.append(
976 {
977 "content": content,
978 "score": score,
979 "timestamp": row[3],
980 "project": row[2],
981 "metadata": json.loads(row[4]) if row[4] else {},
982 },
983 )
984 with suppress(Exception):
985 matched_ids.append(str(row[0]))
987 # Sort by score and return top matches, then log accesses
988 matches.sort(key=operator.itemgetter("score"), reverse=True)
989 top = matches[:limit]
990 self._log_accesses(matched_ids[:limit])
991 return top
993 def _log_accesses(self, conv_ids: list[str]) -> None:
994 """Helper to log memory accesses."""
995 from contextlib import suppress
997 with suppress(Exception):
998 from session_buddy.memory.persistence import (
999 log_memory_access as _log_access,
1000 )
1002 for conv_id in conv_ids:
1003 _log_access(conv_id, access_type="search")
1005 async def search_reflections(
1006 self,
1007 query: str,
1008 limit: int = 5,
1009 project: str | None = None,
1010 *,
1011 tags: list[str] | None = None,
1012 min_score: float = 0.7,
1013 ) -> list[dict[str, Any]]:
1014 """Search stored reflections by semantic similarity with text fallback."""
1015 if query is None:
1016 msg = "query cannot be None"
1017 raise TypeError(msg)
1018 if limit <= 0:
1019 return []
1021 results = await self._semantic_reflection_search(
1022 query,
1023 limit,
1024 min_score,
1025 project,
1026 tags,
1027 )
1028 if results is not None: 1028 ↛ 1029line 1028 didn't jump to line 1029 because the condition on line 1028 was never true
1029 return results
1031 return await self._text_reflection_search(query, limit, project, tags)
1033 async def _semantic_reflection_search(
1034 self,
1035 query: str,
1036 limit: int,
1037 min_score: float,
1038 project: str | None,
1039 tags: list[str] | None,
1040 ) -> list[dict[str, Any]] | None:
1041 """Run semantic reflection search if ONNX embeddings available."""
1042 if not (ONNX_AVAILABLE and self.onnx_session): 1042 ↛ 1045line 1042 didn't jump to line 1045 because the condition on line 1042 was always true
1043 return None
1045 with suppress(Exception):
1046 query_embedding = await self.get_embedding(query)
1047 sql = """
1048 SELECT
1049 id, content, project, tags, timestamp, metadata,
1050 array_cosine_similarity(embedding, CAST(? AS FLOAT[384])) as score
1051 FROM reflections
1052 WHERE embedding IS NOT NULL
1053 """
1055 params: list[Any] = [query_embedding]
1056 if project is not None:
1057 sql += " AND project = ?"
1058 params.append(project)
1060 if tags:
1061 tag_clauses = " OR ".join(["list_contains(tags, ?)"] * len(tags))
1062 sql += f" AND ({tag_clauses})"
1063 params.extend(tags)
1065 sql += """
1066 ORDER BY score DESC
1067 LIMIT ?
1068 """
1070 params.append(limit)
1071 results = await self._execute_query(sql, params)
1072 semantic_results = [
1073 {
1074 "id": row[0],
1075 "content": _decode_text_from_db(row[1]),
1076 "score": float(row[6]),
1077 "project": row[2],
1078 "tags": list(row[3]) if row[3] else [],
1079 "timestamp": row[4],
1080 "metadata": json.loads(row[5]) if row[5] else {},
1081 }
1082 for row in results
1083 if float(row[6]) >= min_score
1084 ]
1086 if semantic_results:
1087 return semantic_results
1088 return None
1090 async def _text_reflection_search(
1091 self,
1092 query: str,
1093 limit: int,
1094 project: str | None,
1095 tags: list[str] | None,
1096 ) -> list[dict[str, Any]]:
1097 """Fallback text search for reflections."""
1098 sql = "SELECT id, content, project, tags, timestamp, metadata FROM reflections"
1099 params: list[Any] = []
1101 where_clauses = []
1102 if project is not None:
1103 where_clauses.append("project = ?")
1104 params.append(project)
1106 if tags:
1107 tag_clauses = " OR ".join(["list_contains(tags, ?)"] * len(tags))
1108 where_clauses.append(f"({tag_clauses})")
1109 params.extend(tags)
1111 if where_clauses:
1112 sql += " WHERE " + " AND ".join(where_clauses)
1114 sql += " ORDER BY timestamp DESC"
1115 results = await self._execute_query(sql, params or None)
1117 search_terms = query.lower().split()
1118 matches = []
1119 for row in results:
1120 content = _decode_text_from_db(row[1])
1121 combined_text = f"{content.lower()} {' '.join(list(row[3] or [])).lower()}"
1122 score = (
1123 sum(1 for term in search_terms if term in combined_text)
1124 / len(search_terms)
1125 if search_terms
1126 else 1.0
1127 )
1129 if score > 0:
1130 matches.append(
1131 {
1132 "id": row[0],
1133 "content": content,
1134 "score": score,
1135 "project": row[2],
1136 "tags": list(row[3]) if row[3] else [],
1137 "timestamp": row[4],
1138 "metadata": json.loads(row[5]) if row[5] else {},
1139 },
1140 )
1142 matches.sort(key=operator.itemgetter("score"), reverse=True)
1143 return matches[:limit]
1145 async def _execute_query(
1146 self,
1147 sql: str,
1148 params: list[Any] | None = None,
1149 ) -> list[Any]:
1150 """Execute a query with locking or async executor based on DB type."""
1151 params = params or []
1152 if self.is_temp_db:
1153 with self.lock:
1154 return self._get_conn().execute(sql, params).fetchall()
1156 loop = asyncio.get_event_loop()
1157 return await loop.run_in_executor(
1158 None,
1159 lambda: self._get_conn().execute(sql, params).fetchall(),
1160 )
1162 async def search_by_file(
1163 self,
1164 file_path: str,
1165 limit: int = 10,
1166 project: str | None = None,
1167 ) -> list[dict[str, Any]]:
1168 """Search conversations that mention a specific file."""
1169 sql = """
1170 SELECT id, content, project, timestamp, metadata
1171 FROM conversations
1172 WHERE content LIKE ?
1173 """
1174 params: list[Any] = [f"%{file_path}%"]
1176 if project: 1176 ↛ 1177line 1176 didn't jump to line 1177 because the condition on line 1176 was never true
1177 sql += " AND project = ?"
1178 params.append(project)
1180 sql += " ORDER BY timestamp DESC LIMIT ?"
1181 params.append(limit)
1183 # For synchronized database access in test environments using in-memory DB
1184 if self.is_temp_db: 1184 ↛ 1186line 1184 didn't jump to line 1186 because the condition on line 1184 was never true
1185 # Use lock to protect database operations for in-memory DB
1186 with self.lock:
1187 results = self._get_conn().execute(sql, params).fetchall()
1188 else:
1189 # For normal file-based DB, run in executor for thread safety
1190 results = await asyncio.get_event_loop().run_in_executor(
1191 None,
1192 lambda: self._get_conn().execute(sql, params).fetchall(),
1193 )
1195 # Build results and log access for each conversation id
1196 output = []
1197 for row in results:
1198 output.append(
1199 {
1200 "content": _decode_text_from_db(row[1]),
1201 "project": row[2],
1202 "timestamp": row[3],
1203 "metadata": json.loads(row[4]) if row[4] else {},
1204 }
1205 )
1206 from contextlib import suppress
1208 with suppress(Exception):
1209 from session_buddy.memory.persistence import (
1210 log_memory_access as _log_access,
1211 )
1213 _log_access(str(row[0]), access_type="search")
1214 return output
1216 async def get_stats(self) -> dict[str, Any]:
1217 """Get database statistics."""
1218 try:
1219 conv_count = await self._get_conversation_count()
1220 refl_count = await self._get_reflection_count()
1222 projects_rows = await self._execute_query(
1223 "SELECT DISTINCT project FROM reflections WHERE project IS NOT NULL",
1224 )
1225 projects = [row[0] for row in projects_rows if row and row[0] is not None]
1227 provider = (
1228 "onnx-runtime"
1229 if (self.onnx_session and ONNX_AVAILABLE)
1230 else "text-search-only"
1231 )
1232 return {
1233 "conversations_count": conv_count,
1234 "reflections_count": refl_count,
1235 "total_conversations": conv_count,
1236 "total_reflections": refl_count,
1237 "projects": projects,
1238 "total_projects": len(projects),
1239 "embedding_provider": provider,
1240 "embedding_dimension": self.embedding_dim,
1241 "database_path": self.db_path,
1242 }
1243 except Exception as e:
1244 return {"error": f"Failed to get stats: {e}"}
1246 async def _get_conversation_count(self) -> int:
1247 """Get the count of conversations from the database."""
1248 if self.is_temp_db:
1249 with self.lock:
1250 result = (
1251 self._get_conn()
1252 .execute(
1253 "SELECT COUNT(*) FROM conversations",
1254 )
1255 .fetchone()
1256 )
1257 return result[0] if result and result[0] else 0
1258 else:
1259 return await asyncio.get_event_loop().run_in_executor(
1260 None,
1261 lambda: (
1262 (
1263 result := self._get_conn()
1264 .execute(
1265 "SELECT COUNT(*) FROM conversations",
1266 )
1267 .fetchone()
1268 )
1269 and result[0]
1270 )
1271 or 0,
1272 )
1274 async def _get_reflection_count(self) -> int:
1275 """Get the count of reflections from the database."""
1276 if self.is_temp_db:
1277 with self.lock:
1278 result = (
1279 self._get_conn()
1280 .execute(
1281 "SELECT COUNT(*) FROM reflections",
1282 )
1283 .fetchone()
1284 )
1285 return result[0] if result and result[0] else 0
1286 else:
1287 return await asyncio.get_event_loop().run_in_executor(
1288 None,
1289 lambda: (
1290 (
1291 result := self._get_conn()
1292 .execute(
1293 "SELECT COUNT(*) FROM reflections",
1294 )
1295 .fetchone()
1296 )
1297 and result[0]
1298 )
1299 or 0,
1300 )
1303# Global database adapter instance
1304_reflection_db: ReflectionDatabaseAdapter | None = None
1307async def get_reflection_database() -> ReflectionDatabaseAdapter:
1308 """Get or create reflection database adapter instance.
1310 DEPRECATED: This function is deprecated and will be removed in a future release.
1311 Use the ReflectionDatabaseAdapter directly with dependency injection instead.
1312 """
1313 global _reflection_db
1314 if _reflection_db is None:
1315 from session_buddy.di import configure
1317 configure()
1318 _reflection_db = ReflectionDatabaseAdapter()
1319 await _reflection_db.initialize()
1320 return _reflection_db
1323def get_initialized_reflection_database() -> ReflectionDatabaseAdapter | None:
1324 """Return the initialized reflection database if available."""
1325 return _reflection_db
1328def cleanup_reflection_database() -> None:
1329 """Clean up global reflection database instance."""
1330 global _reflection_db
1331 if _reflection_db:
1332 _reflection_db.close()
1333 _reflection_db = None
1336def get_current_project() -> str | None:
1337 """Get current project name from working directory."""
1338 try:
1339 cwd = Path.cwd()
1340 # Try to detect project from common indicators
1341 if (cwd / "pyproject.toml").exists() or (cwd / "package.json").exists(): 1341 ↛ 1342line 1341 didn't jump to line 1342 because the condition on line 1341 was never true
1342 return cwd.name
1343 # Fallback to directory name
1344 return cwd.name if cwd.name != "." else None
1345 except Exception:
1346 return None