Coverage for src / kemi / audit.py: 88%
177 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1"""Audit trail for kemi memory operations.
3Provides compliance-grade operation logging with:
4- Complete CRUD operation audit trail
5- Queryable audit log (by user, operation type, time range)
6- Retention policy support
7- Export capability for compliance audits
9Every memory mutation (remember, forget, update, prune, migrate, etc.)
10is logged with timestamp, user ID, operation type, and details.
12Stored in a separate SQLite table `audit_log` for clean separation.
13Zero external dependencies beyond the existing SQLite adapter.
15Usage:
16 from kemi.audit import AuditTrail
18 audit = AuditTrail(db_connection)
19 audit.log_operation("alice", "remember", {"memory_id": "abc123"})
20 entries = audit.query(user_id="alice", operation="remember")
21"""
23import json
24import logging
25import sqlite3
26import time
27from dataclasses import dataclass, field
28from datetime import datetime, timezone
29from typing import Any
31logger = logging.getLogger(__name__)
33# Schema version for future migrations
34AUDIT_SCHEMA_VERSION = 1
36AUDIT_SCHEMA_SQL = """
37CREATE TABLE IF NOT EXISTS audit_log (
38 id INTEGER PRIMARY KEY AUTOINCREMENT,
39 timestamp TEXT NOT NULL,
40 user_id TEXT NOT NULL,
41 operation TEXT NOT NULL,
42 status TEXT NOT NULL DEFAULT 'success',
43 details TEXT NOT NULL DEFAULT '{}',
44 memory_id TEXT,
45 namespace TEXT DEFAULT 'default',
46 client_ip TEXT,
47 user_agent TEXT,
48 duration_ms REAL,
49 schema_version INTEGER DEFAULT 1
50);
52CREATE INDEX IF NOT EXISTS idx_audit_user_id ON audit_log(user_id);
53CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_log(timestamp);
54CREATE INDEX IF NOT EXISTS idx_audit_operation ON audit_log(operation);
55CREATE INDEX IF NOT EXISTS idx_audit_user_time ON audit_log(user_id, timestamp);
56"""
59@dataclass
60class AuditEntry:
61 """A single audit log entry."""
63 id: int = 0
64 timestamp: str = ""
65 user_id: str = ""
66 operation: str = ""
67 status: str = "success"
68 details: dict[str, Any] = field(default_factory=dict)
69 memory_id: str | None = None
70 namespace: str = "default"
71 client_ip: str | None = None
72 user_agent: str | None = None
73 duration_ms: float | None = None
74 schema_version: int = AUDIT_SCHEMA_VERSION
76 def to_dict(self) -> dict[str, Any]:
77 return {
78 "id": self.id,
79 "timestamp": self.timestamp,
80 "user_id": self.user_id,
81 "operation": self.operation,
82 "status": self.status,
83 "details": self.details,
84 "memory_id": self.memory_id,
85 "namespace": self.namespace,
86 "client_ip": self.client_ip,
87 "user_agent": self.user_agent,
88 "duration_ms": self.duration_ms,
89 }
92class AuditTrail:
93 """Audit trail for compliance-grade operation logging.
95 Features:
96 - Automatic schema creation
97 - Batch logging support
98 - Retention policy (auto-purge old entries)
99 - Query by user, operation, time range, status
100 - Export to JSON for compliance audits
101 """
103 def __init__(
104 self,
105 db_connection: sqlite3.Connection,
106 retention_days: int = 365,
107 auto_purge: bool = True,
108 ) -> None:
109 """Initialize audit trail.
111 Args:
112 db_connection: SQLite connection to use.
113 retention_days: Number of days to keep audit entries (default 365).
114 auto_purge: If True, automatically purge old entries on log_operation.
115 """
116 self._conn = db_connection
117 self._retention_days = retention_days
118 self._auto_purge = auto_purge
120 # Throttle auto-purge: run at most every 5 minutes or every 100 writes
121 self._writes_since_purge: int = 0
122 self._last_purge_time: float = time.time()
123 self._purge_interval_seconds: float = 300.0 # 5 minutes
124 self._purge_write_threshold: int = 100
126 self._ensure_schema()
127 logger.info(
128 f"Audit trail initialized (retention: {retention_days}d, auto_purge: {auto_purge})"
129 )
131 def _ensure_schema(self) -> None:
132 """Create audit log table and indexes if they don't exist."""
133 self._conn.executescript(AUDIT_SCHEMA_SQL)
134 self._conn.commit()
136 def log_operation(
137 self,
138 user_id: str,
139 operation: str,
140 details: dict[str, Any] | None = None,
141 memory_id: str | None = None,
142 namespace: str = "default",
143 status: str = "success",
144 client_ip: str | None = None,
145 user_agent: str | None = None,
146 duration_ms: float | None = None,
147 ) -> int:
148 """Log a memory operation to the audit trail.
150 Args:
151 user_id: User who performed the operation.
152 operation: Operation type (remember, recall, forget, update, etc.).
153 details: Additional operation details as dict.
154 memory_id: ID of the memory involved (if applicable).
155 namespace: Memory namespace.
156 status: Operation status (success, error, denied).
157 client_ip: Client IP address (for API usage).
158 user_agent: Client user agent string.
159 duration_ms: Operation duration in milliseconds.
161 Returns:
162 The ID of the new audit entry.
163 """
164 self._maybe_purge()
166 timestamp = datetime.now(timezone.utc).isoformat()
167 details_json = json.dumps(details or {}, default=str)
169 try:
170 cursor = self._conn.execute(
171 """INSERT INTO audit_log
172 (timestamp, user_id, operation, status, details, memory_id,
173 namespace, client_ip, user_agent, duration_ms, schema_version)
174 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
175 (
176 timestamp,
177 user_id,
178 operation,
179 status,
180 details_json,
181 memory_id,
182 namespace,
183 client_ip,
184 user_agent,
185 duration_ms,
186 AUDIT_SCHEMA_VERSION,
187 ),
188 )
189 self._conn.commit()
190 entry_id = cursor.lastrowid or 0
191 self._writes_since_purge += 1
192 logger.debug(f"Audit: {operation} by {user_id} (entry {entry_id})")
193 return entry_id
194 except sqlite3.Error as e:
195 logger.error(f"Failed to write audit entry: {e}")
196 raise
198 def log_operation_batch(
199 self,
200 entries: list[dict[str, Any]],
201 ) -> int:
202 """Log multiple operations atomically.
204 Args:
205 entries: List of dicts with keys matching log_operation params.
207 Returns:
208 Number of entries logged.
209 """
210 self._maybe_purge()
212 timestamp = datetime.now(timezone.utc).isoformat()
213 count = 0
215 try:
216 for entry in entries:
217 details_json = json.dumps(entry.get("details", {}), default=str)
218 self._conn.execute(
219 """INSERT INTO audit_log
220 (timestamp, user_id, operation, status, details, memory_id,
221 namespace, client_ip, user_agent, duration_ms, schema_version)
222 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
223 (
224 timestamp,
225 entry["user_id"],
226 entry["operation"],
227 entry.get("status", "success"),
228 details_json,
229 entry.get("memory_id"),
230 entry.get("namespace", "default"),
231 entry.get("client_ip"),
232 entry.get("user_agent"),
233 entry.get("duration_ms"),
234 AUDIT_SCHEMA_VERSION,
235 ),
236 )
237 count += 1
238 self._conn.commit()
239 self._writes_since_purge += count
240 except sqlite3.Error as e:
241 logger.error(f"Failed to write batch audit entries: {e}")
242 self._conn.rollback()
243 raise
245 return count
247 def query(
248 self,
249 user_id: str | None = None,
250 operation: str | None = None,
251 status: str | None = None,
252 memory_id: str | None = None,
253 namespace: str | None = None,
254 start_time: str | None = None,
255 end_time: str | None = None,
256 limit: int = 100,
257 offset: int = 0,
258 ) -> list[AuditEntry]:
259 """Query audit trail with flexible filters.
261 Args:
262 user_id: Filter by user ID.
263 operation: Filter by operation type.
264 status: Filter by status (success, error, denied).
265 memory_id: Filter by memory ID.
266 namespace: Filter by namespace.
267 start_time: ISO timestamp for start of range (inclusive).
268 end_time: ISO timestamp for end of range (inclusive).
269 limit: Maximum number of entries to return.
270 offset: Offset for pagination.
272 Returns:
273 List of matching AuditEntry objects.
274 """
275 conditions: list[str] = []
276 params: list[Any] = []
278 if user_id:
279 conditions.append("user_id = ?")
280 params.append(user_id)
281 if operation:
282 conditions.append("operation = ?")
283 params.append(operation)
284 if status:
285 conditions.append("status = ?")
286 params.append(status)
287 if memory_id:
288 conditions.append("memory_id = ?")
289 params.append(memory_id)
290 if namespace:
291 conditions.append("namespace = ?")
292 params.append(namespace)
293 if start_time:
294 conditions.append("timestamp >= ?")
295 params.append(start_time)
296 if end_time:
297 conditions.append("timestamp <= ?")
298 params.append(end_time)
300 where_clause = " AND ".join(conditions) if conditions else "1=1"
301 query_sql = (
302 f"SELECT id, timestamp, user_id, operation, status, details, "
303 f"memory_id, namespace, client_ip, user_agent, duration_ms, "
304 f"schema_version "
305 f"FROM audit_log WHERE {where_clause} "
306 f"ORDER BY timestamp DESC LIMIT ? OFFSET ?"
307 )
308 params.extend([limit, offset])
310 try:
311 cursor = self._conn.execute(query_sql, params)
312 results: list[AuditEntry] = []
313 for row in cursor.fetchall():
314 try:
315 details = json.loads(row[5])
316 except (json.JSONDecodeError, TypeError):
317 details = {}
319 results.append(
320 AuditEntry(
321 id=row[0],
322 timestamp=row[1],
323 user_id=row[2],
324 operation=row[3],
325 status=row[4],
326 details=details,
327 memory_id=row[6],
328 namespace=row[7],
329 client_ip=row[8],
330 user_agent=row[9],
331 duration_ms=row[10],
332 schema_version=row[11],
333 )
334 )
335 return results
336 except sqlite3.Error as e:
337 logger.error(f"Failed to query audit trail: {e}")
338 raise
340 def get_user_activity(
341 self,
342 user_id: str,
343 days: int = 30,
344 ) -> dict[str, Any]:
345 """Get activity summary for a user.
347 Args:
348 user_id: User ID to query.
349 days: Number of days to look back.
351 Returns:
352 Dict with operation counts, last activity timestamp, etc.
353 """
354 from datetime import timedelta
356 start_time = (datetime.now(timezone.utc) - timedelta(days=days)).isoformat()
358 try:
359 cursor = self._conn.execute(
360 """SELECT operation, COUNT(*) as count
361 FROM audit_log
362 WHERE user_id = ? AND timestamp >= ?
363 GROUP BY operation""",
364 (user_id, start_time),
365 )
366 operation_counts = {row[0]: row[1] for row in cursor.fetchall()}
368 cursor = self._conn.execute(
369 """SELECT MAX(timestamp) FROM audit_log WHERE user_id = ?""",
370 (user_id,),
371 )
372 last_activity = cursor.fetchone()[0]
374 cursor = self._conn.execute(
375 """SELECT COUNT(*) FROM audit_log
376 WHERE user_id = ? AND timestamp >= ?""",
377 (user_id, start_time),
378 )
379 total_operations = cursor.fetchone()[0]
381 return {
382 "user_id": user_id,
383 "period_days": days,
384 "total_operations": total_operations,
385 "operation_counts": operation_counts,
386 "last_activity": last_activity,
387 }
388 except sqlite3.Error as e:
389 logger.error(f"Failed to get user activity: {e}")
390 raise
392 def get_stats(self) -> dict[str, Any]:
393 """Get overall audit trail statistics."""
394 try:
395 cursor = self._conn.execute("SELECT COUNT(*) FROM audit_log")
396 total_entries = cursor.fetchone()[0]
398 cursor = self._conn.execute("SELECT COUNT(DISTINCT user_id) FROM audit_log")
399 unique_users = cursor.fetchone()[0]
401 cursor = self._conn.execute("SELECT MIN(timestamp), MAX(timestamp) FROM audit_log")
402 row = cursor.fetchone()
403 first_entry = row[0]
404 last_entry = row[1]
406 return {
407 "total_entries": total_entries,
408 "unique_users": unique_users,
409 "first_entry": first_entry,
410 "last_entry": last_entry,
411 "retention_days": self._retention_days,
412 }
413 except sqlite3.Error as e:
414 logger.error(f"Failed to get audit stats: {e}")
415 raise
417 def export(
418 self,
419 start_time: str | None = None,
420 end_time: str | None = None,
421 user_id: str | None = None,
422 ) -> list[dict[str, Any]]:
423 """Export audit entries as a list of dicts for compliance.
425 Args:
426 start_time: ISO timestamp filter.
427 end_time: ISO timestamp filter.
428 user_id: Optional user filter.
430 Returns:
431 List of dicts suitable for JSON export.
432 """
433 entries = self.query(
434 user_id=user_id,
435 start_time=start_time,
436 end_time=end_time,
437 limit=100000, # Large limit for exports
438 )
439 return [e.to_dict() for e in entries]
441 def _maybe_purge(self) -> None:
442 """Throttled purge: only run if enough time or writes have passed."""
443 if not self._auto_purge:
444 return
445 now = time.time()
446 if (
447 now - self._last_purge_time < self._purge_interval_seconds
448 and self._writes_since_purge < self._purge_write_threshold
449 ):
450 return
451 self._last_purge_time = now
452 self._writes_since_purge = 0
453 self._purge_old_entries()
455 def _purge_old_entries(self) -> int:
456 """Remove entries older than retention_days.
458 Returns:
459 Number of entries purged.
460 """
461 if self._retention_days <= 0:
462 return 0
464 from datetime import timedelta
466 cutoff = (datetime.now(timezone.utc) - timedelta(days=self._retention_days)).isoformat()
468 try:
469 cursor = self._conn.execute(
470 "DELETE FROM audit_log WHERE timestamp < ?",
471 (cutoff,),
472 )
473 self._conn.commit()
474 deleted = cursor.rowcount
475 if deleted > 0:
476 logger.info(f"Purged {deleted} old audit entries (cutoff: {cutoff})")
477 return deleted
478 except sqlite3.Error as e:
479 logger.error(f"Failed to purge old audit entries: {e}")
480 return 0
482 def purge_all(self) -> int:
483 """Purge all audit entries. Use with caution.
485 Returns:
486 Number of entries purged.
487 """
488 try:
489 cursor = self._conn.execute("DELETE FROM audit_log")
490 self._conn.commit()
491 deleted = cursor.rowcount
492 logger.warning(f"Purged ALL audit entries: {deleted}")
493 return deleted
494 except sqlite3.Error as e:
495 logger.error(f"Failed to purge all audit entries: {e}")
496 return 0
498 def close(self) -> None:
499 """Close the audit trail (no-op, connection managed externally)."""
500 pass