Coverage for src / kemi / api_server.py: 83%

820 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1"""FastAPI REST API server for kemi memory. 

2 

3Optional dependency: install with `pip install fastapi uvicorn` 

4 

5Usage: 

6 from kemi.api_server import create_app 

7 app = create_app() 

8 uvicorn.run(app, host="0.0.0.0", port=8000) 

9 

10Rate Limiting: 

11 Rate limiting is disabled by default. To enable, set: 

12 - KEMI_RATE_LIMIT_REQUESTS: max requests per window (default 100) 

13 - KEMI_RATE_LIMIT_WINDOW: time window in seconds (default 60) 

14 

15Security: 

16 CORS and security headers are disabled by default. To enable, set: 

17 - KEMI_CORS_ORIGINS: comma-separated list of allowed origins 

18 

19API Key Authentication (multi-tenancy): 

20 Disabled by default for backward compatibility. To enable, set: 

21 - KEMI_API_KEY_REQUIRED=true 

22 

23 When the X-API-Key header is sent, the key is hashed with SHA-256 and 

24 looked up in the `api_keys` table. The associated user_id is injected 

25 into request.state.user_id. Endpoints then enforce that any user_id 

26 in the body/path matches the authenticated user, preventing cross- 

27 tenant access. Manage keys via the /api/keys endpoints or 

28 `kemi api-key ...` CLI commands. 

29""" 

30 

31import logging 

32import os 

33import sqlite3 

34import time 

35from collections import defaultdict 

36from contextlib import asynccontextmanager 

37from datetime import datetime, timezone 

38from threading import Lock 

39from typing import Any 

40 

41from kemi import Memory 

42from kemi.models import LifecycleState, MemorySource, MemoryType 

43from kemi.webhooks import WebhookConfig, WebhookEventType, WebhookStore, RetryConfig 

44 

45logger = logging.getLogger(__name__) 

46 

47try: 

48 from fastapi import FastAPI, HTTPException, Request 

49 from fastapi.middleware.cors import CORSMiddleware 

50 from fastapi.middleware.trustedhost import TrustedHostMiddleware 

51 from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse 

52 from pydantic import BaseModel, Field 

53 

54 _FASTAPI_AVAILABLE = True 

55except ImportError: # pragma: no cover 

56 _FASTAPI_AVAILABLE = False 

57 BaseModel = object # type: ignore[no-redef,assignment] 

58 

59 def Field(*a: Any, **kw: Any) -> None: # type: ignore[no-redef] # noqa: N802 

60 return None 

61 

62 

63# Endpoints that never require auth, even when KEMI_API_KEY_REQUIRED=true. 

64# /health is a liveness probe; /api/keys POST bootstraps a brand-new user's 

65# first key (you can't authenticate before you have a key). 

66_AUTH_EXEMPT_PATHS = frozenset({"/health", "/api/keys"}) 

67 

68# Subset of /api/keys paths that are exempt: only the POST (create) and 

69# the health-style subpaths. GET/DELETE still require auth so that keys 

70# can't be enumerated by an anonymous caller. 

71_AUTH_EXEMPT_PREFIXES = tuple() # handled per-route below 

72 

73 

74def _api_key_required() -> bool: 

75 """Whether the server requires X-API-Key on all non-exempt endpoints.""" 

76 return os.environ.get("KEMI_API_KEY_REQUIRED", "false").lower() in ("true", "1", "yes") 

77 

78 

79def _is_exempt(path: str, method: str) -> bool: 

80 """Whether (path, method) is exempt from API-key requirement.""" 

81 if path in _AUTH_EXEMPT_PATHS: 

82 return True 

83 # POST /api/keys is the bootstrap endpoint; GET/DELETE are not. 

84 if path == "/api/keys" and method == "POST": 

85 return True 

86 return False 

87 

88 

89def _resolve_user_id(request: Request, claimed_user_id: str | None) -> str: 

90 """Return the effective user_id, enforcing isolation when authed. 

91 

92 If a valid X-API-Key was presented, ``request.state.user_id`` is set. 

93 - When authed, ``claimed_user_id`` must match it (or be None); a 

94 mismatch raises 403 so a tenant cannot impersonate another. 

95 - When not authed, the caller's claimed_user_id is passed through. 

96 

97 Returns the user_id that the endpoint should use. 

98 """ 

99 authed = getattr(request.state, "user_id", None) 

100 if authed is None: 

101 if claimed_user_id is None: 

102 raise HTTPException( 

103 status_code=400, 

104 detail="user_id is required", 

105 ) 

106 return claimed_user_id 

107 if claimed_user_id is not None and claimed_user_id != authed: 

108 raise HTTPException( 

109 status_code=403, 

110 detail="user_id does not match authenticated user", 

111 ) 

112 return authed 

113 

114 

115def _require_admin(request: Request) -> str: 

116 """Require authentication; return the authed user_id. 

117 

118 Used for endpoints that should only be reachable by an authenticated 

119 caller (e.g. listing other users' keys). The caller is responsible 

120 for any further authorization (e.g. role checks). 

121 """ 

122 authed = getattr(request.state, "user_id", None) 

123 if authed is None: 

124 raise HTTPException(status_code=401, detail="Authentication required") 

125 return authed 

126 

127 

128class RateLimiter: 

129 """Simple in-memory rate limiter for API endpoints. 

130 

131 Uses a sliding window approach with per-key counters. 

132 """ 

133 

134 def __init__( 

135 self, 

136 requests_per_window: int = 100, 

137 window_seconds: int = 60, 

138 ) -> None: 

139 self._requests_per_window = requests_per_window 

140 self._window_seconds = window_seconds 

141 self._requests: dict[str, list[float]] = defaultdict(list) 

142 self._lock = Lock() 

143 

144 def is_allowed(self, key: str) -> bool: 

145 """Check if a request from the given key is allowed. 

146 

147 Args: 

148 key: Identifier for the client (e.g., IP address, user_id). 

149 

150 Returns: 

151 True if the request is allowed, False if rate limited. 

152 """ 

153 now = time.time() 

154 window_start = now - self._window_seconds 

155 

156 with self._lock: 

157 # Clean up old timestamps 

158 self._requests[key] = [ts for ts in self._requests[key] if ts > window_start] 

159 

160 if len(self._requests[key]) >= self._requests_per_window: 

161 return False 

162 

163 self._requests[key].append(now) 

164 return True 

165 

166 def get_retry_after(self, key: str) -> int: 

167 """Get seconds until the rate limit resets for this key.""" 

168 now = time.time() 

169 window_start = now - self._window_seconds 

170 

171 with self._lock: 

172 valid_requests = [ts for ts in self._requests[key] if ts > window_start] 

173 if len(valid_requests) < self._requests_per_window: 

174 return 0 

175 

176 oldest = min(valid_requests) 

177 return int(self._window_seconds - (now - oldest)) + 1 

178 

179 

180# Global rate limiter instance (lazily initialized) 

181_rate_limiter: RateLimiter | None = None 

182 

183 

184def _get_rate_limiter() -> RateLimiter | None: 

185 """Get or create the rate limiter based on environment config.""" 

186 global _rate_limiter 

187 if _rate_limiter is not None: 

188 return _rate_limiter 

189 

190 # Check if rate limiting is enabled via environment variables 

191 enabled = os.environ.get("KEMI_RATE_LIMIT_ENABLED", "false").lower() in ("true", "1", "yes") 

192 if not enabled: 

193 return None 

194 

195 requests = int(os.environ.get("KEMI_RATE_LIMIT_REQUESTS", "100")) 

196 window = int(os.environ.get("KEMI_RATE_LIMIT_WINDOW", "60")) 

197 

198 _rate_limiter = RateLimiter(requests_per_window=requests, window_seconds=window) 

199 logger.info(f"Rate limiting enabled: {requests} requests per {window} seconds") 

200 return _rate_limiter 

201 

202 

203def _check_rate_limit(client_key: str) -> tuple[bool, int]: 

204 """Check rate limit for a client key. 

205 

206 Returns: 

207 Tuple of (is_allowed, retry_after_seconds). 

208 retry_after_seconds is 0 if allowed. 

209 """ 

210 limiter = _get_rate_limiter() 

211 if limiter is None: 

212 return True, 0 

213 

214 if limiter.is_allowed(client_key): 

215 return True, 0 

216 

217 return False, limiter.get_retry_after(client_key) 

218 

219 

220# Cached APIKeyManager, lazily built from the active memory's storage. 

221_api_key_manager: Any = None 

222_api_key_manager_lock = Lock() 

223 

224 

225def _get_api_key_manager() -> Any: 

226 """Return a cached APIKeyManager bound to the active memory's storage. 

227 

228 Returns None when the storage adapter doesn't support API keys 

229 (e.g. in-memory mock). Endpoints should treat None as 501. 

230 """ 

231 global _api_key_manager 

232 if _api_key_manager is not None: 

233 return _api_key_manager 

234 with _api_key_manager_lock: 

235 if _api_key_manager is not None: 

236 return _api_key_manager 

237 mem = _get_memory_singleton() 

238 store = getattr(mem, "_store", None) 

239 # Prefer a real connection if the adapter exposes one; otherwise 

240 # fall back to the storage adapter's helper. 

241 get_mgr = getattr(store, "get_api_key_manager", None) 

242 if callable(get_mgr): 

243 _api_key_manager = get_mgr() 

244 return _api_key_manager 

245 # Last resort: try to share a SQLite connection. Mocks won't have one. 

246 get_conn = getattr(store, "_get_connection", None) 

247 if callable(get_conn): 

248 from kemi.api_keys import APIKeyManager 

249 

250 _api_key_manager = APIKeyManager(connection=get_conn()) 

251 return _api_key_manager 

252 return None 

253 

254 

255def _reset_api_key_manager() -> None: 

256 """Clear the cached manager. Used by tests to swap storage backends.""" 

257 global _api_key_manager 

258 with _api_key_manager_lock: 

259 _api_key_manager = None 

260 

261 

262class RememberRequest(BaseModel): 

263 user_id: str = Field(..., min_length=1) 

264 content: str = Field(..., min_length=1) 

265 importance: float = Field(0.5, ge=0.0, le=1.0) 

266 source: str = "user_stated" 

267 tags: list[str] | None = None 

268 namespace: str = "default" 

269 session_id: str | None = None 

270 memory_type: str = "episodic" 

271 confidence: float = Field(1.0, ge=0.0, le=1.0) 

272 

273 

274class RecallRequest(BaseModel): 

275 user_id: str = Field(..., min_length=1) 

276 query: str = Field(..., min_length=1) 

277 top_k: int = Field(5, ge=1) 

278 max_tokens: int | None = None 

279 namespace: str = "default" 

280 session_id: str | None = None 

281 hybrid_search: bool | None = None 

282 

283 

284class UpdateRequest(BaseModel): 

285 content: str | None = None 

286 importance: float | None = Field(None, ge=0.0, le=1.0) 

287 confidence: float | None = Field(None, ge=0.0, le=1.0) 

288 memory_type: str | None = None 

289 

290 

291class PruneRequest(BaseModel): 

292 max_age_days: float | None = None 

293 min_importance: float | None = None 

294 lifecycle_states: list[str] | None = None 

295 namespace: str = "default" 

296 

297 

298class ConsolidateRequest(BaseModel): 

299 namespace: str = "default" 

300 min_memories: int = 5 

301 max_age_days: float = 30.0 

302 

303 

304class TopicsRequest(BaseModel): 

305 n_clusters: int = 3 

306 namespace: str = "default" 

307 

308 

309class GraphRequest(BaseModel): 

310 namespace: str = "default" 

311 

312 

313class FeedbackRequest(BaseModel): 

314 memory_id: str 

315 helpful: bool = True 

316 namespace: str = "default" 

317 

318 

319class BatchRememberRequest(BaseModel): 

320 """Request for background batch remember operation.""" 

321 

322 user_id: str = Field(..., min_length=1) 

323 contents: list[str] = Field(..., min_length=1) 

324 importance: float = Field(0.5, ge=0.0, le=1.0) 

325 namespace: str = "default" 

326 

327 

328class RebuildFTSRequest(BaseModel): 

329 """Request for background FTS index rebuild.""" 

330 

331 user_id: str | None = None # Optional: rebuild for specific user only 

332 

333 

334class AdminFTSStatsRequest(BaseModel): 

335 """Request for FTS index statistics.""" 

336 

337 user_id: str | None = None # Optional: get stats for specific user only 

338 

339 

340class AdminFTSRepairRequest(BaseModel): 

341 """Request for FTS index integrity repair.""" 

342 

343 verify_only: bool = False # If True, only verify without repairing 

344 

345 

346class AuditLogRequest(BaseModel): 

347 """Request to log an audit entry.""" 

348 

349 user_id: str = Field(..., min_length=1) 

350 operation: str = Field(..., min_length=1) 

351 status: str = "success" 

352 details: dict[str, Any] | None = None 

353 memory_id: str | None = None 

354 namespace: str = "default" 

355 client_ip: str | None = None 

356 user_agent: str | None = None 

357 duration_ms: float | None = None 

358 

359 

360class AuditQueryRequest(BaseModel): 

361 """Request to query the audit trail.""" 

362 

363 user_id: str | None = None 

364 operation: str | None = None 

365 status: str | None = None 

366 memory_id: str | None = None 

367 namespace: str | None = None 

368 start_time: str | None = None 

369 end_time: str | None = None 

370 limit: int = Field(100, ge=1, le=10000) 

371 offset: int = Field(0, ge=0) 

372 

373 

374class AuditExportRequest(BaseModel): 

375 """Request to export audit entries.""" 

376 

377 start_time: str | None = None 

378 end_time: str | None = None 

379 user_id: str | None = None 

380 

381 

382class AdaptiveAnalyzeRequest(BaseModel): 

383 """Request to analyze a query for adaptive retrieval.""" 

384 

385 query: str = Field(..., min_length=1) 

386 

387 

388class EnableFeatureRequest(BaseModel): 

389 """Request to enable or disable a feature.""" 

390 

391 enable: bool = True 

392 retention_days: int = Field(365, ge=1) 

393 auto_purge: bool = True 

394 

395 

396class CreateAPIKeyRequest(BaseModel): 

397 """Request to create a new API key.""" 

398 

399 user_id: str = Field(..., min_length=1) 

400 name: str = Field(..., min_length=1, max_length=200) 

401 expires_in_days: int | None = Field(None, ge=1, le=36500) 

402 

403 

404# Global memory instance for lifespan management 

405_memory_instance: Memory | None = None 

406 

407 

408def _get_memory_singleton() -> Memory: 

409 """Get or create a singleton Memory instance.""" 

410 global _memory_instance 

411 if _memory_instance is not None: 

412 return _memory_instance 

413 

414 db_path = os.environ.get("KEMI_DB_PATH", os.path.expanduser("~/.kemi/memories.db")) 

415 os.makedirs(os.path.dirname(db_path), exist_ok=True) 

416 _memory_instance = Memory() 

417 return _memory_instance 

418 

419 

420@asynccontextmanager 

421async def lifespan(app: FastAPI) -> Any: 

422 """Lifespan context manager for graceful startup and shutdown. 

423 

424 On startup: Initialize the memory instance 

425 On shutdown: Properly close database connections 

426 """ 

427 # Startup 

428 logger.info("Starting kemi API server...") 

429 _get_memory_singleton() 

430 db_path = os.environ.get("KEMI_DB_PATH", "~/.kemi/memories.db") 

431 logger.info(f"Memory instance initialized with DB: {db_path}") 

432 

433 yield # Application runs here 

434 

435 # Shutdown 

436 logger.info("Shutting down kemi API server...") 

437 global _memory_instance 

438 if _memory_instance is not None: 

439 # Close the storage adapter connection 

440 try: 

441 store = getattr(_memory_instance, "_store", None) 

442 if store is not None and hasattr(store, "close"): 

443 store.close() 

444 logger.info("Database connections closed") 

445 except Exception as e: 

446 logger.error(f"Error closing database connections: {e}") 

447 _memory_instance = None 

448 

449 logger.info("kemi API server shutdown complete") 

450 

451 

452def create_app(memory: Memory | None = None) -> Any: 

453 """Create a FastAPI application wrapping a kemi Memory instance. 

454 

455 Args: 

456 memory: Optional pre-configured Memory instance. 

457 If None, creates a default Memory. 

458 

459 Returns: 

460 FastAPI app instance. 

461 """ 

462 if not _FASTAPI_AVAILABLE: 

463 raise ImportError( 

464 "FastAPI is required for the API server. Install with: pip install fastapi uvicorn" 

465 ) 

466 

467 app = FastAPI( 

468 title="kemi API", 

469 version="0.3.0", 

470 lifespan=lifespan, 

471 ) 

472 

473 # Configure CORS if origins are specified 

474 cors_origins = os.environ.get("KEMI_CORS_ORIGINS", "") 

475 if cors_origins: 

476 origins = [o.strip() for o in cors_origins.split(",") if o.strip()] 

477 if origins: 

478 app.add_middleware( 

479 CORSMiddleware, 

480 allow_origins=origins, 

481 allow_credentials=True, 

482 allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE"], 

483 allow_headers=["Authorization", "Content-Type"], 

484 ) 

485 logger.info(f"CORS enabled for origins: {origins}") 

486 

487 # Add security headers middleware for production 

488 trusted_hosts = os.environ.get("KEMI_TRUSTED_HOSTS", "") 

489 if trusted_hosts: 

490 hosts = [h.strip() for h in trusted_hosts.split(",") if h.strip()] 

491 if hosts: 

492 app.add_middleware(TrustedHostMiddleware, allowed_hosts=hosts) 

493 logger.info(f"Trusted host middleware enabled for: {hosts}") 

494 

495 @app.middleware("http") 

496 async def api_key_middleware(request: Request, call_next: Any) -> Any: 

497 """Validate X-API-Key and inject request.state.user_id. 

498 

499 Behaviour: 

500 - If X-API-Key is provided, hash it and look it up. On success 

501 the key's user_id is attached to request.state. On failure 

502 (unknown / expired / revoked) return 401. 

503 - If X-API-Key is absent: 

504 - Endpoints in _AUTH_EXEMPT_PATHS are always allowed. 

505 - All other endpoints return 401 when KEMI_API_KEY_REQUIRED=true. 

506 - When KEMI_API_KEY_REQUIRED=false (default), the request 

507 proceeds unauthenticated for backward compatibility. 

508 """ 

509 # Don't try to validate against a missing storage adapter 

510 # (e.g. in some test harnesses); behave as if no auth is available. 

511 header = request.headers.get("X-API-Key") 

512 path = request.url.path 

513 method = request.method 

514 

515 if header: 

516 manager = _get_api_key_manager() 

517 if manager is None: 

518 return JSONResponse( 

519 status_code=501, 

520 content={"detail": "API key authentication not supported by this storage"}, 

521 ) 

522 key = manager.lookup(header) 

523 if key is None: 

524 return JSONResponse( 

525 status_code=401, 

526 content={"detail": "Invalid or expired API key"}, 

527 ) 

528 request.state.user_id = key.user_id 

529 request.state.api_key_id = key.key_id 

530 else: 

531 if _api_key_required() and not _is_exempt(path, method): 

532 return JSONResponse( 

533 status_code=401, 

534 content={"detail": "X-API-Key header required"}, 

535 ) 

536 

537 return await call_next(request) 

538 

539 # Expose auth state on the request state for endpoints that want to 

540 # introspect it (e.g. /api/keys GET to scope listings). 

541 @app.middleware("http") 

542 async def _ensure_state_defaults(request: Request, call_next: Any) -> Any: 

543 if not hasattr(request.state, "user_id"): 

544 request.state.user_id = None 

545 if not hasattr(request.state, "api_key_id"): 

546 request.state.api_key_id = None 

547 return await call_next(request) 

548 

549 mem = memory or _get_memory_singleton() 

550 

551 @app.post("/remember") 

552 async def remember(req: RememberRequest, request: Request) -> dict[str, Any]: 

553 effective_user = _resolve_user_id(request, req.user_id) 

554 # Rate limit check 

555 allowed, retry_after = _check_rate_limit(effective_user) 

556 if not allowed: 

557 raise HTTPException( 

558 status_code=429, 

559 detail=( 

560 f"Rate limit exceeded. Retry after {retry_after} seconds." 

561 ), 

562 headers={"Retry-After": str(retry_after)}, 

563 ) 

564 

565 try: 

566 source = MemorySource(req.source) 

567 mtype = MemoryType(req.memory_type) 

568 except ValueError as err: 

569 raise HTTPException(status_code=400, detail=str(err)) from err 

570 

571 mid = mem.remember( 

572 user_id=effective_user, 

573 content=req.content, 

574 importance=req.importance, 

575 source=source, 

576 tags=req.tags, 

577 namespace=req.namespace, 

578 session_id=req.session_id, 

579 memory_type=mtype, 

580 confidence=req.confidence, 

581 ) 

582 return {"memory_id": mid} 

583 

584 @app.post("/recall") 

585 async def recall(req: RecallRequest, request: Request) -> dict[str, Any]: 

586 effective_user = _resolve_user_id(request, req.user_id) 

587 # Rate limit check 

588 allowed, retry_after = _check_rate_limit(effective_user) 

589 if not allowed: 

590 raise HTTPException( 

591 status_code=429, 

592 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.", 

593 headers={"Retry-After": str(retry_after)}, 

594 ) 

595 

596 try: 

597 results = mem.recall( 

598 user_id=effective_user, 

599 query=req.query, 

600 top_k=req.top_k, 

601 max_tokens=req.max_tokens, 

602 namespace=req.namespace, 

603 session_id=req.session_id, 

604 hybrid_search=req.hybrid_search, 

605 ) 

606 except ValueError as e: 

607 raise HTTPException(status_code=400, detail=str(e)) from e 

608 

609 return { 

610 "results": [ 

611 { 

612 "memory_id": r.memory_id, 

613 "content": r.content, 

614 "score": r.score, 

615 "importance": r.importance, 

616 "lifecycle_state": r.lifecycle_state.value, 

617 "created_at": r.created_at.isoformat() if r.created_at else None, 

618 "tags": r.tags, 

619 "memory_type": r.memory_type.value, 

620 "confidence": r.confidence, 

621 "session_id": r.session_id, 

622 "namespace": r.namespace, 

623 "version": r.version, 

624 } 

625 for r in results 

626 ] 

627 } 

628 

629 @app.post("/recall/stream") 

630 async def recall_stream(req: RecallRequest, request: Request) -> StreamingResponse: 

631 """Stream recall results as Server-Sent Events.""" 

632 import json 

633 

634 effective_user = _resolve_user_id(request, req.user_id) 

635 

636 # Rate limit check 

637 allowed, retry_after = _check_rate_limit(effective_user) 

638 if not allowed: 

639 raise HTTPException( 

640 status_code=429, 

641 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.", 

642 headers={"Retry-After": str(retry_after)}, 

643 ) 

644 

645 async def _generate() -> Any: 

646 count = 0 

647 try: 

648 stream = mem.recall_stream( 

649 user_id=effective_user, 

650 query=req.query, 

651 top_k=req.top_k, 

652 max_tokens=req.max_tokens, 

653 namespace=req.namespace, 

654 session_id=req.session_id, 

655 hybrid_search=req.hybrid_search, 

656 ) 

657 async for result in stream: 

658 count += 1 

659 payload = { 

660 "memory_id": result.memory_id, 

661 "content": result.content, 

662 "score": result.score, 

663 } 

664 yield f"data: {json.dumps(payload)}\n\n" 

665 except ValueError as e: 

666 yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n" 

667 return 

668 

669 yield f"event: done\ndata: {json.dumps({'total': count})}\n\n" 

670 

671 return StreamingResponse( 

672 _generate(), 

673 media_type="text/event-stream", 

674 ) 

675 

676 @app.post("/recall-explain") 

677 async def recall_explain(req: RecallRequest, request: Request) -> dict[str, Any]: 

678 effective_user = _resolve_user_id(request, req.user_id) 

679 # Rate limit check 

680 allowed, retry_after = _check_rate_limit(effective_user) 

681 if not allowed: 

682 raise HTTPException( 

683 status_code=429, 

684 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.", 

685 headers={"Retry-After": str(retry_after)}, 

686 ) 

687 

688 try: 

689 explained = mem.recall_explain( 

690 user_id=effective_user, 

691 query=req.query, 

692 top_k=req.top_k, 

693 namespace=req.namespace, 

694 session_id=req.session_id, 

695 ) 

696 except ValueError as e: 

697 raise HTTPException(status_code=400, detail=str(e)) from e 

698 

699 return { 

700 "results": [ 

701 { 

702 "memory": { 

703 "memory_id": item["memory"].memory_id, 

704 "content": item["memory"].content, 

705 "score": item["memory"].score, 

706 }, 

707 "explanation": item["explanation"], 

708 } 

709 for item in explained 

710 ] 

711 } 

712 

713 @app.post("/forget") 

714 async def forget( 

715 request: Request, 

716 user_id: str, 

717 memory_id: str | None = None, 

718 ) -> dict[str, Any]: 

719 effective_user = _resolve_user_id(request, user_id) 

720 # Rate limit check 

721 allowed, retry_after = _check_rate_limit(effective_user) 

722 if not allowed: 

723 raise HTTPException( 

724 status_code=429, 

725 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.", 

726 headers={"Retry-After": str(retry_after)}, 

727 ) 

728 

729 count = mem.forget(effective_user, memory_id) 

730 return {"deleted": count} 

731 

732 @app.patch("/memories/{memory_id}") 

733 async def update_memory( 

734 memory_id: str, 

735 req: UpdateRequest, 

736 request: Request, 

737 ) -> dict[str, Any]: 

738 # We need the user_id to rate-limit per-user; look up the memory. 

739 existing = mem._store.get(memory_id) 

740 if existing is None: 

741 raise HTTPException(status_code=404, detail=f"Memory not found: {memory_id}") 

742 effective_user = _resolve_user_id(request, existing.user_id) 

743 allowed, retry_after = _check_rate_limit(effective_user) 

744 if not allowed: 

745 raise HTTPException( 

746 status_code=429, 

747 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.", 

748 headers={"Retry-After": str(retry_after)}, 

749 ) 

750 

751 try: 

752 mtype = None 

753 if req.memory_type: 

754 mtype = MemoryType(req.memory_type) 

755 except ValueError as e: 

756 raise HTTPException(status_code=400, detail=str(e)) from e 

757 

758 try: 

759 mem.update( 

760 memory_id=memory_id, 

761 content=req.content, 

762 importance=req.importance, 

763 confidence=req.confidence, 

764 memory_type=mtype, 

765 ) 

766 except ValueError as e: 

767 raise HTTPException(status_code=404, detail=str(e)) from e 

768 return {"memory_id": memory_id, "status": "updated"} 

769 

770 @app.post("/prune") 

771 async def prune( 

772 request: Request, 

773 user_id: str, 

774 req: PruneRequest, 

775 ) -> dict[str, Any]: 

776 effective_user = _resolve_user_id(request, user_id) 

777 allowed, retry_after = _check_rate_limit(effective_user) 

778 if not allowed: 

779 raise HTTPException( 

780 status_code=429, 

781 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.", 

782 headers={"Retry-After": str(retry_after)}, 

783 ) 

784 

785 lifecycle_filter = None 

786 if req.lifecycle_states: 

787 try: 

788 lifecycle_filter = [LifecycleState(s) for s in req.lifecycle_states] 

789 except ValueError as e: 

790 raise HTTPException(status_code=400, detail=str(e)) from e 

791 

792 deleted = mem.prune( 

793 user_id=effective_user, 

794 max_age_days=req.max_age_days, 

795 min_importance=req.min_importance, 

796 lifecycle_states=lifecycle_filter, 

797 namespace=req.namespace, 

798 ) 

799 return {"deleted": deleted} 

800 

801 @app.get("/stats/{user_id}") 

802 async def stats(user_id: str, request: Request) -> dict[str, Any]: 

803 effective_user = _resolve_user_id(request, user_id) 

804 # Rate limit check 

805 allowed, retry_after = _check_rate_limit(effective_user) 

806 if not allowed: 

807 raise HTTPException( 

808 status_code=429, 

809 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.", 

810 headers={"Retry-After": str(retry_after)}, 

811 ) 

812 

813 try: 

814 return mem.stats(effective_user) 

815 except ValueError as e: 

816 raise HTTPException(status_code=400, detail=str(e)) from e 

817 

818 @app.get("/users") 

819 async def list_users(request: Request) -> dict[str, Any]: 

820 # If authed, restrict to the caller's own user_id. 

821 authed = getattr(request.state, "user_id", None) 

822 users = mem.list_users() 

823 if authed is not None: 

824 users = [u for u in users if u == authed] 

825 return {"users": users} 

826 

827 @app.post("/consolidate/{user_id}") 

828 async def consolidate_user( 

829 user_id: str, 

830 req: ConsolidateRequest, 

831 request: Request, 

832 ) -> dict[str, Any]: 

833 effective_user = _resolve_user_id(request, user_id) 

834 allowed, retry_after = _check_rate_limit(effective_user) 

835 if not allowed: 

836 raise HTTPException( 

837 status_code=429, 

838 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.", 

839 headers={"Retry-After": str(retry_after)}, 

840 ) 

841 try: 

842 mid = mem.consolidate( 

843 user_id=effective_user, 

844 namespace=req.namespace, 

845 min_memories=req.min_memories, 

846 max_age_days=req.max_age_days, 

847 ) 

848 except ValueError as e: 

849 raise HTTPException(status_code=400, detail=str(e)) from e 

850 if mid: 

851 return {"consolidated_memory_id": mid} 

852 return {"message": "No consolidation needed"} 

853 

854 @app.post("/topics/{user_id}") 

855 async def topics_user( 

856 user_id: str, 

857 req: TopicsRequest, 

858 request: Request, 

859 ) -> dict[str, Any]: 

860 effective_user = _resolve_user_id(request, user_id) 

861 allowed, retry_after = _check_rate_limit(effective_user) 

862 if not allowed: 

863 raise HTTPException( 

864 status_code=429, 

865 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.", 

866 headers={"Retry-After": str(retry_after)}, 

867 ) 

868 try: 

869 clusters = mem.cluster_topics( 

870 user_id=effective_user, 

871 n_clusters=req.n_clusters, 

872 namespace=req.namespace, 

873 ) 

874 except ValueError as e: 

875 raise HTTPException(status_code=400, detail=str(e)) from e 

876 return { 

877 "topics": { 

878 label: [ 

879 { 

880 "memory_id": m.memory_id, 

881 "content": m.content, 

882 "importance": m.importance, 

883 } 

884 for m in mems 

885 ] 

886 for label, mems in clusters.items() 

887 } 

888 } 

889 

890 @app.post("/graph/{user_id}") 

891 async def graph_user( 

892 user_id: str, 

893 req: GraphRequest, 

894 request: Request, 

895 ) -> dict[str, Any]: 

896 effective_user = _resolve_user_id(request, user_id) 

897 allowed, retry_after = _check_rate_limit(effective_user) 

898 if not allowed: 

899 raise HTTPException( 

900 status_code=429, 

901 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.", 

902 headers={"Retry-After": str(retry_after)}, 

903 ) 

904 try: 

905 graph_data = mem.get_memory_graph( 

906 user_id=effective_user, 

907 namespace=req.namespace, 

908 ) 

909 except ValueError as e: 

910 raise HTTPException(status_code=400, detail=str(e)) from e 

911 return graph_data 

912 

913 @app.post("/feedback/{user_id}") 

914 async def feedback_user( 

915 user_id: str, 

916 req: FeedbackRequest, 

917 request: Request, 

918 ) -> dict[str, Any]: 

919 effective_user = _resolve_user_id(request, user_id) 

920 allowed, retry_after = _check_rate_limit(effective_user) 

921 if not allowed: 

922 raise HTTPException( 

923 status_code=429, 

924 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.", 

925 headers={"Retry-After": str(retry_after)}, 

926 ) 

927 try: 

928 mem.feedback( 

929 user_id=effective_user, 

930 memory_id=req.memory_id, 

931 helpful=req.helpful, 

932 namespace=req.namespace, 

933 ) 

934 except (ValueError, AttributeError) as e: 

935 raise HTTPException(status_code=400, detail=str(e)) from e 

936 return {"status": "ok", "memory_id": req.memory_id, "helpful": req.helpful} 

937 

938 @app.get("/health") 

939 async def health() -> dict[str, Any]: 

940 """Enhanced health check endpoint. 

941 

942 Returns: 

943 status: "ok" if healthy, "degraded" if issues detected 

944 components: Dictionary of component statuses 

945 timestamp: ISO timestamp of the health check 

946 """ 

947 components: dict[str, Any] = {} 

948 overall_healthy = True 

949 mem = memory or _get_memory_singleton() 

950 

951 # Check database connectivity 

952 try: 

953 if hasattr(mem, "_store") and mem._store: 

954 conn = mem._store._get_connection() # type: ignore[attr-defined] 

955 cursor = conn.execute("SELECT 1") 

956 cursor.fetchone() 

957 components["database"] = {"status": "healthy", "type": "sqlite"} 

958 else: 

959 components["database"] = {"status": "unknown", "message": "Storage not initialized"} 

960 overall_healthy = False 

961 except (sqlite3.Error, AttributeError) as e: 

962 components["database"] = {"status": "unhealthy", "error": str(e)} 

963 overall_healthy = False 

964 

965 # Check embedding adapter availability 

966 try: 

967 embed = getattr(mem, "_embed", None) 

968 if embed is not None: 

969 components["embedding"] = {"status": "healthy", "adapter": type(embed).__name__} 

970 else: 

971 components["embedding"] = { 

972 "status": "not_configured", 

973 "message": "No embedding adapter", 

974 } 

975 except AttributeError as e: 

976 components["embedding"] = {"status": "unhealthy", "error": str(e)} 

977 overall_healthy = False 

978 

979 status = "ok" if overall_healthy else "degraded" 

980 

981 return { 

982 "status": status, 

983 "components": components, 

984 "timestamp": datetime.now(timezone.utc).isoformat(), 

985 } 

986 

987 # Background Task Endpoints 

988 

989 @app.post("/tasks/embed-batch") 

990 async def submit_embed_batch_task( 

991 req: BatchRememberRequest, request: Request 

992 ) -> dict[str, Any]: 

993 """Submit a batch embedding task to run in background. 

994 

995 This endpoint returns immediately with a task_id that can be used 

996 to check the task status at /tasks/{task_id}. 

997 

998 Args: 

999 req: BatchRememberRequest with user_id, contents, etc. 

1000 

1001 Returns: 

1002 Dict with task_id for tracking progress. 

1003 """ 

1004 effective_user = _resolve_user_id(request, req.user_id) 

1005 from kemi.background_tasks import get_task_manager 

1006 

1007 task_manager = get_task_manager() 

1008 try: 

1009 task_id = task_manager.submit_embed_batch( 

1010 user_id=effective_user, 

1011 contents=req.contents, 

1012 importance=req.importance, 

1013 namespace=req.namespace, 

1014 ) 

1015 return {"task_id": task_id, "status": "pending"} 

1016 except RuntimeError as err: 

1017 raise HTTPException(status_code=429, detail=str(err)) from err 

1018 

1019 @app.post("/tasks/rebuild-fts") 

1020 async def submit_rebuild_fts_task( 

1021 req: RebuildFTSRequest, request: Request 

1022 ) -> dict[str, Any]: 

1023 """Submit an FTS index rebuild task to run in background. 

1024 

1025 This endpoint returns immediately with a task_id that can be used 

1026 to check the task status at /tasks/{task_id}. 

1027 

1028 Args: 

1029 req: RebuildFTSRequest with optional user_id filter. 

1030 

1031 Returns: 

1032 Dict with task_id for tracking progress. 

1033 """ 

1034 # Optional user_id: when authed, we always scope to the authed user 

1035 # regardless of what the body says (prevents cross-tenant rebuilds). 

1036 authed = getattr(request.state, "user_id", None) 

1037 target_user: str | None 

1038 if authed is not None: 

1039 if req.user_id is not None and req.user_id != authed: 

1040 raise HTTPException( 

1041 status_code=403, 

1042 detail="user_id does not match authenticated user", 

1043 ) 

1044 target_user = authed 

1045 else: 

1046 target_user = req.user_id 

1047 

1048 from kemi.background_tasks import get_task_manager 

1049 

1050 task_manager = get_task_manager() 

1051 try: 

1052 task_id = task_manager.submit_rebuild_fts_index(user_id=target_user) 

1053 return {"task_id": task_id, "status": "pending"} 

1054 except RuntimeError as err: 

1055 raise HTTPException(status_code=429, detail=str(err)) from err 

1056 

1057 @app.get("/tasks/stats") 

1058 async def get_task_stats() -> dict[str, Any]: 

1059 """Get background task manager statistics. 

1060 

1061 Returns: 

1062 Dict with counts of pending, running, completed, failed tasks. 

1063 """ 

1064 from kemi.background_tasks import get_task_manager 

1065 

1066 task_manager = get_task_manager() 

1067 return task_manager.get_stats() 

1068 

1069 @app.get("/tasks/{task_id}") 

1070 async def get_task_status(task_id: str) -> dict[str, Any]: 

1071 """Get the status of a background task. 

1072 

1073 Args: 

1074 task_id: The task ID returned from submit_* endpoints. 

1075 

1076 Returns: 

1077 Task status including progress, result, or error. 

1078 """ 

1079 from kemi.background_tasks import get_task_manager 

1080 

1081 task_manager = get_task_manager() 

1082 task = task_manager.get_task_status(task_id) 

1083 

1084 if task is None: 

1085 raise HTTPException(status_code=404, detail=f"Task not found: {task_id}") 

1086 

1087 return task.to_dict() 

1088 

1089 @app.get("/tasks") 

1090 async def list_tasks( 

1091 status: str | None = None, 

1092 limit: int = 50, 

1093 ) -> dict[str, Any]: 

1094 """List all background tasks. 

1095 

1096 Args: 

1097 status: Optional filter by status (pending, running, completed, failed). 

1098 limit: Maximum number of tasks to return (default 50). 

1099 

1100 Returns: 

1101 List of tasks with their statuses. 

1102 """ 

1103 from kemi.background_tasks import TaskStatus, get_task_manager 

1104 

1105 task_manager = get_task_manager() 

1106 

1107 filter_status = None 

1108 if status: 

1109 try: 

1110 filter_status = TaskStatus(status) 

1111 except ValueError as exc: 

1112 valid_vals = "pending, running, completed, failed" 

1113 raise HTTPException( 

1114 status_code=400, 

1115 detail=f"Invalid status: {status}. Valid values: {valid_vals}", 

1116 ) from exc 

1117 

1118 tasks = task_manager.list_tasks(status=filter_status, limit=limit) 

1119 return { 

1120 "tasks": [t.to_dict() for t in tasks], 

1121 "stats": task_manager.get_stats(), 

1122 } 

1123 

1124 @app.delete("/tasks/{task_id}") 

1125 async def cancel_task(task_id: str) -> dict[str, Any]: 

1126 """Cancel a pending background task. 

1127 

1128 Note: Running tasks cannot be cancelled mid-execution. 

1129 

1130 Args: 

1131 task_id: The task ID to cancel. 

1132 

1133 Returns: 

1134 Dict with success status. 

1135 """ 

1136 from kemi.background_tasks import get_task_manager 

1137 

1138 task_manager = get_task_manager() 

1139 cancelled = task_manager.cancel_task(task_id) 

1140 

1141 if not cancelled: 

1142 raise HTTPException( 

1143 status_code=400, 

1144 detail="Cannot cancel task: not found or already running", 

1145 ) 

1146 

1147 return {"task_id": task_id, "cancelled": True} 

1148 

1149 # Admin Endpoints for Index Maintenance 

1150 

1151 @app.post("/admin/fts/rebuild") 

1152 async def admin_rebuild_fts() -> dict[str, Any]: 

1153 """Admin endpoint to rebuild FTS5 index synchronously. 

1154 

1155 This is a blocking operation that rebuilds the full-text search index. 

1156 For large datasets, consider using the background task endpoint instead. 

1157 

1158 Returns: 

1159 Dict with rebuild statistics. 

1160 """ 

1161 mem = _get_memory_singleton() 

1162 

1163 if not hasattr(mem._store, "rebuild_fts_index"): 

1164 raise HTTPException( 

1165 status_code=501, 

1166 detail="Storage adapter does not support FTS index rebuild", 

1167 ) 

1168 

1169 try: 

1170 # Rebuild entire FTS index 

1171 count = mem._store.rebuild_fts_index() 

1172 return { 

1173 "status": "completed", 

1174 "memories_indexed": count, 

1175 "scope": "all", 

1176 } 

1177 except Exception as e: 

1178 logger.error(f"FTS rebuild failed: {e}") 

1179 raise HTTPException(status_code=500, detail=str(e)) from e 

1180 

1181 @app.get("/admin/fts/stats") 

1182 async def admin_fts_stats( 

1183 user_id: str | None = None, request: Request = None # type: ignore[assignment] 

1184 ) -> dict[str, Any]: 

1185 """Admin endpoint to get FTS5 index statistics. 

1186 

1187 Args: 

1188 user_id: Optional user ID to get stats for specific user. 

1189 

1190 Returns: 

1191 Dict with FTS index statistics. 

1192 """ 

1193 # If authed, restrict the stats view to the caller's own user. 

1194 if request is not None: 

1195 authed = getattr(request.state, "user_id", None) 

1196 if authed is not None: 

1197 if user_id is not None and user_id != authed: 

1198 raise HTTPException( 

1199 status_code=403, 

1200 detail="user_id does not match authenticated user", 

1201 ) 

1202 user_id = authed 

1203 

1204 mem = _get_memory_singleton() 

1205 

1206 try: 

1207 conn = mem._store._get_connection() # type: ignore[attr-defined] 

1208 

1209 # Get total FTS entries 

1210 cursor = conn.execute("SELECT COUNT(*) FROM memories_fts") 

1211 fts_total = cursor.fetchone()[0] 

1212 

1213 # Get total memories in main table 

1214 if user_id: 

1215 cursor = conn.execute("SELECT COUNT(*) FROM memories WHERE user_id = ?", (user_id,)) 

1216 mem_total = cursor.fetchone()[0] 

1217 

1218 # Get FTS entries for this user 

1219 cursor = conn.execute( 

1220 "SELECT COUNT(*) FROM memories_fts WHERE user_id = ?", (user_id,) 

1221 ) 

1222 fts_user = cursor.fetchone()[0] 

1223 

1224 in_sync = (fts_user == mem_total) if mem_total > 0 else True 

1225 

1226 return { 

1227 "fts_total_entries": fts_total, 

1228 "user_id": user_id, 

1229 "user_memories": mem_total, 

1230 "user_fts_entries": fts_user, 

1231 "in_sync": in_sync, 

1232 "sync_gap": mem_total - fts_user, 

1233 } 

1234 else: 

1235 cursor = conn.execute("SELECT COUNT(*) FROM memories") 

1236 mem_total = cursor.fetchone()[0] 

1237 

1238 in_sync = (fts_total == mem_total) if mem_total > 0 else True 

1239 

1240 return { 

1241 "fts_total_entries": fts_total, 

1242 "total_memories": mem_total, 

1243 "in_sync": in_sync, 

1244 "sync_gap": mem_total - fts_total, 

1245 } 

1246 except Exception as e: 

1247 logger.error(f"FTS stats failed: {e}") 

1248 raise HTTPException(status_code=500, detail=str(e)) from e 

1249 

1250 @app.post("/admin/fts/verify") 

1251 async def admin_fts_verify(req: AdminFTSRepairRequest) -> dict[str, Any]: 

1252 """Admin endpoint to verify FTS5 index integrity. 

1253 

1254 Checks if all memories have corresponding FTS entries and vice versa. 

1255 

1256 Args: 

1257 req: AdminFTSRepairRequest with verify_only flag. 

1258 

1259 Returns: 

1260 Dict with verification results. 

1261 """ 

1262 mem = _get_memory_singleton() 

1263 

1264 try: 

1265 conn = mem._store._get_connection() # type: ignore[attr-defined] 

1266 

1267 # Get all memory IDs 

1268 cursor = conn.execute("SELECT memory_id FROM memories") 

1269 memory_ids = set(row[0] for row in cursor.fetchall()) 

1270 

1271 # Get all FTS IDs 

1272 cursor = conn.execute("SELECT memory_id FROM memories_fts") 

1273 fts_ids = set(row[0] for row in cursor.fetchall()) 

1274 

1275 # Find discrepancies 

1276 missing_from_fts = memory_ids - fts_ids 

1277 orphaned_in_fts = fts_ids - memory_ids 

1278 

1279 in_sync = len(missing_from_fts) == 0 and len(orphaned_in_fts) == 0 

1280 

1281 result = { 

1282 "status": "ok" if in_sync else "degraded", 

1283 "total_memories": len(memory_ids), 

1284 "total_fts_entries": len(fts_ids), 

1285 "in_sync": in_sync, 

1286 "missing_from_fts": len(missing_from_fts), 

1287 "orphaned_in_fts": len(orphaned_in_fts), 

1288 } 

1289 

1290 if not in_sync and not req.verify_only: 

1291 # Auto-repair: remove orphaned FTS entries 

1292 if orphaned_in_fts: 

1293 placeholders = ",".join("?" * len(orphaned_in_fts)) 

1294 conn.execute( 

1295 f"DELETE FROM memories_fts WHERE memory_id IN ({placeholders})", 

1296 list(orphaned_in_fts), 

1297 ) 

1298 result["repaired_orphaned"] = len(orphaned_in_fts) 

1299 

1300 result["auto_repaired"] = True 

1301 result["status"] = "repaired" 

1302 

1303 return result 

1304 

1305 except Exception as e: 

1306 logger.error(f"FTS verify failed: {e}") 

1307 raise HTTPException(status_code=500, detail=str(e)) from e 

1308 

1309 @app.get("/admin/health") 

1310 async def admin_health() -> dict[str, Any]: 

1311 """Admin health check endpoint with detailed system status. 

1312 

1313 Returns: 

1314 Dict with detailed component statuses and system metrics. 

1315 """ 

1316 mem = _get_memory_singleton() 

1317 components: dict[str, Any] = {} 

1318 

1319 # Database health 

1320 try: 

1321 conn = mem._store._get_connection() # type: ignore[attr-defined] 

1322 cursor = conn.execute("SELECT 1") 

1323 cursor.fetchone() 

1324 

1325 # Get database stats 

1326 cursor = conn.execute("SELECT COUNT(*) FROM memories") 

1327 total_memories = cursor.fetchone()[0] 

1328 

1329 cursor = conn.execute("SELECT COUNT(*) FROM memories_fts") 

1330 fts_entries = cursor.fetchone()[0] 

1331 

1332 cursor = conn.execute("SELECT COUNT(DISTINCT user_id) FROM memories") 

1333 total_users = cursor.fetchone()[0] 

1334 

1335 components["database"] = { 

1336 "status": "healthy", 

1337 "type": "sqlite", 

1338 "total_memories": total_memories, 

1339 "total_users": total_users, 

1340 "fts_entries": fts_entries, 

1341 "fts_in_sync": total_memories == fts_entries, 

1342 } 

1343 except Exception as e: 

1344 components["database"] = { 

1345 "status": "unhealthy", 

1346 "error": str(e), 

1347 } 

1348 

1349 # Embedding adapter health 

1350 try: 

1351 embed = getattr(mem, "_embed", None) 

1352 if embed is not None: 

1353 adapter_name = type(embed).__name__ 

1354 components["embedding"] = { 

1355 "status": "healthy", 

1356 "adapter": adapter_name, 

1357 } 

1358 

1359 # Check circuit breaker if available 

1360 if hasattr(embed, "get_circuit_breaker_state"): 

1361 cb_state = embed.get_circuit_breaker_state() 

1362 components["embedding"]["circuit_breaker"] = cb_state 

1363 else: 

1364 components["embedding"] = { 

1365 "status": "not_configured", 

1366 } 

1367 except Exception as e: 

1368 components["embedding"] = { 

1369 "status": "unhealthy", 

1370 "error": str(e), 

1371 } 

1372 

1373 # Background task manager health 

1374 try: 

1375 from kemi.background_tasks import get_task_manager 

1376 

1377 task_manager = get_task_manager() 

1378 stats = task_manager.get_stats() 

1379 components["task_manager"] = { 

1380 "status": "healthy", 

1381 "pending": stats["pending"], 

1382 "running": stats["running"], 

1383 "completed": stats["completed"], 

1384 "failed": stats["failed"], 

1385 } 

1386 except Exception as e: 

1387 components["task_manager"] = { 

1388 "status": "unhealthy", 

1389 "error": str(e), 

1390 } 

1391 

1392 # Determine overall status 

1393 all_healthy = all(c.get("status") == "healthy" for c in components.values()) 

1394 

1395 

1396 return { 

1397 "status": "ok" if all_healthy else "degraded", 

1398 "components": components, 

1399 "timestamp": datetime.now(timezone.utc).isoformat(), 

1400 } 

1401 

1402 # Observability / Metrics Endpoints 

1403 

1404 @app.get("/metrics") 

1405 async def get_metrics(output_format: str = "json") -> Any: 

1406 """Get system metrics. 

1407 

1408 Args: 

1409 format: Output format — "json" or "prometheus". 

1410 

1411 Returns: 

1412 Metrics data in the requested format. 

1413 """ 

1414 metrics_data = mem.get_metrics() 

1415 if metrics_data is None: 

1416 raise HTTPException( 

1417 status_code=503, 

1418 detail="Metrics collector not available", 

1419 ) 

1420 

1421 if output_format.lower() == "prometheus": 

1422 prom = mem.get_metrics_prometheus() 

1423 if prom is None: 

1424 raise HTTPException( 

1425 status_code=503, 

1426 detail="Metrics collector not available", 

1427 ) 

1428 return PlainTextResponse(content=prom, media_type="text/plain") 

1429 

1430 return metrics_data 

1431 

1432 # Audit Trail Endpoints 

1433 

1434 @app.post("/audit/log") 

1435 async def audit_log(req: AuditLogRequest, request: Request) -> dict[str, Any]: 

1436 """Log an operation to the audit trail. 

1437 

1438 Returns: 

1439 Dict with the entry ID of the logged operation. 

1440 """ 

1441 effective_user = _resolve_user_id(request, req.user_id) 

1442 if not hasattr(mem, "_audit_trail") or mem._audit_trail is None: 

1443 raise HTTPException( 

1444 status_code=503, 

1445 detail="Audit trail not enabled. Use POST /admin/enable-audit first.", 

1446 ) 

1447 

1448 try: 

1449 entry_id = mem._audit_trail.log_operation( 

1450 user_id=effective_user, 

1451 operation=req.operation, 

1452 details=req.details or {}, 

1453 memory_id=req.memory_id, 

1454 namespace=req.namespace, 

1455 status=req.status, 

1456 client_ip=req.client_ip, 

1457 user_agent=req.user_agent, 

1458 duration_ms=req.duration_ms, 

1459 ) 

1460 return {"entry_id": entry_id, "status": "logged"} 

1461 except Exception as e: 

1462 logger.error(f"Audit log failed: {e}") 

1463 raise HTTPException(status_code=500, detail=str(e)) from e 

1464 

1465 @app.post("/audit/query") 

1466 async def audit_query( 

1467 req: AuditQueryRequest, request: Request 

1468 ) -> dict[str, Any]: 

1469 """Query the audit trail with filters. 

1470 

1471 Returns: 

1472 Dict with list of matching audit entries and total count. 

1473 """ 

1474 if not hasattr(mem, "_audit_trail") or mem._audit_trail is None: 

1475 raise HTTPException( 

1476 status_code=503, 

1477 detail="Audit trail not enabled. Use POST /admin/enable-audit first.", 

1478 ) 

1479 

1480 # When authed, force the user_id filter to the caller's own id 

1481 # so a tenant cannot read another tenant's audit entries. 

1482 authed = getattr(request.state, "user_id", None) 

1483 query_user = req.user_id 

1484 if authed is not None: 

1485 if query_user is not None and query_user != authed: 

1486 raise HTTPException( 

1487 status_code=403, 

1488 detail="user_id does not match authenticated user", 

1489 ) 

1490 query_user = authed 

1491 

1492 try: 

1493 entries = mem._audit_trail.query( 

1494 user_id=query_user, 

1495 operation=req.operation, 

1496 status=req.status, 

1497 memory_id=req.memory_id, 

1498 namespace=req.namespace, 

1499 start_time=req.start_time, 

1500 end_time=req.end_time, 

1501 limit=req.limit, 

1502 offset=req.offset, 

1503 ) 

1504 return { 

1505 "entries": [e.to_dict() for e in entries], 

1506 "count": len(entries), 

1507 "limit": req.limit, 

1508 "offset": req.offset, 

1509 } 

1510 except Exception as e: 

1511 logger.error(f"Audit query failed: {e}") 

1512 raise HTTPException(status_code=500, detail=str(e)) from e 

1513 

1514 @app.get("/audit/stats") 

1515 async def audit_stats() -> dict[str, Any]: 

1516 """Get overall audit trail statistics. 

1517 

1518 Returns: 

1519 Dict with total entries, unique users, date range, retention policy. 

1520 """ 

1521 if not hasattr(mem, "_audit_trail") or mem._audit_trail is None: 

1522 raise HTTPException( 

1523 status_code=503, 

1524 detail="Audit trail not enabled. Use POST /admin/enable-audit first.", 

1525 ) 

1526 

1527 try: 

1528 return mem._audit_trail.get_stats() # type: ignore[no-any-return] 

1529 except Exception as e: 

1530 logger.error(f"Audit stats failed: {e}") 

1531 raise HTTPException(status_code=500, detail=str(e)) from e 

1532 

1533 @app.post("/audit/export") 

1534 async def audit_export( 

1535 req: AuditExportRequest, request: Request 

1536 ) -> dict[str, Any]: 

1537 """Export audit entries for compliance. 

1538 

1539 Returns: 

1540 Dict with exported entries. 

1541 """ 

1542 if not hasattr(mem, "_audit_trail") or mem._audit_trail is None: 

1543 raise HTTPException( 

1544 status_code=503, 

1545 detail="Audit trail not enabled. Use POST /admin/enable-audit first.", 

1546 ) 

1547 

1548 # Same isolation rule as audit_query. 

1549 authed = getattr(request.state, "user_id", None) 

1550 export_user = req.user_id 

1551 if authed is not None: 

1552 if export_user is not None and export_user != authed: 

1553 raise HTTPException( 

1554 status_code=403, 

1555 detail="user_id does not match authenticated user", 

1556 ) 

1557 export_user = authed 

1558 

1559 try: 

1560 entries = mem._audit_trail.export( 

1561 start_time=req.start_time, 

1562 end_time=req.end_time, 

1563 user_id=export_user, 

1564 ) 

1565 return {"entries": entries, "count": len(entries)} 

1566 except Exception as e: 

1567 logger.error(f"Audit export failed: {e}") 

1568 raise HTTPException(status_code=500, detail=str(e)) from e 

1569 

1570 # Adaptive Retrieval Endpoints 

1571 

1572 @app.post("/adaptive/analyze") 

1573 async def adaptive_analyze(req: AdaptiveAnalyzeRequest) -> dict[str, Any]: 

1574 """Analyze a query and return adaptive retrieval weights. 

1575 

1576 Returns: 

1577 Dict with query classification, confidence, and recommended weights. 

1578 """ 

1579 if not hasattr(mem, "_adaptive_retriever") or mem._adaptive_retriever is None: 

1580 raise HTTPException( 

1581 status_code=503, 

1582 detail="Adaptive retrieval not enabled. Use POST /admin/enable-adaptive first.", 

1583 ) 

1584 

1585 try: 

1586 profile = mem._adaptive_retriever.analyze_query(req.query) 

1587 return { 

1588 "query": profile.query, 

1589 "query_type": profile.query_type.value, 

1590 "confidence": profile.confidence, 

1591 "word_count": profile.word_count, 

1592 "keyword_density": profile.keyword_density, 

1593 "specificity": profile.specificity, 

1594 "has_question_mark": profile.has_question_mark, 

1595 "has_named_entity_hint": profile.has_named_entity_hint, 

1596 "recommended_weights": profile.recommended_weights, 

1597 } 

1598 except Exception as e: 

1599 logger.error(f"Adaptive analyze failed: {e}") 

1600 raise HTTPException(status_code=500, detail=str(e)) from e 

1601 

1602 @app.get("/adaptive/user-profile/{user_id}") 

1603 async def adaptive_user_profile( 

1604 user_id: str, request: Request 

1605 ) -> dict[str, Any]: 

1606 """Get the adaptive query type distribution for a user. 

1607 

1608 Returns: 

1609 Dict with query distribution and dominant type. 

1610 """ 

1611 effective_user = _resolve_user_id(request, user_id) 

1612 if not hasattr(mem, "_adaptive_retriever") or mem._adaptive_retriever is None: 

1613 raise HTTPException( 

1614 status_code=503, 

1615 detail="Adaptive retrieval not enabled. Use POST /admin/enable-adaptive first.", 

1616 ) 

1617 

1618 try: 

1619 return mem._adaptive_retriever.get_user_profile(effective_user) # type: ignore[no-any-return] 

1620 except Exception as e: 

1621 logger.error(f"Adaptive user profile failed: {e}") 

1622 raise HTTPException(status_code=500, detail=str(e)) from e 

1623 

1624 # Admin Feature Toggle Endpoints 

1625 

1626 @app.post("/admin/enable-audit") 

1627 async def admin_enable_audit(req: EnableFeatureRequest) -> dict[str, Any]: 

1628 """Enable or disable the audit trail. 

1629 

1630 Returns: 

1631 Dict with enabled status. 

1632 """ 

1633 if not hasattr(mem, "enable_audit_trail"): 

1634 raise HTTPException( 

1635 status_code=501, 

1636 detail="Memory instance does not support audit trail", 

1637 ) 

1638 

1639 try: 

1640 if req.enable: 

1641 mem.enable_audit_trail( 

1642 retention_days=req.retention_days, 

1643 auto_purge=req.auto_purge, 

1644 ) 

1645 else: 

1646 mem._audit_trail = None 

1647 

1648 return { 

1649 "audit_trail_enabled": req.enable, 

1650 "retention_days": req.retention_days, 

1651 "auto_purge": req.auto_purge, 

1652 } 

1653 except Exception as e: 

1654 logger.error(f"Enable audit trail failed: {e}") 

1655 raise HTTPException(status_code=500, detail=str(e)) from e 

1656 

1657 @app.post("/admin/enable-adaptive") 

1658 async def admin_enable_adaptive(req: EnableFeatureRequest) -> dict[str, Any]: 

1659 """Enable or disable adaptive retrieval. 

1660 

1661 Returns: 

1662 Dict with enabled status. 

1663 """ 

1664 if not hasattr(mem, "enable_adaptive_retrieval"): 

1665 raise HTTPException( 

1666 status_code=501, 

1667 detail="Memory instance does not support adaptive retrieval", 

1668 ) 

1669 

1670 try: 

1671 mem.enable_adaptive_retrieval(enable=req.enable) 

1672 return {"adaptive_retrieval_enabled": req.enable} 

1673 except Exception as e: 

1674 logger.error(f"Enable adaptive retrieval failed: {e}") 

1675 raise HTTPException(status_code=500, detail=str(e)) from e 

1676 

1677 # API Key Management Endpoints 

1678 

1679 @app.post("/api/keys") 

1680 async def create_api_key( 

1681 req: CreateAPIKeyRequest, request: Request 

1682 ) -> dict[str, Any]: 

1683 """Create a new API key for a user. 

1684 

1685 The raw key is returned in the response exactly once; it cannot 

1686 be retrieved later. Store it securely. 

1687 

1688 When the caller is authenticated, the key is bound to the 

1689 caller's own user_id (a 403 is raised if the body disagrees). 

1690 When unauthenticated, the body's user_id is used — this is the 

1691 bootstrap path for a brand-new tenant. 

1692 """ 

1693 manager = _get_api_key_manager() 

1694 if manager is None: 

1695 raise HTTPException( 

1696 status_code=501, 

1697 detail="API key management not supported by this storage", 

1698 ) 

1699 

1700 authed = getattr(request.state, "user_id", None) 

1701 if authed is not None and req.user_id != authed: 

1702 raise HTTPException( 

1703 status_code=403, 

1704 detail="user_id does not match authenticated user", 

1705 ) 

1706 

1707 from kemi.api_keys import make_expiry 

1708 

1709 expires_at = make_expiry(req.expires_in_days) if req.expires_in_days else None 

1710 try: 

1711 key = manager.create_key( 

1712 user_id=req.user_id, 

1713 name=req.name, 

1714 expires_at=expires_at, 

1715 ) 

1716 except (ValueError, RuntimeError) as e: 

1717 raise HTTPException(status_code=400, detail=str(e)) from e 

1718 

1719 return key.to_dict(include_secret=True) 

1720 

1721 @app.get("/api/keys") 

1722 async def list_api_keys(request: Request) -> dict[str, Any]: 

1723 """List API keys. Authenticated callers see only their own keys.""" 

1724 manager = _get_api_key_manager() 

1725 if manager is None: 

1726 raise HTTPException( 

1727 status_code=501, 

1728 detail="API key management not supported by this storage", 

1729 ) 

1730 

1731 authed = getattr(request.state, "user_id", None) 

1732 # When authed, always scope to the caller's user_id regardless of 

1733 # any user_id query param. When unauthed, the listing is global 

1734 # (admin-style view). 

1735 scope_user = authed 

1736 keys = manager.list_keys(user_id=scope_user) 

1737 return { 

1738 "keys": [k.to_dict() for k in keys], 

1739 "count": len(keys), 

1740 } 

1741 

1742 @app.delete("/api/keys/{key_id}") 

1743 async def revoke_api_key(key_id: str, request: Request) -> dict[str, Any]: 

1744 """Revoke an API key by id. 

1745 

1746 Authenticated callers may only revoke their own keys. 

1747 """ 

1748 manager = _get_api_key_manager() 

1749 if manager is None: 

1750 raise HTTPException( 

1751 status_code=501, 

1752 detail="API key management not supported by this storage", 

1753 ) 

1754 

1755 authed = getattr(request.state, "user_id", None) 

1756 if authed is not None: 

1757 existing = manager.get(key_id) 

1758 if existing is None: 

1759 raise HTTPException(status_code=404, detail="Key not found") 

1760 if existing.user_id != authed: 

1761 raise HTTPException( 

1762 status_code=403, 

1763 detail="Cannot revoke a key belonging to another user", 

1764 ) 

1765 

1766 if not manager.revoke(key_id): 

1767 raise HTTPException( 

1768 status_code=404, 

1769 detail="Key not found or already revoked", 

1770 ) 

1771 return {"key_id": key_id, "revoked": True} 

1772 

1773 # Memory Version History Endpoint 

1774 

1775 @app.get("/memories/{memory_id}/history") 

1776 async def get_memory_history(memory_id: str, request: Request, limit: int = 100) -> dict[str, Any]: 

1777 """Get version history for a memory.""" 

1778 try: 

1779 mem.configure_versioning() 

1780 history = mem.get_history(memory_id, limit=limit) 

1781 except RuntimeError as e: 

1782 raise HTTPException(status_code=501, detail=str(e)) from e 

1783 

1784 return { 

1785 "memory_id": memory_id, 

1786 "versions": [ 

1787 { 

1788 "version": snap.version, 

1789 "content": snap.content, 

1790 "importance": snap.importance, 

1791 "tags": snap.tags, 

1792 "memory_type": snap.memory_type, 

1793 "confidence": snap.confidence, 

1794 "namespace": snap.namespace, 

1795 "source": snap.source, 

1796 "changed_at": snap.changed_at.isoformat() if snap.changed_at else None, 

1797 "changed_by": snap.changed_by, 

1798 } 

1799 for snap in history 

1800 ], 

1801 "count": len(history), 

1802 } 

1803 

1804 # Webhook Management Endpoints 

1805 

1806 class CreateWebhookRequest(BaseModel): 

1807 url: str = Field(..., min_length=1) 

1808 events: list[str] = Field(..., min_length=1) 

1809 secret: str = "" 

1810 active: bool = True 

1811 

1812 class UpdateWebhookRequest(BaseModel): 

1813 url: str | None = None 

1814 events: list[str] | None = None 

1815 secret: str | None = None 

1816 active: bool | None = None 

1817 

1818 def _get_webhook_store() -> WebhookStore | None: 

1819 """Get or create a WebhookStore bound to the active memory's database.""" 

1820 try: 

1821 db_path = mem._store._db_path # type: ignore[attr-defined] 

1822 return WebhookStore(db_path=db_path) 

1823 except (AttributeError, Exception): 

1824 return None 

1825 

1826 @app.post("/webhooks", status_code=201) 

1827 async def create_webhook(req: CreateWebhookRequest, request: Request) -> dict[str, Any]: 

1828 """Register a new webhook endpoint.""" 

1829 store = _get_webhook_store() 

1830 if store is None: 

1831 raise HTTPException( 

1832 status_code=501, 

1833 detail="Webhook store not available (storage adapter does not expose db_path)", 

1834 ) 

1835 

1836 try: 

1837 event_types = [WebhookEventType.from_string(e) for e in req.events] 

1838 except ValueError as e: 

1839 raise HTTPException(status_code=400, detail=str(e)) from e 

1840 

1841 cfg = WebhookConfig( 

1842 webhook_id="", 

1843 url=req.url, 

1844 events=event_types, 

1845 secret=req.secret, 

1846 active=req.active, 

1847 ) 

1848 wh_id = store.create(cfg) 

1849 return {"webhook_id": wh_id, "url": req.url, "events": req.events, "active": req.active} 

1850 

1851 @app.get("/webhooks") 

1852 async def list_webhooks(request: Request) -> dict[str, Any]: 

1853 """List all registered webhooks.""" 

1854 store = _get_webhook_store() 

1855 if store is None: 

1856 raise HTTPException(status_code=501, detail="Webhook store not available") 

1857 

1858 configs = store.list_all(active_only=False) 

1859 return { 

1860 "webhooks": [ 

1861 { 

1862 "webhook_id": c.webhook_id, 

1863 "url": c.url, 

1864 "events": [e.value for e in c.events], 

1865 "active": c.active, 

1866 } 

1867 for c in configs 

1868 ], 

1869 "count": len(configs), 

1870 } 

1871 

1872 @app.delete("/webhooks/{webhook_id}") 

1873 async def delete_webhook(webhook_id: str, request: Request) -> dict[str, Any]: 

1874 """Delete a webhook configuration.""" 

1875 store = _get_webhook_store() 

1876 if store is None: 

1877 raise HTTPException(status_code=501, detail="Webhook store not available") 

1878 

1879 if not store.delete(webhook_id): 

1880 raise HTTPException(status_code=404, detail=f"Webhook not found: {webhook_id}") 

1881 return {"webhook_id": webhook_id, "deleted": True} 

1882 

1883 # Admin endpoint: list users with their memory counts 

1884 

1885 @app.get("/admin/users") 

1886 async def admin_list_users(request: Request) -> dict[str, Any]: 

1887 """List all users and their memory counts. 

1888 

1889 When authenticated, only the caller's own row is returned. 

1890 When unauthenticated (default in backward-compat mode), the 

1891 full list is returned. 

1892 """ 

1893 authed = getattr(request.state, "user_id", None) 

1894 all_users = mem.list_users() 

1895 if authed is not None: 

1896 users = [u for u in all_users if u == authed] 

1897 else: 

1898 users = all_users 

1899 

1900 store = getattr(mem, "_store", None) 

1901 rows: list[dict[str, Any]] = [] 

1902 for uid in users: 

1903 try: 

1904 count = store.count(uid) if store is not None else 0 

1905 except Exception: # pragma: no cover - defensive 

1906 count = 0 

1907 row: dict[str, Any] = {"user_id": uid, "memory_count": count} 

1908 # Optional metadata — best effort, never fail the response. 

1909 last_active = getattr(store, "get_last_active", lambda _u: None)(uid) 

1910 if last_active is not None: 

1911 row["last_active"] = last_active 

1912 rows.append(row) 

1913 return {"users": rows, "count": len(rows)} 

1914 

1915 return app