Coverage for src \ truenex_memory \ store \ repository.py: 91%
245 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 10:21 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 10:21 +0200
1"""Repository for local documents, chunks, memories and retrieval logs."""
3from __future__ import annotations
5from pathlib import Path
6from datetime import datetime, timezone
7import json
8import sqlite3
9import uuid
11from truenex_memory.core.chunker import TextChunk, content_hash
12from truenex_memory.retrieval.semantic import Embedder, VectorMatch, VectorPoint, VectorStore, chunk_point_id
13from truenex_memory.store.qdrant_store import VectorSearchHit
14from truenex_memory.store.models import MemoryNode, RetrievalLog, SearchHit, VALID_STATUSES
15from truenex_memory.retrieval.scoring import tokenize_set
16from truenex_memory.store.sqlite import connect, initialize_schema
19ACTIVE_STATUSES = ("active", "unverified")
20EXPORT_VERSION = "1"
21PROJECT_ID = "default"
22EXPORT_TABLES = ("documents", "chunks", "memory_nodes", "edges", "retrieval_logs", "schema_migrations")
25class MemoryRepository:
26 """SQLite-backed local repository."""
28 def __init__(
29 self,
30 db_path: Path,
31 *,
32 embedder: Embedder | None = None,
33 vector_store: VectorStore | None = None,
34 ) -> None:
35 self.db_path = db_path
36 self.embedder = embedder
37 self.vector_store = vector_store
38 self.last_trace_id: str | None = None
40 def initialize(self) -> None:
41 with connect(self.db_path) as conn:
42 initialize_schema(conn)
44 def add_memory(
45 self,
46 content: str,
47 *,
48 memory_type: str = "note",
49 title: str | None = None,
50 status: str = "active",
51 source_kind: str = "manual",
52 source_document_id: str | None = None,
53 source_chunk_id: str | None = None,
54 source_path: str | None = None,
55 created_by: str = "user",
56 model_name: str | None = None,
57 confidence: float | None = None,
58 ) -> str:
59 if status not in VALID_STATUSES:
60 raise ValueError(f"invalid status {status!r}, expected one of {sorted(VALID_STATUSES)}")
61 now = _now_sql()
62 memory_id = _new_id("mem")
63 clean_content = content.strip()
64 if not clean_content:
65 raise ValueError("memory content cannot be empty")
66 self.initialize()
67 with connect(self.db_path) as conn:
68 conn.execute(
69 """
70 INSERT INTO memory_nodes (
71 id, project_id, type, title, content, status, source_kind,
72 source_document_id, source_chunk_id, source_path,
73 content_hash, created_by, model_name, confidence,
74 created_at, updated_at
75 )
76 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
77 """,
78 (
79 memory_id,
80 PROJECT_ID,
81 memory_type,
82 title or _title_from_content(clean_content),
83 clean_content,
84 status,
85 source_kind,
86 source_document_id,
87 source_chunk_id,
88 source_path,
89 content_hash(clean_content),
90 created_by,
91 model_name,
92 confidence,
93 now,
94 now,
95 ),
96 )
97 conn.commit()
98 return memory_id
100 def find_memory_by_content_hash(self, hash_value: str) -> MemoryNode | None:
101 self.initialize()
102 with connect(self.db_path) as conn:
103 row = conn.execute(
104 """
105 SELECT * FROM memory_nodes
106 WHERE project_id = ? AND content_hash = ?
107 ORDER BY created_at, id
108 LIMIT 1
109 """,
110 (PROJECT_ID, hash_value),
111 ).fetchone()
112 return _memory_node_from_row(row) if row is not None else None
114 def upsert_document(self, path: Path, relative_path: str, chunks: list[TextChunk], *, source_type: str | None = None) -> str:
115 text = path.read_text(encoding="utf-8", errors="replace")
116 doc_id = "doc_" + content_hash(relative_path)[:24]
117 filename = _filename_from_logical_path(relative_path, fallback=path)
118 now = _now_sql()
119 self.initialize()
120 with connect(self.db_path) as conn:
121 conn.execute(
122 """
123 INSERT INTO documents (
124 id, project_id, path, filename, content_hash,
125 last_indexed_at, created_at, updated_at
126 )
127 VALUES (?, ?, ?, ?, ?, ?, ?, ?)
128 ON CONFLICT(id) DO UPDATE SET
129 filename=excluded.filename,
130 content_hash=excluded.content_hash,
131 last_indexed_at=excluded.last_indexed_at,
132 updated_at=excluded.updated_at
133 """,
134 (
135 doc_id,
136 PROJECT_ID,
137 relative_path,
138 filename,
139 content_hash(text),
140 now,
141 now,
142 now,
143 ),
144 )
145 conn.execute("DELETE FROM chunks WHERE document_id = ?", (doc_id,))
146 vector_points: list[VectorPoint] = []
147 for chunk in chunks:
148 chunk_id = f"{doc_id}_chunk_{chunk.index}"
149 embedding_vector = self.embedder.embed(chunk.content) if self.embedder is not None else None
150 point_id = chunk_point_id(chunk_id) if embedding_vector is not None else None
151 conn.execute(
152 """
153 INSERT INTO chunks (
154 id, document_id, chunk_index, heading_path, content,
155 content_hash, token_count, qdrant_point_id, embedding_model,
156 embedding_vector_json, source_type, created_at, updated_at
157 )
158 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
159 """,
160 (
161 chunk_id,
162 doc_id,
163 chunk.index,
164 chunk.heading_path,
165 chunk.content,
166 chunk.content_hash,
167 chunk.token_count,
168 point_id,
169 self.embedder.model_name if self.embedder is not None else None,
170 json.dumps(embedding_vector) if embedding_vector is not None else None,
171 source_type,
172 now,
173 now,
174 ),
175 )
176 if point_id is not None and embedding_vector is not None:
177 vector_points.append(
178 VectorPoint(
179 point_id=point_id,
180 vector=embedding_vector,
181 payload={"chunk_id": chunk_id, "document_id": doc_id},
182 )
183 )
184 if vector_points and self.vector_store is not None:
185 self.vector_store.upsert(vector_points)
186 conn.commit()
187 return doc_id
189 def search(self, query: str, *, top_k: int = 5, include_inactive: bool = False) -> list[SearchHit]:
190 tokens = tokenize_set(query)
191 if not tokens:
192 return []
193 self.initialize()
194 with connect(self.db_path) as conn:
195 hits = self._search_semantic_chunks(conn, query, top_k)
196 if not hits:
197 hits = _search_memories(conn, tokens, include_inactive)
198 hits.extend(_search_chunks(conn, tokens))
199 hits.sort(key=lambda item: item.score, reverse=True)
200 results = hits[:top_k]
201 self.last_trace_id = self._record_retrieval_log(conn, query, top_k, results)
202 conn.commit()
203 return results
205 def stats(self) -> dict[str, int]:
206 self.initialize()
207 with connect(self.db_path) as conn:
208 return {
209 "documents": conn.execute("SELECT COUNT(*) FROM documents").fetchone()[0],
210 "chunks": conn.execute("SELECT COUNT(*) FROM chunks").fetchone()[0],
211 "memory_nodes": conn.execute("SELECT COUNT(*) FROM memory_nodes").fetchone()[0],
212 "retrieval_logs": conn.execute("SELECT COUNT(*) FROM retrieval_logs").fetchone()[0],
213 }
215 def list_memory_nodes(self, *, status: str | None = None) -> list[MemoryNode]:
216 if status is not None and status not in VALID_STATUSES:
217 raise ValueError(f"invalid status {status!r}, expected one of {sorted(VALID_STATUSES)}")
218 self.initialize()
219 with connect(self.db_path) as conn:
220 if status is not None:
221 rows = conn.execute(
222 "SELECT * FROM memory_nodes WHERE status = ? ORDER BY created_at, id",
223 (status,),
224 ).fetchall()
225 else:
226 rows = conn.execute("SELECT * FROM memory_nodes ORDER BY created_at, id").fetchall()
227 return [_memory_node_from_row(row) for row in rows]
229 def set_memory_status(self, memory_id: str, status: str) -> None:
230 if status not in VALID_STATUSES:
231 raise ValueError(
232 f"invalid status {status!r}, expected one of {sorted(VALID_STATUSES)}"
233 )
234 self.initialize()
235 with connect(self.db_path) as conn:
236 cursor = conn.execute(
237 "UPDATE memory_nodes SET status = ?, updated_at = ? WHERE id = ?",
238 (status, _now_sql(), memory_id),
239 )
240 if cursor.rowcount == 0:
241 raise LookupError(f"memory node not found: {memory_id!r}")
242 conn.commit()
244 def export_data(self) -> dict[str, object]:
245 self.initialize()
246 with connect(self.db_path) as conn:
247 return {
248 "memory_export_version": EXPORT_VERSION,
249 "project_id": PROJECT_ID,
250 "documents": _rows(conn, "documents"),
251 "chunks": _rows(conn, "chunks"),
252 "memory_nodes": _rows(conn, "memory_nodes"),
253 "edges": _rows(conn, "edges"),
254 "retrieval_logs": _rows(conn, "retrieval_logs"),
255 "schema_migrations": _rows(conn, "schema_migrations"),
256 }
258 def import_data(self, payload: dict[str, object]) -> None:
259 if str(payload.get("memory_export_version")) != EXPORT_VERSION:
260 raise ValueError("unsupported memory export version")
261 self.initialize()
262 with connect(self.db_path) as conn:
263 for table in EXPORT_TABLES:
264 rows = payload.get(table, [])
265 if not isinstance(rows, list):
266 raise ValueError(f"invalid export table: {table}")
267 for row in rows:
268 if not isinstance(row, dict):
269 raise ValueError(f"invalid row in table: {table}")
270 _upsert_row(conn, table, row)
271 conn.commit()
273 def list_retrieval_logs(self, *, limit: int = 20) -> list[RetrievalLog]:
274 if limit < 1:
275 raise ValueError("limit must be greater than zero")
276 self.initialize()
277 with connect(self.db_path) as conn:
278 rows = conn.execute(
279 "SELECT * FROM retrieval_logs ORDER BY created_at DESC LIMIT ?",
280 (limit,),
281 ).fetchall()
282 return [_retrieval_log_from_row(row) for row in rows]
284 def get_retrieval_log(self, trace_id: str) -> RetrievalLog | None:
285 self.initialize()
286 with connect(self.db_path) as conn:
287 row = conn.execute(
288 "SELECT * FROM retrieval_logs WHERE id = ?", (trace_id,)
289 ).fetchone()
290 if row is None:
291 return None
292 return _retrieval_log_from_row(row)
294 def _record_retrieval_log(
295 self,
296 conn: sqlite3.Connection,
297 query: str,
298 top_k: int,
299 results: list[SearchHit],
300 ) -> str:
301 trace_id = _new_id("ret")
302 conn.execute(
303 """
304 INSERT INTO retrieval_logs (
305 id, project_id, query, top_k, result_count, results_json, created_at
306 )
307 VALUES (?, ?, ?, ?, ?, ?, ?)
308 """,
309 (
310 trace_id,
311 PROJECT_ID,
312 query,
313 top_k,
314 len(results),
315 json.dumps([hit.__dict__ for hit in results], sort_keys=True),
316 _now_sql(),
317 ),
318 )
319 return trace_id
321 def _semantic_enabled(self) -> bool:
322 return self.embedder is not None and self.vector_store is not None
324 def _search_semantic_chunks(
325 self,
326 conn: sqlite3.Connection,
327 query: str,
328 top_k: int,
329 ) -> list[SearchHit]:
330 if self.embedder is None:
331 return []
332 assert self.embedder is not None
333 query_vector = self.embedder.embed(query)
334 matches = self._vector_store_matches(query_vector, top_k)
335 if not matches:
336 matches = _sqlite_vector_matches(conn, query_vector, top_k)
337 if not matches:
338 return []
340 hits: list[SearchHit] = []
341 for match in matches:
342 row = conn.execute(
343 """
344 SELECT c.*, d.path
345 FROM chunks c
346 JOIN documents d ON d.id = c.document_id
347 LEFT JOIN source_ledger sl ON sl.source_path_or_alias = d.path
348 WHERE c.qdrant_point_id = ?
349 AND (sl.source_id IS NULL OR sl.status NOT IN ('missing', 'skipped'))
350 """,
351 (match.point_id,),
352 ).fetchone()
353 if row is None:
354 continue
355 hits.append(
356 SearchHit(
357 title=row["heading_path"] or Path(row["path"]).name,
358 content=row["content"],
359 source_path=row["path"],
360 heading_path=row["heading_path"],
361 memory_type="document_chunk",
362 status="active",
363 score=match.score,
364 )
365 )
366 return hits
368 def _vector_store_matches(self, query_vector: list[float], top_k: int) -> list[VectorMatch]:
369 if self.vector_store is None:
370 return []
371 try:
372 matches = self.vector_store.search(query_vector, top_k=top_k)
373 except Exception:
374 return []
375 return [_coerce_vector_match(match) for match in matches]
378def _search_memories(
379 conn: sqlite3.Connection, tokens: set[str], include_inactive: bool
380) -> list[SearchHit]:
381 if include_inactive:
382 rows = conn.execute("SELECT * FROM memory_nodes").fetchall()
383 else:
384 rows = conn.execute(
385 "SELECT * FROM memory_nodes WHERE status IN (?, ?)", ACTIVE_STATUSES
386 ).fetchall()
387 hits = []
388 for row in rows:
389 overlap = tokens & tokenize_set(f"{row['title']} {row['content']}")
390 score = round(len(overlap) / len(tokens), 4) if tokens else 0.0
391 if score > 0:
392 hits.append(
393 SearchHit(
394 title=row["title"],
395 content=row["content"],
396 source_path=row["source_path"],
397 heading_path=None,
398 memory_type=row["type"],
399 status=row["status"],
400 score=score,
401 )
402 )
403 return hits
406def _search_chunks(conn: sqlite3.Connection, tokens: set[str]) -> list[SearchHit]:
407 from truenex_memory.retrieval.scoring import BM25, tokenize, source_boost
408 rows = conn.execute(
409 """
410 SELECT c.*, d.path
411 FROM chunks c
412 JOIN documents d ON d.id = c.document_id
413 LEFT JOIN source_ledger sl ON sl.source_path_or_alias = d.path
414 WHERE sl.source_id IS NULL OR sl.status NOT IN ('missing', 'skipped')
415 """
416 ).fetchall()
417 if not rows:
418 return []
420 contents = [str(row["content"] or "") for row in rows]
421 query_tokens = list(tokens)
422 tokenized = [tokenize(c) for c in contents]
423 bm25 = BM25(tokenized)
424 scores = bm25.get_scores(query_tokens)
426 hits = []
427 for row, raw_score in zip(rows, scores):
428 if raw_score <= 0:
429 continue
430 st = row["source_type"] if "source_type" in row.keys() else None
431 final_score = round(raw_score * source_boost(st), 6)
432 hits.append(
433 SearchHit(
434 title=str(row["heading_path"] or Path(str(row["path"])).name),
435 content=str(row["content"] or ""),
436 source_path=str(row["path"]) if row["path"] is not None else None,
437 heading_path=str(row["heading_path"]) if row["heading_path"] is not None else None,
438 memory_type="document_chunk",
439 status="active",
440 score=final_score,
441 )
442 )
443 return hits
446def _sqlite_vector_matches(
447 conn: sqlite3.Connection, query_vector: list[float], top_k: int
448) -> list[VectorMatch]:
449 rows = conn.execute(
450 """
451 SELECT c.qdrant_point_id, c.embedding_vector_json
452 FROM chunks c
453 JOIN documents d ON d.id = c.document_id
454 LEFT JOIN source_ledger sl ON sl.source_path_or_alias = d.path
455 WHERE c.qdrant_point_id IS NOT NULL
456 AND c.embedding_vector_json IS NOT NULL
457 AND (sl.source_id IS NULL OR sl.status NOT IN ('missing', 'skipped'))
458 """
459 ).fetchall()
460 matches: list[VectorMatch] = []
461 for row in rows:
462 try:
463 vector = json.loads(row["embedding_vector_json"])
464 except json.JSONDecodeError:
465 continue
466 if not isinstance(vector, list):
467 continue
468 score = _cosine(query_vector, [float(value) for value in vector])
469 if score > 0:
470 matches.append(VectorMatch(point_id=row["qdrant_point_id"], score=round(score, 4)))
471 matches.sort(key=lambda item: item.score, reverse=True)
472 return matches[:top_k]
475def _coerce_vector_match(match: object) -> VectorMatch:
476 if isinstance(match, VectorMatch):
477 return match
478 if isinstance(match, VectorSearchHit):
479 return VectorMatch(point_id=match.id, score=match.score)
480 point_id = getattr(match, "point_id", None) or getattr(match, "id", None)
481 score = getattr(match, "score", 0.0)
482 return VectorMatch(point_id=str(point_id), score=float(score))
485def _cosine(left: list[float], right: list[float]) -> float:
486 if len(left) != len(right) or not left:
487 return 0.0
488 return sum(a * b for a, b in zip(left, right, strict=True))
491def _rows(conn: sqlite3.Connection, table: str) -> list[dict[str, object]]:
492 rows = conn.execute(f"SELECT * FROM {table}").fetchall()
493 return [dict(row) for row in rows]
496def _upsert_row(conn: sqlite3.Connection, table: str, row: dict[str, object]) -> None:
497 columns = list(row.keys())
498 placeholders = ", ".join("?" for _ in columns)
499 column_sql = ", ".join(columns)
500 conn.execute(
501 f"INSERT OR REPLACE INTO {table} ({column_sql}) VALUES ({placeholders})",
502 [row[column] for column in columns],
503 )
506def _title_from_content(content: str) -> str:
507 first_line = content.splitlines()[0].strip()
508 return first_line[:80] or "Untitled memory"
511def _filename_from_logical_path(relative_path: str, *, fallback: Path) -> str:
512 cleaned = relative_path.strip().replace("\\", "/").rstrip("/")
513 if not cleaned:
514 return fallback.name
515 name = cleaned.rsplit("/", 1)[-1]
516 return name or fallback.name
519def _memory_node_from_row(row: sqlite3.Row) -> MemoryNode:
520 return MemoryNode(
521 id=row["id"],
522 project_id=row["project_id"],
523 type=row["type"],
524 title=row["title"],
525 content=row["content"],
526 status=row["status"],
527 source_kind=row["source_kind"],
528 source_document_id=row["source_document_id"],
529 source_chunk_id=row["source_chunk_id"],
530 source_path=row["source_path"],
531 content_hash=row["content_hash"],
532 created_by=row["created_by"],
533 model_name=row["model_name"],
534 confidence=row["confidence"],
535 created_at=row["created_at"],
536 updated_at=row["updated_at"],
537 )
540def _retrieval_log_from_row(row: sqlite3.Row) -> RetrievalLog:
541 return RetrievalLog(
542 id=row["id"],
543 project_id=row["project_id"],
544 query=row["query"],
545 top_k=row["top_k"],
546 result_count=row["result_count"],
547 results_json=row["results_json"],
548 created_at=row["created_at"],
549 )
552def _new_id(prefix: str) -> str:
553 return f"{prefix}_{uuid.uuid4().hex}"
556def _now_sql() -> str:
557 return datetime.now(timezone.utc).isoformat()