Coverage for src / kemi / versions.py: 96%
234 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1"""Memory versioning and undo: keep a history of memory edits and support rollback.
3Every call to :meth:`Memory.update` can optionally be recorded in a versions
4table. Users can then:
5- List all past versions of a memory
6- Preview any past version
7- Roll back to a previous version
8- Diff two versions to see what changed
10This is useful for:
11- Debugging: understand how a memory evolved over time
12- Undo: revert to a known-good state after accidental edits
13- Audit: track when and how memory content changed
15Schema: a separate ``memory_versions`` table stores snapshots of memory
16fields before each update. The current state is always in ``memory_versions``
17with ``version = current_version``; older snapshots have ``version < current_version``.
18"""
20from __future__ import annotations
22import json
23import logging
24import struct
25from dataclasses import dataclass
26from datetime import datetime, timezone
27from typing import Any
29from kemi.models import MemoryObject, MemorySource, MemoryType
31__all__ = [
32 "MemoryVersionStore",
33 "VersionSnapshot",
34 "RollbackResult",
35 "diff_memories",
36]
38logger = logging.getLogger(__name__)
41def _pack_embedding(embedding: list[float] | None) -> bytes | None:
42 """Pack a list of floats into 8 bytes per float for exact round-trip.
44 Float32 is too imprecise for values like ``0.1`` (``0.10000000149...``),
45 so we store as little-endian float64 instead.
46 """
47 if not embedding:
48 return None
49 return struct.pack(f"<{len(embedding)}d", *embedding)
52def _unpack_embedding(blob: bytes | None) -> list[float] | None:
53 """Unpack a float64 blob back into a list of Python floats."""
54 if not blob:
55 return None
56 if len(blob) % 8 != 0:
57 # Fall back to float32 in case an older row was written with the
58 # original 4-byte-per-float encoding.
59 if len(blob) % 4 == 0:
60 return list(struct.unpack(f"<{len(blob) // 4}f", blob))
61 return None
62 return list(struct.unpack(f"<{len(blob) // 8}d", blob))
64# ---------------------------------------------------------------------------
65# Data types
66# ---------------------------------------------------------------------------
69@dataclass
70class VersionSnapshot:
71 """A snapshot of a memory at a point in time."""
73 version: int # version number (1 = original, increments per edit)
74 memory_id: str
75 content: str
76 embedding: list[float] | None
77 importance: float
78 metadata: dict[str, Any]
79 tags: list[str]
80 memory_type: str
81 confidence: float
82 session_id: str | None
83 namespace: str
84 source: str
85 changed_at: datetime
86 changed_by: str | None # "update", "import", "consolidate", etc.
89@dataclass
90class RollbackResult:
91 """Result of a rollback operation."""
93 memory_id: str
94 from_version: int
95 to_version: int
96 rolled_back_at: datetime
99@dataclass
100class DiffResult:
101 """Diff between two memory versions."""
103 memory_id: str
104 from_version: int
105 to_version: int
106 field_changes: dict[str, tuple[Any, Any]] # field → (old, new)
109# ---------------------------------------------------------------------------
110# Version store (stored alongside the main SQLite adapter)
111# ---------------------------------------------------------------------------
113_VERSION_TABLE_DDL = """
114CREATE TABLE IF NOT EXISTS memory_versions (
115 memory_id TEXT NOT NULL,
116 version INTEGER NOT NULL,
117 content TEXT NOT NULL,
118 embedding BLOB,
119 importance REAL NOT NULL DEFAULT 0.5,
120 metadata TEXT NOT NULL DEFAULT '{}',
121 tags TEXT NOT NULL DEFAULT '[]',
122 memory_type TEXT NOT NULL DEFAULT 'episodic',
123 confidence REAL NOT NULL DEFAULT 1.0,
124 session_id TEXT,
125 namespace TEXT NOT NULL DEFAULT 'default',
126 source TEXT NOT NULL DEFAULT 'user_stated',
127 changed_at TEXT NOT NULL,
128 changed_by TEXT,
129 PRIMARY KEY (memory_id, version)
130);
131CREATE INDEX IF NOT EXISTS idx_versions_memory
132 ON memory_versions(memory_id, version DESC);
133"""
135_CHANGE_TABLE_DDL = """
136CREATE TABLE IF NOT EXISTS memory_change_log (
137 id INTEGER PRIMARY KEY AUTOINCREMENT,
138 memory_id TEXT NOT NULL,
139 from_version INTEGER NOT NULL,
140 to_version INTEGER NOT NULL,
141 changed_at TEXT NOT NULL,
142 changed_by TEXT,
143 field TEXT NOT NULL,
144 old_value TEXT,
145 new_value TEXT
146);
147"""
150class MemoryVersionStore:
151 """Manages memory version history and rollback operations.
153 Uses separate SQLite tables (``memory_versions`` and
154 ``memory_change_log``) within the same database as the main
155 memory store.
157 Usage::
159 vs = MemoryVersionStore(db_path="~/.kemi/memories.db")
160 vs.record_version(memory_obj, changed_by="update")
161 snapshots = vs.list_versions("mem-123")
162 result = vs.rollback("mem-123", target_version=2)
163 """
165 def __init__(self, db_path: str | None = None) -> None:
166 import os
167 from pathlib import Path
168 import sqlite3
170 if db_path is None:
171 db_path = os.path.join(os.path.expanduser("~"), ".kemi", "memories.db")
172 self._db_path = str(Path(db_path).expanduser())
173 self._ensure_tables()
175 def _get_connection(self) -> Any:
176 import sqlite3
178 conn = sqlite3.connect(self._db_path)
179 conn.row_factory = sqlite3.Row
180 return conn
182 def _ensure_tables(self) -> None:
183 conn = self._get_connection()
184 try:
185 conn.executescript(_VERSION_TABLE_DDL)
186 conn.executescript(_CHANGE_TABLE_DDL)
187 conn.commit()
188 finally:
189 conn.close()
191 # -------------------------------------------------------------------------
192 # Recording
193 # -------------------------------------------------------------------------
195 def _next_version_number(
196 self,
197 conn: Any,
198 memory_id: str,
199 memory: MemoryObject,
200 ) -> int:
201 """Compute the next version number for a memory.
203 Honours the caller's ``memory.version`` by default so non-contiguous,
204 caller-specified version numbers (e.g. the rollback helper writing
205 at a chosen position) are preserved. Falls back to
206 ``MAX(version) + 1`` only when the supplied number would collide
207 with an existing row, which prevents the ``UNIQUE`` constraint
208 failure that occurs when concurrent writers or a caller that
209 forgot to increment ``memory.version`` race the same
210 ``(memory_id, version)`` primary key.
211 """
212 cursor = conn.cursor()
213 cursor.execute(
214 "SELECT MAX(version) FROM memory_versions WHERE memory_id = ?",
215 (memory_id,),
216 )
217 row = cursor.fetchone()
218 current_max = row[0] if row and row[0] is not None else 0
220 # If the caller's version is unused, respect it.
221 cursor.execute(
222 "SELECT 1 FROM memory_versions WHERE memory_id = ? AND version = ? LIMIT 1",
223 (memory_id, memory.version),
224 )
225 if cursor.fetchone() is None:
226 return memory.version
228 # Otherwise advance to the next free number.
229 return current_max + 1
231 def record_version(
232 self,
233 memory: MemoryObject,
234 *,
235 changed_by: str = "update",
236 ) -> int:
237 """Record a new version snapshot of a memory.
239 Uses the caller's ``memory.version`` when it advances the sequence.
240 If the supplied version number would collide with an existing row
241 (e.g. concurrent writes, or the caller hasn't incremented
242 ``memory.version``), the next available version number is used
243 automatically and written back to ``memory.version`` so subsequent
244 calls see the correct value.
246 Args:
247 memory: The current MemoryObject to snapshot.
248 changed_by: Label describing what operation triggered this snapshot
249 (e.g., "update", "import", "consolidate").
251 Returns:
252 The version number that was written.
253 """
254 import json
256 conn = self._get_connection()
257 try:
258 # BEGIN IMMEDIATE acquires a RESERVED lock for the duration of
259 # the transaction. This serialises concurrent record_version
260 # calls for the same memory_id, so the MAX(version)+1 read and
261 # INSERT below cannot race with another writer.
262 conn.execute("BEGIN IMMEDIATE")
263 try:
264 next_version = self._next_version_number(conn, memory.memory_id, memory)
265 memory.version = next_version
267 cursor = conn.cursor()
268 cursor.execute(
269 """
270 INSERT INTO memory_versions
271 (memory_id, version, content, embedding, importance,
272 metadata, tags, memory_type, confidence, session_id,
273 namespace, source, changed_at, changed_by)
274 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
275 """,
276 (
277 memory.memory_id,
278 next_version,
279 memory.content,
280 _pack_embedding(memory.embedding),
281 memory.importance,
282 json.dumps(memory.metadata or {}),
283 json.dumps(memory.tags or []),
284 memory.memory_type.value
285 if hasattr(memory.memory_type, "value")
286 else str(memory.memory_type),
287 memory.confidence,
288 memory.session_id,
289 memory.namespace,
290 memory.source.value
291 if hasattr(memory.source, "value")
292 else str(memory.source),
293 datetime.now(timezone.utc).isoformat(),
294 changed_by,
295 ),
296 )
297 conn.commit()
298 return next_version
299 except Exception:
300 conn.rollback()
301 raise
302 finally:
303 conn.close()
305 def record_before_update(
306 self,
307 memory_before: MemoryObject,
308 memory_after: MemoryObject,
309 *,
310 changed_by: str = "update",
311 ) -> int:
312 """Record both the pre-update and post-update snapshots atomically.
314 Records the pre-update snapshot at its current version number and
315 the post-update snapshot at the next available version number.
316 Both inserts run inside a single ``BEGIN IMMEDIATE`` transaction so
317 they cannot interleave with another writer.
319 Args:
320 memory_before: State of memory before the update.
321 memory_after: State of memory after the update.
322 changed_by: Operation label.
324 Returns:
325 The new version number of memory_after.
326 """
327 import json
329 conn = self._get_connection()
330 try:
331 conn.execute("BEGIN IMMEDIATE")
332 try:
333 cursor = conn.cursor()
335 # Pre-update snapshot: store at its current version number
336 # (which the caller typically set to the live version). If a
337 # row already exists at that (memory_id, version) we still
338 # upsert via INSERT OR REPLACE — re-recording a pre-update
339 # snapshot for the same version is idempotent.
340 pre_version = memory_before.version
341 cursor.execute(
342 """
343 INSERT OR REPLACE INTO memory_versions
344 (memory_id, version, content, embedding, importance,
345 metadata, tags, memory_type, confidence, session_id,
346 namespace, source, changed_at, changed_by)
347 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
348 """,
349 (
350 memory_before.memory_id,
351 pre_version,
352 memory_before.content,
353 _pack_embedding(memory_before.embedding),
354 memory_before.importance,
355 json.dumps(memory_before.metadata or {}),
356 json.dumps(memory_before.tags or []),
357 memory_before.memory_type.value
358 if hasattr(memory_before.memory_type, "value")
359 else str(memory_before.memory_type),
360 memory_before.confidence,
361 memory_before.session_id,
362 memory_before.namespace,
363 memory_before.source.value
364 if hasattr(memory_before.source, "value")
365 else str(memory_before.source),
366 datetime.now(timezone.utc).isoformat(),
367 "pre-" + changed_by,
368 ),
369 )
371 # Post-update snapshot: take the next available version
372 # number so the primary key never collides.
373 post_version = self._next_version_number(
374 conn, memory_after.memory_id, memory_after
375 )
376 memory_after.version = post_version
377 cursor.execute(
378 """
379 INSERT INTO memory_versions
380 (memory_id, version, content, embedding, importance,
381 metadata, tags, memory_type, confidence, session_id,
382 namespace, source, changed_at, changed_by)
383 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
384 """,
385 (
386 memory_after.memory_id,
387 post_version,
388 memory_after.content,
389 _pack_embedding(memory_after.embedding),
390 memory_after.importance,
391 json.dumps(memory_after.metadata or {}),
392 json.dumps(memory_after.tags or []),
393 memory_after.memory_type.value
394 if hasattr(memory_after.memory_type, "value")
395 else str(memory_after.memory_type),
396 memory_after.confidence,
397 memory_after.session_id,
398 memory_after.namespace,
399 memory_after.source.value
400 if hasattr(memory_after.source, "value")
401 else str(memory_after.source),
402 datetime.now(timezone.utc).isoformat(),
403 changed_by,
404 ),
405 )
406 conn.commit()
407 return post_version
408 except Exception:
409 conn.rollback()
410 raise
411 finally:
412 conn.close()
414 def prune_versions(self, memory_id: str, keep_count: int) -> int:
415 """Prune old versions, keeping only the most recent N versions.
417 Args:
418 memory_id: Memory whose versions to prune.
419 keep_count: Number of recent versions to keep.
421 Returns:
422 Number of versions deleted.
423 """
424 if keep_count <= 0:
425 return 0
426 conn = self._get_connection()
427 try:
428 conn.execute("BEGIN IMMEDIATE")
429 try:
430 cursor = conn.cursor()
431 cursor.execute(
432 """
433 SELECT version FROM memory_versions
434 WHERE memory_id = ?
435 ORDER BY version DESC
436 """,
437 (memory_id,),
438 )
439 rows = cursor.fetchall()
440 if len(rows) <= keep_count:
441 conn.commit()
442 return 0
443 versions_to_delete = [r[0] for r in rows[keep_count:]]
444 placeholders = ",".join("?" * len(versions_to_delete))
445 cursor.execute(
446 f"DELETE FROM memory_versions "
447 f"WHERE memory_id = ? AND version IN ({placeholders})",
448 [memory_id] + versions_to_delete,
449 )
450 deleted = cursor.rowcount
451 conn.commit()
452 return deleted
453 except Exception:
454 conn.rollback()
455 raise
456 finally:
457 conn.close()
459 def verify_sequential_versions(self, memory_id: str) -> bool:
460 """Verify that version numbers for a memory form a contiguous sequence.
462 Returns True if versions are 1, 2, 3, ... with no gaps. Useful as an
463 integrity check after concurrent writes.
464 """
465 conn = self._get_connection()
466 try:
467 cursor = conn.cursor()
468 cursor.execute(
469 "SELECT version FROM memory_versions WHERE memory_id = ? ORDER BY version ASC",
470 (memory_id,),
471 )
472 versions = [r[0] for r in cursor.fetchall()]
473 return versions == list(range(1, len(versions) + 1))
474 finally:
475 conn.close()
477 # -------------------------------------------------------------------------
478 # Querying
479 # -------------------------------------------------------------------------
481 def list_versions(self, memory_id: str) -> list[VersionSnapshot]:
482 """Return all version snapshots for a memory, newest first.
484 Args:
485 memory_id: ID of the memory.
487 Returns:
488 List of VersionSnapshot objects, sorted by version descending.
489 """
490 import json
492 conn = self._get_connection()
493 try:
494 cursor = conn.cursor()
495 cursor.execute(
496 """
497 SELECT memory_id, version, content, embedding, importance,
498 metadata, tags, memory_type, confidence, session_id,
499 namespace, source, changed_at, changed_by
500 FROM memory_versions
501 WHERE memory_id = ?
502 ORDER BY version DESC
503 """,
504 (memory_id,),
505 )
506 rows = cursor.fetchall()
507 return [self._row_to_snapshot(row) for row in rows]
508 finally:
509 conn.close()
511 def get_version(self, memory_id: str, version: int) -> VersionSnapshot | None:
512 """Retrieve a specific version of a memory.
514 Args:
515 memory_id: ID of the memory.
516 version: Version number to retrieve.
518 Returns:
519 VersionSnapshot if found, None otherwise.
520 """
521 conn = self._get_connection()
522 try:
523 cursor = conn.cursor()
524 cursor.execute(
525 """
526 SELECT memory_id, version, content, embedding, importance,
527 metadata, tags, memory_type, confidence, session_id,
528 namespace, source, changed_at, changed_by
529 FROM memory_versions
530 WHERE memory_id = ? AND version = ?
531 """,
532 (memory_id, version),
533 )
534 row = cursor.fetchone()
535 return self._row_to_snapshot(row) if row else None
536 finally:
537 conn.close()
539 def get_latest_version_number(self, memory_id: str) -> int | None:
540 """Return the highest version number for a memory, or None if no versions exist."""
541 conn = self._get_connection()
542 try:
543 cursor = conn.cursor()
544 cursor.execute(
545 "SELECT MAX(version) FROM memory_versions WHERE memory_id = ?",
546 (memory_id,),
547 )
548 row = cursor.fetchone()
549 return row[0] if row and row[0] is not None else None
550 finally:
551 conn.close()
553 # -------------------------------------------------------------------------
554 # Rollback
555 # -------------------------------------------------------------------------
557 def rollback(
558 self,
559 memory_id: str,
560 target_version: int,
561 store: Any,
562 *,
563 changed_by: str = "rollback",
564 ) -> RollbackResult | None:
565 """Roll a memory back to a specific version.
567 Reconstructs the MemoryObject from the version snapshot and
568 writes it back to the storage adapter. The new state is recorded
569 as a fresh, monotonically-increasing version (MAX(version) + 1)
570 rather than reusing the old version number, preserving the
571 full audit trail.
573 Args:
574 memory_id: ID of the memory to roll back.
575 target_version: Version number to roll back to.
576 store: The StorageAdapter to write the rolled-back memory to.
577 changed_by: Label for the rollback operation.
579 Returns:
580 RollbackResult if successful, None if target version not found.
581 """
582 snapshot = self.get_version(memory_id, target_version)
583 if snapshot is None:
584 return None
586 # Fetch current memory to preserve user_id, lifecycle_state,
587 # and to read the next available version number for the store row.
588 current = store.get(memory_id)
589 if current is None:
590 return None
592 # Reconstruct MemoryObject from snapshot
593 rolled_back = MemoryObject(
594 memory_id=memory_id,
595 user_id=current.user_id,
596 content=snapshot.content,
597 embedding=snapshot.embedding,
598 score=0.0,
599 created_at=current.created_at,
600 last_accessed_at=datetime.now(timezone.utc),
601 source=MemorySource(snapshot.source),
602 importance=snapshot.importance,
603 lifecycle_state=current.lifecycle_state,
604 metadata=json.loads(snapshot.metadata) if isinstance(snapshot.metadata, str) else (snapshot.metadata or {}),
605 embedding_dim=len(snapshot.embedding) if snapshot.embedding else None,
606 tags=json.loads(snapshot.tags) if isinstance(snapshot.tags, str) else (snapshot.tags or []),
607 confidence=snapshot.confidence,
608 memory_type=MemoryType(snapshot.memory_type),
609 session_id=snapshot.session_id,
610 namespace=snapshot.namespace,
611 version=current.version,
612 )
614 # Compute the new version number for the rollback row BEFORE writing
615 # to the store. We use MAX(version) + 1 so the audit trail is
616 # contiguous and the row's version field reflects the new state.
617 new_version = self.get_latest_version_number(memory_id)
618 if new_version is None:
619 new_version = current.version
620 else:
621 new_version = new_version + 1
622 rolled_back.version = new_version
624 # Write the restored content back to the store. The storage adapter
625 # performs an INSERT OR REPLACE keyed on memory_id, so the version
626 # value on the row is what we pass in.
627 store.update(rolled_back)
629 # Record the rolled-back state as a version snapshot so the history
630 # is preserved. record_version auto-increments — we use INSERT OR
631 # REPLACE via the connection so it doesn't double-increment.
632 conn = self._get_connection()
633 try:
634 conn.execute("BEGIN IMMEDIATE")
635 try:
636 cursor = conn.cursor()
637 cursor.execute(
638 """
639 INSERT OR REPLACE INTO memory_versions
640 (memory_id, version, content, embedding, importance,
641 metadata, tags, memory_type, confidence, session_id,
642 namespace, source, changed_at, changed_by)
643 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
644 """,
645 (
646 rolled_back.memory_id,
647 new_version,
648 rolled_back.content,
649 _pack_embedding(rolled_back.embedding),
650 rolled_back.importance,
651 json.dumps(rolled_back.metadata or {}),
652 json.dumps(rolled_back.tags or []),
653 rolled_back.memory_type.value
654 if hasattr(rolled_back.memory_type, "value")
655 else str(rolled_back.memory_type),
656 rolled_back.confidence,
657 rolled_back.session_id,
658 rolled_back.namespace,
659 rolled_back.source.value
660 if hasattr(rolled_back.source, "value")
661 else str(rolled_back.source),
662 datetime.now(timezone.utc).isoformat(),
663 changed_by,
664 ),
665 )
666 conn.commit()
667 except Exception:
668 conn.rollback()
669 raise
670 finally:
671 conn.close()
673 # Sync the memory object's version field so callers see the new value
674 current.version = new_version
675 rolled_back.version = new_version
677 logger.info(
678 f"Rolled back memory {memory_id} from version {snapshot.version} "
679 f"to version {new_version}"
680 )
682 return RollbackResult(
683 memory_id=memory_id,
684 from_version=snapshot.version,
685 to_version=new_version,
686 rolled_back_at=datetime.now(timezone.utc),
687 )
689 # -------------------------------------------------------------------------
690 # Diff
691 # -------------------------------------------------------------------------
693 def diff(
694 self,
695 memory_id: str,
696 from_version: int,
697 to_version: int,
698 ) -> DiffResult | None:
699 """Show what changed between two versions of a memory.
701 Args:
702 memory_id: ID of the memory.
703 from_version: Starting version number.
704 to_version: Ending version number.
706 Returns:
707 DiffResult listing all field changes, or None if either version not found.
708 """
709 snap_from = self.get_version(memory_id, from_version)
710 snap_to = self.get_version(memory_id, to_version)
712 if snap_from is None or snap_to is None:
713 return None
715 changes: dict[str, tuple[Any, Any]] = {}
716 fields = [
717 ("content", snap_from.content, snap_to.content),
718 ("importance", snap_from.importance, snap_to.importance),
719 ("metadata", snap_from.metadata, snap_to.metadata),
720 ("tags", snap_from.tags, snap_to.tags),
721 ("memory_type", snap_from.memory_type, snap_to.memory_type),
722 ("confidence", snap_from.confidence, snap_to.confidence),
723 ("session_id", snap_from.session_id, snap_to.session_id),
724 ]
726 for field_name, old_val, new_val in fields:
727 old_normalized = self._normalize_field_value(old_val)
728 new_normalized = self._normalize_field_value(new_val)
729 if old_normalized != new_normalized:
730 changes[field_name] = (old_val, new_val)
732 return DiffResult(
733 memory_id=memory_id,
734 from_version=from_version,
735 to_version=to_version,
736 field_changes=changes,
737 )
739 # -------------------------------------------------------------------------
740 # Helpers
741 # -------------------------------------------------------------------------
743 def _row_to_snapshot(self, row) -> VersionSnapshot:
744 import json
746 return VersionSnapshot(
747 memory_id=row["memory_id"],
748 version=row["version"],
749 content=row["content"],
750 embedding=_unpack_embedding(row["embedding"]),
751 importance=row["importance"],
752 metadata=json.loads(row["metadata"]) if isinstance(row["metadata"], str) else row["metadata"],
753 tags=json.loads(row["tags"]) if isinstance(row["tags"], str) else row["tags"],
754 memory_type=row["memory_type"],
755 confidence=row["confidence"],
756 session_id=row["session_id"],
757 namespace=row["namespace"],
758 source=row["source"],
759 changed_at=datetime.fromisoformat(row["changed_at"]),
760 changed_by=row["changed_by"],
761 )
763 def _normalize_field_value(self, value: Any) -> str:
764 """Normalize a field value for comparison."""
765 if isinstance(value, (dict, list)):
766 return json.dumps(value, sort_keys=True)
767 return str(value)
770def diff_memories(
771 before: MemoryObject,
772 after: MemoryObject,
773) -> DiffResult:
774 """Diff two memory objects and return field-level changes.
776 Convenience function that doesn't need a version store —
777 useful for previewing what an update would change.
779 Args:
780 before: Memory state before the change.
781 after: Memory state after the change.
783 Returns:
784 DiffResult with all changed fields.
785 """
786 changes: dict[str, tuple[Any, Any]] = {}
787 for field in ("content", "importance", "confidence", "tags", "metadata", "memory_type"):
788 old_val = getattr(before, field, None)
789 new_val = getattr(after, field, None)
790 if old_val != new_val:
791 changes[field] = (old_val, new_val)
793 return DiffResult(
794 memory_id=before.memory_id,
795 from_version=before.version,
796 to_version=after.version,
797 field_changes=changes,
798 )
801# ---------------------------------------------------------------------------
802# Attach versioning to Memory core (patch-in hooks)
803# ---------------------------------------------------------------------------
805def enable_versioning(memory: Any, db_path: str | None = None) -> MemoryVersionStore:
806 """Enable memory versioning on a Memory instance.
808 Returns a MemoryVersionStore that can be used to record, list,
809 and rollback memory versions.
811 Args:
812 memory: A Memory instance to enable versioning on.
813 db_path: Optional path to the SQLite DB (defaults to ~.kemi/memories.db).
815 Usage::
817 vs = enable_versioning(memory)
818 vs.record_version(updated_memory, changed_by="update")
819 snapshots = vs.list_versions("mem-123")
820 vs.rollback("mem-123", target_version=2, store=memory._store)
821 """
822 if db_path is None:
823 import os
825 db_path = os.path.join(os.path.expanduser("~"), ".kemi", "memories.db")
827 return MemoryVersionStore(db_path=db_path)