Coverage for src / kemi / audit.py: 88%

177 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1"""Audit trail for kemi memory operations. 

2 

3Provides compliance-grade operation logging with: 

4- Complete CRUD operation audit trail 

5- Queryable audit log (by user, operation type, time range) 

6- Retention policy support 

7- Export capability for compliance audits 

8 

9Every memory mutation (remember, forget, update, prune, migrate, etc.) 

10is logged with timestamp, user ID, operation type, and details. 

11 

12Stored in a separate SQLite table `audit_log` for clean separation. 

13Zero external dependencies beyond the existing SQLite adapter. 

14 

15Usage: 

16 from kemi.audit import AuditTrail 

17 

18 audit = AuditTrail(db_connection) 

19 audit.log_operation("alice", "remember", {"memory_id": "abc123"}) 

20 entries = audit.query(user_id="alice", operation="remember") 

21""" 

22 

23import json 

24import logging 

25import sqlite3 

26import time 

27from dataclasses import dataclass, field 

28from datetime import datetime, timezone 

29from typing import Any 

30 

31logger = logging.getLogger(__name__) 

32 

33# Schema version for future migrations 

34AUDIT_SCHEMA_VERSION = 1 

35 

36AUDIT_SCHEMA_SQL = """ 

37CREATE TABLE IF NOT EXISTS audit_log ( 

38 id INTEGER PRIMARY KEY AUTOINCREMENT, 

39 timestamp TEXT NOT NULL, 

40 user_id TEXT NOT NULL, 

41 operation TEXT NOT NULL, 

42 status TEXT NOT NULL DEFAULT 'success', 

43 details TEXT NOT NULL DEFAULT '{}', 

44 memory_id TEXT, 

45 namespace TEXT DEFAULT 'default', 

46 client_ip TEXT, 

47 user_agent TEXT, 

48 duration_ms REAL, 

49 schema_version INTEGER DEFAULT 1 

50); 

51 

52CREATE INDEX IF NOT EXISTS idx_audit_user_id ON audit_log(user_id); 

53CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON audit_log(timestamp); 

54CREATE INDEX IF NOT EXISTS idx_audit_operation ON audit_log(operation); 

55CREATE INDEX IF NOT EXISTS idx_audit_user_time ON audit_log(user_id, timestamp); 

56""" 

57 

58 

59@dataclass 

60class AuditEntry: 

61 """A single audit log entry.""" 

62 

63 id: int = 0 

64 timestamp: str = "" 

65 user_id: str = "" 

66 operation: str = "" 

67 status: str = "success" 

68 details: dict[str, Any] = field(default_factory=dict) 

69 memory_id: str | None = None 

70 namespace: str = "default" 

71 client_ip: str | None = None 

72 user_agent: str | None = None 

73 duration_ms: float | None = None 

74 schema_version: int = AUDIT_SCHEMA_VERSION 

75 

76 def to_dict(self) -> dict[str, Any]: 

77 return { 

78 "id": self.id, 

79 "timestamp": self.timestamp, 

80 "user_id": self.user_id, 

81 "operation": self.operation, 

82 "status": self.status, 

83 "details": self.details, 

84 "memory_id": self.memory_id, 

85 "namespace": self.namespace, 

86 "client_ip": self.client_ip, 

87 "user_agent": self.user_agent, 

88 "duration_ms": self.duration_ms, 

89 } 

90 

91 

92class AuditTrail: 

93 """Audit trail for compliance-grade operation logging. 

94 

95 Features: 

96 - Automatic schema creation 

97 - Batch logging support 

98 - Retention policy (auto-purge old entries) 

99 - Query by user, operation, time range, status 

100 - Export to JSON for compliance audits 

101 """ 

102 

103 def __init__( 

104 self, 

105 db_connection: sqlite3.Connection, 

106 retention_days: int = 365, 

107 auto_purge: bool = True, 

108 ) -> None: 

109 """Initialize audit trail. 

110 

111 Args: 

112 db_connection: SQLite connection to use. 

113 retention_days: Number of days to keep audit entries (default 365). 

114 auto_purge: If True, automatically purge old entries on log_operation. 

115 """ 

116 self._conn = db_connection 

117 self._retention_days = retention_days 

118 self._auto_purge = auto_purge 

119 

120 # Throttle auto-purge: run at most every 5 minutes or every 100 writes 

121 self._writes_since_purge: int = 0 

122 self._last_purge_time: float = time.time() 

123 self._purge_interval_seconds: float = 300.0 # 5 minutes 

124 self._purge_write_threshold: int = 100 

125 

126 self._ensure_schema() 

127 logger.info( 

128 f"Audit trail initialized (retention: {retention_days}d, auto_purge: {auto_purge})" 

129 ) 

130 

131 def _ensure_schema(self) -> None: 

132 """Create audit log table and indexes if they don't exist.""" 

133 self._conn.executescript(AUDIT_SCHEMA_SQL) 

134 self._conn.commit() 

135 

136 def log_operation( 

137 self, 

138 user_id: str, 

139 operation: str, 

140 details: dict[str, Any] | None = None, 

141 memory_id: str | None = None, 

142 namespace: str = "default", 

143 status: str = "success", 

144 client_ip: str | None = None, 

145 user_agent: str | None = None, 

146 duration_ms: float | None = None, 

147 ) -> int: 

148 """Log a memory operation to the audit trail. 

149 

150 Args: 

151 user_id: User who performed the operation. 

152 operation: Operation type (remember, recall, forget, update, etc.). 

153 details: Additional operation details as dict. 

154 memory_id: ID of the memory involved (if applicable). 

155 namespace: Memory namespace. 

156 status: Operation status (success, error, denied). 

157 client_ip: Client IP address (for API usage). 

158 user_agent: Client user agent string. 

159 duration_ms: Operation duration in milliseconds. 

160 

161 Returns: 

162 The ID of the new audit entry. 

163 """ 

164 self._maybe_purge() 

165 

166 timestamp = datetime.now(timezone.utc).isoformat() 

167 details_json = json.dumps(details or {}, default=str) 

168 

169 try: 

170 cursor = self._conn.execute( 

171 """INSERT INTO audit_log 

172 (timestamp, user_id, operation, status, details, memory_id, 

173 namespace, client_ip, user_agent, duration_ms, schema_version) 

174 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", 

175 ( 

176 timestamp, 

177 user_id, 

178 operation, 

179 status, 

180 details_json, 

181 memory_id, 

182 namespace, 

183 client_ip, 

184 user_agent, 

185 duration_ms, 

186 AUDIT_SCHEMA_VERSION, 

187 ), 

188 ) 

189 self._conn.commit() 

190 entry_id = cursor.lastrowid or 0 

191 self._writes_since_purge += 1 

192 logger.debug(f"Audit: {operation} by {user_id} (entry {entry_id})") 

193 return entry_id 

194 except sqlite3.Error as e: 

195 logger.error(f"Failed to write audit entry: {e}") 

196 raise 

197 

198 def log_operation_batch( 

199 self, 

200 entries: list[dict[str, Any]], 

201 ) -> int: 

202 """Log multiple operations atomically. 

203 

204 Args: 

205 entries: List of dicts with keys matching log_operation params. 

206 

207 Returns: 

208 Number of entries logged. 

209 """ 

210 self._maybe_purge() 

211 

212 timestamp = datetime.now(timezone.utc).isoformat() 

213 count = 0 

214 

215 try: 

216 for entry in entries: 

217 details_json = json.dumps(entry.get("details", {}), default=str) 

218 self._conn.execute( 

219 """INSERT INTO audit_log 

220 (timestamp, user_id, operation, status, details, memory_id, 

221 namespace, client_ip, user_agent, duration_ms, schema_version) 

222 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", 

223 ( 

224 timestamp, 

225 entry["user_id"], 

226 entry["operation"], 

227 entry.get("status", "success"), 

228 details_json, 

229 entry.get("memory_id"), 

230 entry.get("namespace", "default"), 

231 entry.get("client_ip"), 

232 entry.get("user_agent"), 

233 entry.get("duration_ms"), 

234 AUDIT_SCHEMA_VERSION, 

235 ), 

236 ) 

237 count += 1 

238 self._conn.commit() 

239 self._writes_since_purge += count 

240 except sqlite3.Error as e: 

241 logger.error(f"Failed to write batch audit entries: {e}") 

242 self._conn.rollback() 

243 raise 

244 

245 return count 

246 

247 def query( 

248 self, 

249 user_id: str | None = None, 

250 operation: str | None = None, 

251 status: str | None = None, 

252 memory_id: str | None = None, 

253 namespace: str | None = None, 

254 start_time: str | None = None, 

255 end_time: str | None = None, 

256 limit: int = 100, 

257 offset: int = 0, 

258 ) -> list[AuditEntry]: 

259 """Query audit trail with flexible filters. 

260 

261 Args: 

262 user_id: Filter by user ID. 

263 operation: Filter by operation type. 

264 status: Filter by status (success, error, denied). 

265 memory_id: Filter by memory ID. 

266 namespace: Filter by namespace. 

267 start_time: ISO timestamp for start of range (inclusive). 

268 end_time: ISO timestamp for end of range (inclusive). 

269 limit: Maximum number of entries to return. 

270 offset: Offset for pagination. 

271 

272 Returns: 

273 List of matching AuditEntry objects. 

274 """ 

275 conditions: list[str] = [] 

276 params: list[Any] = [] 

277 

278 if user_id: 

279 conditions.append("user_id = ?") 

280 params.append(user_id) 

281 if operation: 

282 conditions.append("operation = ?") 

283 params.append(operation) 

284 if status: 

285 conditions.append("status = ?") 

286 params.append(status) 

287 if memory_id: 

288 conditions.append("memory_id = ?") 

289 params.append(memory_id) 

290 if namespace: 

291 conditions.append("namespace = ?") 

292 params.append(namespace) 

293 if start_time: 

294 conditions.append("timestamp >= ?") 

295 params.append(start_time) 

296 if end_time: 

297 conditions.append("timestamp <= ?") 

298 params.append(end_time) 

299 

300 where_clause = " AND ".join(conditions) if conditions else "1=1" 

301 query_sql = ( 

302 f"SELECT id, timestamp, user_id, operation, status, details, " 

303 f"memory_id, namespace, client_ip, user_agent, duration_ms, " 

304 f"schema_version " 

305 f"FROM audit_log WHERE {where_clause} " 

306 f"ORDER BY timestamp DESC LIMIT ? OFFSET ?" 

307 ) 

308 params.extend([limit, offset]) 

309 

310 try: 

311 cursor = self._conn.execute(query_sql, params) 

312 results: list[AuditEntry] = [] 

313 for row in cursor.fetchall(): 

314 try: 

315 details = json.loads(row[5]) 

316 except (json.JSONDecodeError, TypeError): 

317 details = {} 

318 

319 results.append( 

320 AuditEntry( 

321 id=row[0], 

322 timestamp=row[1], 

323 user_id=row[2], 

324 operation=row[3], 

325 status=row[4], 

326 details=details, 

327 memory_id=row[6], 

328 namespace=row[7], 

329 client_ip=row[8], 

330 user_agent=row[9], 

331 duration_ms=row[10], 

332 schema_version=row[11], 

333 ) 

334 ) 

335 return results 

336 except sqlite3.Error as e: 

337 logger.error(f"Failed to query audit trail: {e}") 

338 raise 

339 

340 def get_user_activity( 

341 self, 

342 user_id: str, 

343 days: int = 30, 

344 ) -> dict[str, Any]: 

345 """Get activity summary for a user. 

346 

347 Args: 

348 user_id: User ID to query. 

349 days: Number of days to look back. 

350 

351 Returns: 

352 Dict with operation counts, last activity timestamp, etc. 

353 """ 

354 from datetime import timedelta 

355 

356 start_time = (datetime.now(timezone.utc) - timedelta(days=days)).isoformat() 

357 

358 try: 

359 cursor = self._conn.execute( 

360 """SELECT operation, COUNT(*) as count 

361 FROM audit_log 

362 WHERE user_id = ? AND timestamp >= ? 

363 GROUP BY operation""", 

364 (user_id, start_time), 

365 ) 

366 operation_counts = {row[0]: row[1] for row in cursor.fetchall()} 

367 

368 cursor = self._conn.execute( 

369 """SELECT MAX(timestamp) FROM audit_log WHERE user_id = ?""", 

370 (user_id,), 

371 ) 

372 last_activity = cursor.fetchone()[0] 

373 

374 cursor = self._conn.execute( 

375 """SELECT COUNT(*) FROM audit_log 

376 WHERE user_id = ? AND timestamp >= ?""", 

377 (user_id, start_time), 

378 ) 

379 total_operations = cursor.fetchone()[0] 

380 

381 return { 

382 "user_id": user_id, 

383 "period_days": days, 

384 "total_operations": total_operations, 

385 "operation_counts": operation_counts, 

386 "last_activity": last_activity, 

387 } 

388 except sqlite3.Error as e: 

389 logger.error(f"Failed to get user activity: {e}") 

390 raise 

391 

392 def get_stats(self) -> dict[str, Any]: 

393 """Get overall audit trail statistics.""" 

394 try: 

395 cursor = self._conn.execute("SELECT COUNT(*) FROM audit_log") 

396 total_entries = cursor.fetchone()[0] 

397 

398 cursor = self._conn.execute("SELECT COUNT(DISTINCT user_id) FROM audit_log") 

399 unique_users = cursor.fetchone()[0] 

400 

401 cursor = self._conn.execute("SELECT MIN(timestamp), MAX(timestamp) FROM audit_log") 

402 row = cursor.fetchone() 

403 first_entry = row[0] 

404 last_entry = row[1] 

405 

406 return { 

407 "total_entries": total_entries, 

408 "unique_users": unique_users, 

409 "first_entry": first_entry, 

410 "last_entry": last_entry, 

411 "retention_days": self._retention_days, 

412 } 

413 except sqlite3.Error as e: 

414 logger.error(f"Failed to get audit stats: {e}") 

415 raise 

416 

417 def export( 

418 self, 

419 start_time: str | None = None, 

420 end_time: str | None = None, 

421 user_id: str | None = None, 

422 ) -> list[dict[str, Any]]: 

423 """Export audit entries as a list of dicts for compliance. 

424 

425 Args: 

426 start_time: ISO timestamp filter. 

427 end_time: ISO timestamp filter. 

428 user_id: Optional user filter. 

429 

430 Returns: 

431 List of dicts suitable for JSON export. 

432 """ 

433 entries = self.query( 

434 user_id=user_id, 

435 start_time=start_time, 

436 end_time=end_time, 

437 limit=100000, # Large limit for exports 

438 ) 

439 return [e.to_dict() for e in entries] 

440 

441 def _maybe_purge(self) -> None: 

442 """Throttled purge: only run if enough time or writes have passed.""" 

443 if not self._auto_purge: 

444 return 

445 now = time.time() 

446 if ( 

447 now - self._last_purge_time < self._purge_interval_seconds 

448 and self._writes_since_purge < self._purge_write_threshold 

449 ): 

450 return 

451 self._last_purge_time = now 

452 self._writes_since_purge = 0 

453 self._purge_old_entries() 

454 

455 def _purge_old_entries(self) -> int: 

456 """Remove entries older than retention_days. 

457 

458 Returns: 

459 Number of entries purged. 

460 """ 

461 if self._retention_days <= 0: 

462 return 0 

463 

464 from datetime import timedelta 

465 

466 cutoff = (datetime.now(timezone.utc) - timedelta(days=self._retention_days)).isoformat() 

467 

468 try: 

469 cursor = self._conn.execute( 

470 "DELETE FROM audit_log WHERE timestamp < ?", 

471 (cutoff,), 

472 ) 

473 self._conn.commit() 

474 deleted = cursor.rowcount 

475 if deleted > 0: 

476 logger.info(f"Purged {deleted} old audit entries (cutoff: {cutoff})") 

477 return deleted 

478 except sqlite3.Error as e: 

479 logger.error(f"Failed to purge old audit entries: {e}") 

480 return 0 

481 

482 def purge_all(self) -> int: 

483 """Purge all audit entries. Use with caution. 

484 

485 Returns: 

486 Number of entries purged. 

487 """ 

488 try: 

489 cursor = self._conn.execute("DELETE FROM audit_log") 

490 self._conn.commit() 

491 deleted = cursor.rowcount 

492 logger.warning(f"Purged ALL audit entries: {deleted}") 

493 return deleted 

494 except sqlite3.Error as e: 

495 logger.error(f"Failed to purge all audit entries: {e}") 

496 return 0 

497 

498 def close(self) -> None: 

499 """Close the audit trail (no-op, connection managed externally).""" 

500 pass