Coverage for src \ truenex_memory \ store \ repository.py: 91%

245 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-19 10:21 +0200

1"""Repository for local documents, chunks, memories and retrieval logs.""" 

2 

3from __future__ import annotations 

4 

5from pathlib import Path 

6from datetime import datetime, timezone 

7import json 

8import sqlite3 

9import uuid 

10 

11from truenex_memory.core.chunker import TextChunk, content_hash 

12from truenex_memory.retrieval.semantic import Embedder, VectorMatch, VectorPoint, VectorStore, chunk_point_id 

13from truenex_memory.store.qdrant_store import VectorSearchHit 

14from truenex_memory.store.models import MemoryNode, RetrievalLog, SearchHit, VALID_STATUSES 

15from truenex_memory.retrieval.scoring import tokenize_set 

16from truenex_memory.store.sqlite import connect, initialize_schema 

17 

18 

19ACTIVE_STATUSES = ("active", "unverified") 

20EXPORT_VERSION = "1" 

21PROJECT_ID = "default" 

22EXPORT_TABLES = ("documents", "chunks", "memory_nodes", "edges", "retrieval_logs", "schema_migrations") 

23 

24 

25class MemoryRepository: 

26 """SQLite-backed local repository.""" 

27 

28 def __init__( 

29 self, 

30 db_path: Path, 

31 *, 

32 embedder: Embedder | None = None, 

33 vector_store: VectorStore | None = None, 

34 ) -> None: 

35 self.db_path = db_path 

36 self.embedder = embedder 

37 self.vector_store = vector_store 

38 self.last_trace_id: str | None = None 

39 

40 def initialize(self) -> None: 

41 with connect(self.db_path) as conn: 

42 initialize_schema(conn) 

43 

44 def add_memory( 

45 self, 

46 content: str, 

47 *, 

48 memory_type: str = "note", 

49 title: str | None = None, 

50 status: str = "active", 

51 source_kind: str = "manual", 

52 source_document_id: str | None = None, 

53 source_chunk_id: str | None = None, 

54 source_path: str | None = None, 

55 created_by: str = "user", 

56 model_name: str | None = None, 

57 confidence: float | None = None, 

58 ) -> str: 

59 if status not in VALID_STATUSES: 

60 raise ValueError(f"invalid status {status!r}, expected one of {sorted(VALID_STATUSES)}") 

61 now = _now_sql() 

62 memory_id = _new_id("mem") 

63 clean_content = content.strip() 

64 if not clean_content: 

65 raise ValueError("memory content cannot be empty") 

66 self.initialize() 

67 with connect(self.db_path) as conn: 

68 conn.execute( 

69 """ 

70 INSERT INTO memory_nodes ( 

71 id, project_id, type, title, content, status, source_kind, 

72 source_document_id, source_chunk_id, source_path, 

73 content_hash, created_by, model_name, confidence, 

74 created_at, updated_at 

75 ) 

76 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 

77 """, 

78 ( 

79 memory_id, 

80 PROJECT_ID, 

81 memory_type, 

82 title or _title_from_content(clean_content), 

83 clean_content, 

84 status, 

85 source_kind, 

86 source_document_id, 

87 source_chunk_id, 

88 source_path, 

89 content_hash(clean_content), 

90 created_by, 

91 model_name, 

92 confidence, 

93 now, 

94 now, 

95 ), 

96 ) 

97 conn.commit() 

98 return memory_id 

99 

100 def find_memory_by_content_hash(self, hash_value: str) -> MemoryNode | None: 

101 self.initialize() 

102 with connect(self.db_path) as conn: 

103 row = conn.execute( 

104 """ 

105 SELECT * FROM memory_nodes 

106 WHERE project_id = ? AND content_hash = ? 

107 ORDER BY created_at, id 

108 LIMIT 1 

109 """, 

110 (PROJECT_ID, hash_value), 

111 ).fetchone() 

112 return _memory_node_from_row(row) if row is not None else None 

113 

114 def upsert_document(self, path: Path, relative_path: str, chunks: list[TextChunk], *, source_type: str | None = None) -> str: 

115 text = path.read_text(encoding="utf-8", errors="replace") 

116 doc_id = "doc_" + content_hash(relative_path)[:24] 

117 filename = _filename_from_logical_path(relative_path, fallback=path) 

118 now = _now_sql() 

119 self.initialize() 

120 with connect(self.db_path) as conn: 

121 conn.execute( 

122 """ 

123 INSERT INTO documents ( 

124 id, project_id, path, filename, content_hash, 

125 last_indexed_at, created_at, updated_at 

126 ) 

127 VALUES (?, ?, ?, ?, ?, ?, ?, ?) 

128 ON CONFLICT(id) DO UPDATE SET 

129 filename=excluded.filename, 

130 content_hash=excluded.content_hash, 

131 last_indexed_at=excluded.last_indexed_at, 

132 updated_at=excluded.updated_at 

133 """, 

134 ( 

135 doc_id, 

136 PROJECT_ID, 

137 relative_path, 

138 filename, 

139 content_hash(text), 

140 now, 

141 now, 

142 now, 

143 ), 

144 ) 

145 conn.execute("DELETE FROM chunks WHERE document_id = ?", (doc_id,)) 

146 vector_points: list[VectorPoint] = [] 

147 for chunk in chunks: 

148 chunk_id = f"{doc_id}_chunk_{chunk.index}" 

149 embedding_vector = self.embedder.embed(chunk.content) if self.embedder is not None else None 

150 point_id = chunk_point_id(chunk_id) if embedding_vector is not None else None 

151 conn.execute( 

152 """ 

153 INSERT INTO chunks ( 

154 id, document_id, chunk_index, heading_path, content, 

155 content_hash, token_count, qdrant_point_id, embedding_model, 

156 embedding_vector_json, source_type, created_at, updated_at 

157 ) 

158 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 

159 """, 

160 ( 

161 chunk_id, 

162 doc_id, 

163 chunk.index, 

164 chunk.heading_path, 

165 chunk.content, 

166 chunk.content_hash, 

167 chunk.token_count, 

168 point_id, 

169 self.embedder.model_name if self.embedder is not None else None, 

170 json.dumps(embedding_vector) if embedding_vector is not None else None, 

171 source_type, 

172 now, 

173 now, 

174 ), 

175 ) 

176 if point_id is not None and embedding_vector is not None: 

177 vector_points.append( 

178 VectorPoint( 

179 point_id=point_id, 

180 vector=embedding_vector, 

181 payload={"chunk_id": chunk_id, "document_id": doc_id}, 

182 ) 

183 ) 

184 if vector_points and self.vector_store is not None: 

185 self.vector_store.upsert(vector_points) 

186 conn.commit() 

187 return doc_id 

188 

189 def search(self, query: str, *, top_k: int = 5, include_inactive: bool = False) -> list[SearchHit]: 

190 tokens = tokenize_set(query) 

191 if not tokens: 

192 return [] 

193 self.initialize() 

194 with connect(self.db_path) as conn: 

195 hits = self._search_semantic_chunks(conn, query, top_k) 

196 if not hits: 

197 hits = _search_memories(conn, tokens, include_inactive) 

198 hits.extend(_search_chunks(conn, tokens)) 

199 hits.sort(key=lambda item: item.score, reverse=True) 

200 results = hits[:top_k] 

201 self.last_trace_id = self._record_retrieval_log(conn, query, top_k, results) 

202 conn.commit() 

203 return results 

204 

205 def stats(self) -> dict[str, int]: 

206 self.initialize() 

207 with connect(self.db_path) as conn: 

208 return { 

209 "documents": conn.execute("SELECT COUNT(*) FROM documents").fetchone()[0], 

210 "chunks": conn.execute("SELECT COUNT(*) FROM chunks").fetchone()[0], 

211 "memory_nodes": conn.execute("SELECT COUNT(*) FROM memory_nodes").fetchone()[0], 

212 "retrieval_logs": conn.execute("SELECT COUNT(*) FROM retrieval_logs").fetchone()[0], 

213 } 

214 

215 def list_memory_nodes(self, *, status: str | None = None) -> list[MemoryNode]: 

216 if status is not None and status not in VALID_STATUSES: 

217 raise ValueError(f"invalid status {status!r}, expected one of {sorted(VALID_STATUSES)}") 

218 self.initialize() 

219 with connect(self.db_path) as conn: 

220 if status is not None: 

221 rows = conn.execute( 

222 "SELECT * FROM memory_nodes WHERE status = ? ORDER BY created_at, id", 

223 (status,), 

224 ).fetchall() 

225 else: 

226 rows = conn.execute("SELECT * FROM memory_nodes ORDER BY created_at, id").fetchall() 

227 return [_memory_node_from_row(row) for row in rows] 

228 

229 def set_memory_status(self, memory_id: str, status: str) -> None: 

230 if status not in VALID_STATUSES: 

231 raise ValueError( 

232 f"invalid status {status!r}, expected one of {sorted(VALID_STATUSES)}" 

233 ) 

234 self.initialize() 

235 with connect(self.db_path) as conn: 

236 cursor = conn.execute( 

237 "UPDATE memory_nodes SET status = ?, updated_at = ? WHERE id = ?", 

238 (status, _now_sql(), memory_id), 

239 ) 

240 if cursor.rowcount == 0: 

241 raise LookupError(f"memory node not found: {memory_id!r}") 

242 conn.commit() 

243 

244 def export_data(self) -> dict[str, object]: 

245 self.initialize() 

246 with connect(self.db_path) as conn: 

247 return { 

248 "memory_export_version": EXPORT_VERSION, 

249 "project_id": PROJECT_ID, 

250 "documents": _rows(conn, "documents"), 

251 "chunks": _rows(conn, "chunks"), 

252 "memory_nodes": _rows(conn, "memory_nodes"), 

253 "edges": _rows(conn, "edges"), 

254 "retrieval_logs": _rows(conn, "retrieval_logs"), 

255 "schema_migrations": _rows(conn, "schema_migrations"), 

256 } 

257 

258 def import_data(self, payload: dict[str, object]) -> None: 

259 if str(payload.get("memory_export_version")) != EXPORT_VERSION: 

260 raise ValueError("unsupported memory export version") 

261 self.initialize() 

262 with connect(self.db_path) as conn: 

263 for table in EXPORT_TABLES: 

264 rows = payload.get(table, []) 

265 if not isinstance(rows, list): 

266 raise ValueError(f"invalid export table: {table}") 

267 for row in rows: 

268 if not isinstance(row, dict): 

269 raise ValueError(f"invalid row in table: {table}") 

270 _upsert_row(conn, table, row) 

271 conn.commit() 

272 

273 def list_retrieval_logs(self, *, limit: int = 20) -> list[RetrievalLog]: 

274 if limit < 1: 

275 raise ValueError("limit must be greater than zero") 

276 self.initialize() 

277 with connect(self.db_path) as conn: 

278 rows = conn.execute( 

279 "SELECT * FROM retrieval_logs ORDER BY created_at DESC LIMIT ?", 

280 (limit,), 

281 ).fetchall() 

282 return [_retrieval_log_from_row(row) for row in rows] 

283 

284 def get_retrieval_log(self, trace_id: str) -> RetrievalLog | None: 

285 self.initialize() 

286 with connect(self.db_path) as conn: 

287 row = conn.execute( 

288 "SELECT * FROM retrieval_logs WHERE id = ?", (trace_id,) 

289 ).fetchone() 

290 if row is None: 

291 return None 

292 return _retrieval_log_from_row(row) 

293 

294 def _record_retrieval_log( 

295 self, 

296 conn: sqlite3.Connection, 

297 query: str, 

298 top_k: int, 

299 results: list[SearchHit], 

300 ) -> str: 

301 trace_id = _new_id("ret") 

302 conn.execute( 

303 """ 

304 INSERT INTO retrieval_logs ( 

305 id, project_id, query, top_k, result_count, results_json, created_at 

306 ) 

307 VALUES (?, ?, ?, ?, ?, ?, ?) 

308 """, 

309 ( 

310 trace_id, 

311 PROJECT_ID, 

312 query, 

313 top_k, 

314 len(results), 

315 json.dumps([hit.__dict__ for hit in results], sort_keys=True), 

316 _now_sql(), 

317 ), 

318 ) 

319 return trace_id 

320 

321 def _semantic_enabled(self) -> bool: 

322 return self.embedder is not None and self.vector_store is not None 

323 

324 def _search_semantic_chunks( 

325 self, 

326 conn: sqlite3.Connection, 

327 query: str, 

328 top_k: int, 

329 ) -> list[SearchHit]: 

330 if self.embedder is None: 

331 return [] 

332 assert self.embedder is not None 

333 query_vector = self.embedder.embed(query) 

334 matches = self._vector_store_matches(query_vector, top_k) 

335 if not matches: 

336 matches = _sqlite_vector_matches(conn, query_vector, top_k) 

337 if not matches: 

338 return [] 

339 

340 hits: list[SearchHit] = [] 

341 for match in matches: 

342 row = conn.execute( 

343 """ 

344 SELECT c.*, d.path 

345 FROM chunks c 

346 JOIN documents d ON d.id = c.document_id 

347 LEFT JOIN source_ledger sl ON sl.source_path_or_alias = d.path 

348 WHERE c.qdrant_point_id = ? 

349 AND (sl.source_id IS NULL OR sl.status NOT IN ('missing', 'skipped')) 

350 """, 

351 (match.point_id,), 

352 ).fetchone() 

353 if row is None: 

354 continue 

355 hits.append( 

356 SearchHit( 

357 title=row["heading_path"] or Path(row["path"]).name, 

358 content=row["content"], 

359 source_path=row["path"], 

360 heading_path=row["heading_path"], 

361 memory_type="document_chunk", 

362 status="active", 

363 score=match.score, 

364 ) 

365 ) 

366 return hits 

367 

368 def _vector_store_matches(self, query_vector: list[float], top_k: int) -> list[VectorMatch]: 

369 if self.vector_store is None: 

370 return [] 

371 try: 

372 matches = self.vector_store.search(query_vector, top_k=top_k) 

373 except Exception: 

374 return [] 

375 return [_coerce_vector_match(match) for match in matches] 

376 

377 

378def _search_memories( 

379 conn: sqlite3.Connection, tokens: set[str], include_inactive: bool 

380) -> list[SearchHit]: 

381 if include_inactive: 

382 rows = conn.execute("SELECT * FROM memory_nodes").fetchall() 

383 else: 

384 rows = conn.execute( 

385 "SELECT * FROM memory_nodes WHERE status IN (?, ?)", ACTIVE_STATUSES 

386 ).fetchall() 

387 hits = [] 

388 for row in rows: 

389 overlap = tokens & tokenize_set(f"{row['title']} {row['content']}") 

390 score = round(len(overlap) / len(tokens), 4) if tokens else 0.0 

391 if score > 0: 

392 hits.append( 

393 SearchHit( 

394 title=row["title"], 

395 content=row["content"], 

396 source_path=row["source_path"], 

397 heading_path=None, 

398 memory_type=row["type"], 

399 status=row["status"], 

400 score=score, 

401 ) 

402 ) 

403 return hits 

404 

405 

406def _search_chunks(conn: sqlite3.Connection, tokens: set[str]) -> list[SearchHit]: 

407 from truenex_memory.retrieval.scoring import BM25, tokenize, source_boost 

408 rows = conn.execute( 

409 """ 

410 SELECT c.*, d.path 

411 FROM chunks c 

412 JOIN documents d ON d.id = c.document_id 

413 LEFT JOIN source_ledger sl ON sl.source_path_or_alias = d.path 

414 WHERE sl.source_id IS NULL OR sl.status NOT IN ('missing', 'skipped') 

415 """ 

416 ).fetchall() 

417 if not rows: 

418 return [] 

419 

420 contents = [str(row["content"] or "") for row in rows] 

421 query_tokens = list(tokens) 

422 tokenized = [tokenize(c) for c in contents] 

423 bm25 = BM25(tokenized) 

424 scores = bm25.get_scores(query_tokens) 

425 

426 hits = [] 

427 for row, raw_score in zip(rows, scores): 

428 if raw_score <= 0: 

429 continue 

430 st = row["source_type"] if "source_type" in row.keys() else None 

431 final_score = round(raw_score * source_boost(st), 6) 

432 hits.append( 

433 SearchHit( 

434 title=str(row["heading_path"] or Path(str(row["path"])).name), 

435 content=str(row["content"] or ""), 

436 source_path=str(row["path"]) if row["path"] is not None else None, 

437 heading_path=str(row["heading_path"]) if row["heading_path"] is not None else None, 

438 memory_type="document_chunk", 

439 status="active", 

440 score=final_score, 

441 ) 

442 ) 

443 return hits 

444 

445 

446def _sqlite_vector_matches( 

447 conn: sqlite3.Connection, query_vector: list[float], top_k: int 

448) -> list[VectorMatch]: 

449 rows = conn.execute( 

450 """ 

451 SELECT c.qdrant_point_id, c.embedding_vector_json 

452 FROM chunks c 

453 JOIN documents d ON d.id = c.document_id 

454 LEFT JOIN source_ledger sl ON sl.source_path_or_alias = d.path 

455 WHERE c.qdrant_point_id IS NOT NULL 

456 AND c.embedding_vector_json IS NOT NULL 

457 AND (sl.source_id IS NULL OR sl.status NOT IN ('missing', 'skipped')) 

458 """ 

459 ).fetchall() 

460 matches: list[VectorMatch] = [] 

461 for row in rows: 

462 try: 

463 vector = json.loads(row["embedding_vector_json"]) 

464 except json.JSONDecodeError: 

465 continue 

466 if not isinstance(vector, list): 

467 continue 

468 score = _cosine(query_vector, [float(value) for value in vector]) 

469 if score > 0: 

470 matches.append(VectorMatch(point_id=row["qdrant_point_id"], score=round(score, 4))) 

471 matches.sort(key=lambda item: item.score, reverse=True) 

472 return matches[:top_k] 

473 

474 

475def _coerce_vector_match(match: object) -> VectorMatch: 

476 if isinstance(match, VectorMatch): 

477 return match 

478 if isinstance(match, VectorSearchHit): 

479 return VectorMatch(point_id=match.id, score=match.score) 

480 point_id = getattr(match, "point_id", None) or getattr(match, "id", None) 

481 score = getattr(match, "score", 0.0) 

482 return VectorMatch(point_id=str(point_id), score=float(score)) 

483 

484 

485def _cosine(left: list[float], right: list[float]) -> float: 

486 if len(left) != len(right) or not left: 

487 return 0.0 

488 return sum(a * b for a, b in zip(left, right, strict=True)) 

489 

490 

491def _rows(conn: sqlite3.Connection, table: str) -> list[dict[str, object]]: 

492 rows = conn.execute(f"SELECT * FROM {table}").fetchall() 

493 return [dict(row) for row in rows] 

494 

495 

496def _upsert_row(conn: sqlite3.Connection, table: str, row: dict[str, object]) -> None: 

497 columns = list(row.keys()) 

498 placeholders = ", ".join("?" for _ in columns) 

499 column_sql = ", ".join(columns) 

500 conn.execute( 

501 f"INSERT OR REPLACE INTO {table} ({column_sql}) VALUES ({placeholders})", 

502 [row[column] for column in columns], 

503 ) 

504 

505 

506def _title_from_content(content: str) -> str: 

507 first_line = content.splitlines()[0].strip() 

508 return first_line[:80] or "Untitled memory" 

509 

510 

511def _filename_from_logical_path(relative_path: str, *, fallback: Path) -> str: 

512 cleaned = relative_path.strip().replace("\\", "/").rstrip("/") 

513 if not cleaned: 

514 return fallback.name 

515 name = cleaned.rsplit("/", 1)[-1] 

516 return name or fallback.name 

517 

518 

519def _memory_node_from_row(row: sqlite3.Row) -> MemoryNode: 

520 return MemoryNode( 

521 id=row["id"], 

522 project_id=row["project_id"], 

523 type=row["type"], 

524 title=row["title"], 

525 content=row["content"], 

526 status=row["status"], 

527 source_kind=row["source_kind"], 

528 source_document_id=row["source_document_id"], 

529 source_chunk_id=row["source_chunk_id"], 

530 source_path=row["source_path"], 

531 content_hash=row["content_hash"], 

532 created_by=row["created_by"], 

533 model_name=row["model_name"], 

534 confidence=row["confidence"], 

535 created_at=row["created_at"], 

536 updated_at=row["updated_at"], 

537 ) 

538 

539 

540def _retrieval_log_from_row(row: sqlite3.Row) -> RetrievalLog: 

541 return RetrievalLog( 

542 id=row["id"], 

543 project_id=row["project_id"], 

544 query=row["query"], 

545 top_k=row["top_k"], 

546 result_count=row["result_count"], 

547 results_json=row["results_json"], 

548 created_at=row["created_at"], 

549 ) 

550 

551 

552def _new_id(prefix: str) -> str: 

553 return f"{prefix}_{uuid.uuid4().hex}" 

554 

555 

556def _now_sql() -> str: 

557 return datetime.now(timezone.utc).isoformat()