Coverage for src / kemi / api_server.py: 83%
820 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1"""FastAPI REST API server for kemi memory.
3Optional dependency: install with `pip install fastapi uvicorn`
5Usage:
6 from kemi.api_server import create_app
7 app = create_app()
8 uvicorn.run(app, host="0.0.0.0", port=8000)
10Rate Limiting:
11 Rate limiting is disabled by default. To enable, set:
12 - KEMI_RATE_LIMIT_REQUESTS: max requests per window (default 100)
13 - KEMI_RATE_LIMIT_WINDOW: time window in seconds (default 60)
15Security:
16 CORS and security headers are disabled by default. To enable, set:
17 - KEMI_CORS_ORIGINS: comma-separated list of allowed origins
19API Key Authentication (multi-tenancy):
20 Disabled by default for backward compatibility. To enable, set:
21 - KEMI_API_KEY_REQUIRED=true
23 When the X-API-Key header is sent, the key is hashed with SHA-256 and
24 looked up in the `api_keys` table. The associated user_id is injected
25 into request.state.user_id. Endpoints then enforce that any user_id
26 in the body/path matches the authenticated user, preventing cross-
27 tenant access. Manage keys via the /api/keys endpoints or
28 `kemi api-key ...` CLI commands.
29"""
31import logging
32import os
33import sqlite3
34import time
35from collections import defaultdict
36from contextlib import asynccontextmanager
37from datetime import datetime, timezone
38from threading import Lock
39from typing import Any
41from kemi import Memory
42from kemi.models import LifecycleState, MemorySource, MemoryType
43from kemi.webhooks import WebhookConfig, WebhookEventType, WebhookStore, RetryConfig
45logger = logging.getLogger(__name__)
47try:
48 from fastapi import FastAPI, HTTPException, Request
49 from fastapi.middleware.cors import CORSMiddleware
50 from fastapi.middleware.trustedhost import TrustedHostMiddleware
51 from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse
52 from pydantic import BaseModel, Field
54 _FASTAPI_AVAILABLE = True
55except ImportError: # pragma: no cover
56 _FASTAPI_AVAILABLE = False
57 BaseModel = object # type: ignore[no-redef,assignment]
59 def Field(*a: Any, **kw: Any) -> None: # type: ignore[no-redef] # noqa: N802
60 return None
63# Endpoints that never require auth, even when KEMI_API_KEY_REQUIRED=true.
64# /health is a liveness probe; /api/keys POST bootstraps a brand-new user's
65# first key (you can't authenticate before you have a key).
66_AUTH_EXEMPT_PATHS = frozenset({"/health", "/api/keys"})
68# Subset of /api/keys paths that are exempt: only the POST (create) and
69# the health-style subpaths. GET/DELETE still require auth so that keys
70# can't be enumerated by an anonymous caller.
71_AUTH_EXEMPT_PREFIXES = tuple() # handled per-route below
74def _api_key_required() -> bool:
75 """Whether the server requires X-API-Key on all non-exempt endpoints."""
76 return os.environ.get("KEMI_API_KEY_REQUIRED", "false").lower() in ("true", "1", "yes")
79def _is_exempt(path: str, method: str) -> bool:
80 """Whether (path, method) is exempt from API-key requirement."""
81 if path in _AUTH_EXEMPT_PATHS:
82 return True
83 # POST /api/keys is the bootstrap endpoint; GET/DELETE are not.
84 if path == "/api/keys" and method == "POST":
85 return True
86 return False
89def _resolve_user_id(request: Request, claimed_user_id: str | None) -> str:
90 """Return the effective user_id, enforcing isolation when authed.
92 If a valid X-API-Key was presented, ``request.state.user_id`` is set.
93 - When authed, ``claimed_user_id`` must match it (or be None); a
94 mismatch raises 403 so a tenant cannot impersonate another.
95 - When not authed, the caller's claimed_user_id is passed through.
97 Returns the user_id that the endpoint should use.
98 """
99 authed = getattr(request.state, "user_id", None)
100 if authed is None:
101 if claimed_user_id is None:
102 raise HTTPException(
103 status_code=400,
104 detail="user_id is required",
105 )
106 return claimed_user_id
107 if claimed_user_id is not None and claimed_user_id != authed:
108 raise HTTPException(
109 status_code=403,
110 detail="user_id does not match authenticated user",
111 )
112 return authed
115def _require_admin(request: Request) -> str:
116 """Require authentication; return the authed user_id.
118 Used for endpoints that should only be reachable by an authenticated
119 caller (e.g. listing other users' keys). The caller is responsible
120 for any further authorization (e.g. role checks).
121 """
122 authed = getattr(request.state, "user_id", None)
123 if authed is None:
124 raise HTTPException(status_code=401, detail="Authentication required")
125 return authed
128class RateLimiter:
129 """Simple in-memory rate limiter for API endpoints.
131 Uses a sliding window approach with per-key counters.
132 """
134 def __init__(
135 self,
136 requests_per_window: int = 100,
137 window_seconds: int = 60,
138 ) -> None:
139 self._requests_per_window = requests_per_window
140 self._window_seconds = window_seconds
141 self._requests: dict[str, list[float]] = defaultdict(list)
142 self._lock = Lock()
144 def is_allowed(self, key: str) -> bool:
145 """Check if a request from the given key is allowed.
147 Args:
148 key: Identifier for the client (e.g., IP address, user_id).
150 Returns:
151 True if the request is allowed, False if rate limited.
152 """
153 now = time.time()
154 window_start = now - self._window_seconds
156 with self._lock:
157 # Clean up old timestamps
158 self._requests[key] = [ts for ts in self._requests[key] if ts > window_start]
160 if len(self._requests[key]) >= self._requests_per_window:
161 return False
163 self._requests[key].append(now)
164 return True
166 def get_retry_after(self, key: str) -> int:
167 """Get seconds until the rate limit resets for this key."""
168 now = time.time()
169 window_start = now - self._window_seconds
171 with self._lock:
172 valid_requests = [ts for ts in self._requests[key] if ts > window_start]
173 if len(valid_requests) < self._requests_per_window:
174 return 0
176 oldest = min(valid_requests)
177 return int(self._window_seconds - (now - oldest)) + 1
180# Global rate limiter instance (lazily initialized)
181_rate_limiter: RateLimiter | None = None
184def _get_rate_limiter() -> RateLimiter | None:
185 """Get or create the rate limiter based on environment config."""
186 global _rate_limiter
187 if _rate_limiter is not None:
188 return _rate_limiter
190 # Check if rate limiting is enabled via environment variables
191 enabled = os.environ.get("KEMI_RATE_LIMIT_ENABLED", "false").lower() in ("true", "1", "yes")
192 if not enabled:
193 return None
195 requests = int(os.environ.get("KEMI_RATE_LIMIT_REQUESTS", "100"))
196 window = int(os.environ.get("KEMI_RATE_LIMIT_WINDOW", "60"))
198 _rate_limiter = RateLimiter(requests_per_window=requests, window_seconds=window)
199 logger.info(f"Rate limiting enabled: {requests} requests per {window} seconds")
200 return _rate_limiter
203def _check_rate_limit(client_key: str) -> tuple[bool, int]:
204 """Check rate limit for a client key.
206 Returns:
207 Tuple of (is_allowed, retry_after_seconds).
208 retry_after_seconds is 0 if allowed.
209 """
210 limiter = _get_rate_limiter()
211 if limiter is None:
212 return True, 0
214 if limiter.is_allowed(client_key):
215 return True, 0
217 return False, limiter.get_retry_after(client_key)
220# Cached APIKeyManager, lazily built from the active memory's storage.
221_api_key_manager: Any = None
222_api_key_manager_lock = Lock()
225def _get_api_key_manager() -> Any:
226 """Return a cached APIKeyManager bound to the active memory's storage.
228 Returns None when the storage adapter doesn't support API keys
229 (e.g. in-memory mock). Endpoints should treat None as 501.
230 """
231 global _api_key_manager
232 if _api_key_manager is not None:
233 return _api_key_manager
234 with _api_key_manager_lock:
235 if _api_key_manager is not None:
236 return _api_key_manager
237 mem = _get_memory_singleton()
238 store = getattr(mem, "_store", None)
239 # Prefer a real connection if the adapter exposes one; otherwise
240 # fall back to the storage adapter's helper.
241 get_mgr = getattr(store, "get_api_key_manager", None)
242 if callable(get_mgr):
243 _api_key_manager = get_mgr()
244 return _api_key_manager
245 # Last resort: try to share a SQLite connection. Mocks won't have one.
246 get_conn = getattr(store, "_get_connection", None)
247 if callable(get_conn):
248 from kemi.api_keys import APIKeyManager
250 _api_key_manager = APIKeyManager(connection=get_conn())
251 return _api_key_manager
252 return None
255def _reset_api_key_manager() -> None:
256 """Clear the cached manager. Used by tests to swap storage backends."""
257 global _api_key_manager
258 with _api_key_manager_lock:
259 _api_key_manager = None
262class RememberRequest(BaseModel):
263 user_id: str = Field(..., min_length=1)
264 content: str = Field(..., min_length=1)
265 importance: float = Field(0.5, ge=0.0, le=1.0)
266 source: str = "user_stated"
267 tags: list[str] | None = None
268 namespace: str = "default"
269 session_id: str | None = None
270 memory_type: str = "episodic"
271 confidence: float = Field(1.0, ge=0.0, le=1.0)
274class RecallRequest(BaseModel):
275 user_id: str = Field(..., min_length=1)
276 query: str = Field(..., min_length=1)
277 top_k: int = Field(5, ge=1)
278 max_tokens: int | None = None
279 namespace: str = "default"
280 session_id: str | None = None
281 hybrid_search: bool | None = None
284class UpdateRequest(BaseModel):
285 content: str | None = None
286 importance: float | None = Field(None, ge=0.0, le=1.0)
287 confidence: float | None = Field(None, ge=0.0, le=1.0)
288 memory_type: str | None = None
291class PruneRequest(BaseModel):
292 max_age_days: float | None = None
293 min_importance: float | None = None
294 lifecycle_states: list[str] | None = None
295 namespace: str = "default"
298class ConsolidateRequest(BaseModel):
299 namespace: str = "default"
300 min_memories: int = 5
301 max_age_days: float = 30.0
304class TopicsRequest(BaseModel):
305 n_clusters: int = 3
306 namespace: str = "default"
309class GraphRequest(BaseModel):
310 namespace: str = "default"
313class FeedbackRequest(BaseModel):
314 memory_id: str
315 helpful: bool = True
316 namespace: str = "default"
319class BatchRememberRequest(BaseModel):
320 """Request for background batch remember operation."""
322 user_id: str = Field(..., min_length=1)
323 contents: list[str] = Field(..., min_length=1)
324 importance: float = Field(0.5, ge=0.0, le=1.0)
325 namespace: str = "default"
328class RebuildFTSRequest(BaseModel):
329 """Request for background FTS index rebuild."""
331 user_id: str | None = None # Optional: rebuild for specific user only
334class AdminFTSStatsRequest(BaseModel):
335 """Request for FTS index statistics."""
337 user_id: str | None = None # Optional: get stats for specific user only
340class AdminFTSRepairRequest(BaseModel):
341 """Request for FTS index integrity repair."""
343 verify_only: bool = False # If True, only verify without repairing
346class AuditLogRequest(BaseModel):
347 """Request to log an audit entry."""
349 user_id: str = Field(..., min_length=1)
350 operation: str = Field(..., min_length=1)
351 status: str = "success"
352 details: dict[str, Any] | None = None
353 memory_id: str | None = None
354 namespace: str = "default"
355 client_ip: str | None = None
356 user_agent: str | None = None
357 duration_ms: float | None = None
360class AuditQueryRequest(BaseModel):
361 """Request to query the audit trail."""
363 user_id: str | None = None
364 operation: str | None = None
365 status: str | None = None
366 memory_id: str | None = None
367 namespace: str | None = None
368 start_time: str | None = None
369 end_time: str | None = None
370 limit: int = Field(100, ge=1, le=10000)
371 offset: int = Field(0, ge=0)
374class AuditExportRequest(BaseModel):
375 """Request to export audit entries."""
377 start_time: str | None = None
378 end_time: str | None = None
379 user_id: str | None = None
382class AdaptiveAnalyzeRequest(BaseModel):
383 """Request to analyze a query for adaptive retrieval."""
385 query: str = Field(..., min_length=1)
388class EnableFeatureRequest(BaseModel):
389 """Request to enable or disable a feature."""
391 enable: bool = True
392 retention_days: int = Field(365, ge=1)
393 auto_purge: bool = True
396class CreateAPIKeyRequest(BaseModel):
397 """Request to create a new API key."""
399 user_id: str = Field(..., min_length=1)
400 name: str = Field(..., min_length=1, max_length=200)
401 expires_in_days: int | None = Field(None, ge=1, le=36500)
404# Global memory instance for lifespan management
405_memory_instance: Memory | None = None
408def _get_memory_singleton() -> Memory:
409 """Get or create a singleton Memory instance."""
410 global _memory_instance
411 if _memory_instance is not None:
412 return _memory_instance
414 db_path = os.environ.get("KEMI_DB_PATH", os.path.expanduser("~/.kemi/memories.db"))
415 os.makedirs(os.path.dirname(db_path), exist_ok=True)
416 _memory_instance = Memory()
417 return _memory_instance
420@asynccontextmanager
421async def lifespan(app: FastAPI) -> Any:
422 """Lifespan context manager for graceful startup and shutdown.
424 On startup: Initialize the memory instance
425 On shutdown: Properly close database connections
426 """
427 # Startup
428 logger.info("Starting kemi API server...")
429 _get_memory_singleton()
430 db_path = os.environ.get("KEMI_DB_PATH", "~/.kemi/memories.db")
431 logger.info(f"Memory instance initialized with DB: {db_path}")
433 yield # Application runs here
435 # Shutdown
436 logger.info("Shutting down kemi API server...")
437 global _memory_instance
438 if _memory_instance is not None:
439 # Close the storage adapter connection
440 try:
441 store = getattr(_memory_instance, "_store", None)
442 if store is not None and hasattr(store, "close"):
443 store.close()
444 logger.info("Database connections closed")
445 except Exception as e:
446 logger.error(f"Error closing database connections: {e}")
447 _memory_instance = None
449 logger.info("kemi API server shutdown complete")
452def create_app(memory: Memory | None = None) -> Any:
453 """Create a FastAPI application wrapping a kemi Memory instance.
455 Args:
456 memory: Optional pre-configured Memory instance.
457 If None, creates a default Memory.
459 Returns:
460 FastAPI app instance.
461 """
462 if not _FASTAPI_AVAILABLE:
463 raise ImportError(
464 "FastAPI is required for the API server. Install with: pip install fastapi uvicorn"
465 )
467 app = FastAPI(
468 title="kemi API",
469 version="0.3.0",
470 lifespan=lifespan,
471 )
473 # Configure CORS if origins are specified
474 cors_origins = os.environ.get("KEMI_CORS_ORIGINS", "")
475 if cors_origins:
476 origins = [o.strip() for o in cors_origins.split(",") if o.strip()]
477 if origins:
478 app.add_middleware(
479 CORSMiddleware,
480 allow_origins=origins,
481 allow_credentials=True,
482 allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
483 allow_headers=["Authorization", "Content-Type"],
484 )
485 logger.info(f"CORS enabled for origins: {origins}")
487 # Add security headers middleware for production
488 trusted_hosts = os.environ.get("KEMI_TRUSTED_HOSTS", "")
489 if trusted_hosts:
490 hosts = [h.strip() for h in trusted_hosts.split(",") if h.strip()]
491 if hosts:
492 app.add_middleware(TrustedHostMiddleware, allowed_hosts=hosts)
493 logger.info(f"Trusted host middleware enabled for: {hosts}")
495 @app.middleware("http")
496 async def api_key_middleware(request: Request, call_next: Any) -> Any:
497 """Validate X-API-Key and inject request.state.user_id.
499 Behaviour:
500 - If X-API-Key is provided, hash it and look it up. On success
501 the key's user_id is attached to request.state. On failure
502 (unknown / expired / revoked) return 401.
503 - If X-API-Key is absent:
504 - Endpoints in _AUTH_EXEMPT_PATHS are always allowed.
505 - All other endpoints return 401 when KEMI_API_KEY_REQUIRED=true.
506 - When KEMI_API_KEY_REQUIRED=false (default), the request
507 proceeds unauthenticated for backward compatibility.
508 """
509 # Don't try to validate against a missing storage adapter
510 # (e.g. in some test harnesses); behave as if no auth is available.
511 header = request.headers.get("X-API-Key")
512 path = request.url.path
513 method = request.method
515 if header:
516 manager = _get_api_key_manager()
517 if manager is None:
518 return JSONResponse(
519 status_code=501,
520 content={"detail": "API key authentication not supported by this storage"},
521 )
522 key = manager.lookup(header)
523 if key is None:
524 return JSONResponse(
525 status_code=401,
526 content={"detail": "Invalid or expired API key"},
527 )
528 request.state.user_id = key.user_id
529 request.state.api_key_id = key.key_id
530 else:
531 if _api_key_required() and not _is_exempt(path, method):
532 return JSONResponse(
533 status_code=401,
534 content={"detail": "X-API-Key header required"},
535 )
537 return await call_next(request)
539 # Expose auth state on the request state for endpoints that want to
540 # introspect it (e.g. /api/keys GET to scope listings).
541 @app.middleware("http")
542 async def _ensure_state_defaults(request: Request, call_next: Any) -> Any:
543 if not hasattr(request.state, "user_id"):
544 request.state.user_id = None
545 if not hasattr(request.state, "api_key_id"):
546 request.state.api_key_id = None
547 return await call_next(request)
549 mem = memory or _get_memory_singleton()
551 @app.post("/remember")
552 async def remember(req: RememberRequest, request: Request) -> dict[str, Any]:
553 effective_user = _resolve_user_id(request, req.user_id)
554 # Rate limit check
555 allowed, retry_after = _check_rate_limit(effective_user)
556 if not allowed:
557 raise HTTPException(
558 status_code=429,
559 detail=(
560 f"Rate limit exceeded. Retry after {retry_after} seconds."
561 ),
562 headers={"Retry-After": str(retry_after)},
563 )
565 try:
566 source = MemorySource(req.source)
567 mtype = MemoryType(req.memory_type)
568 except ValueError as err:
569 raise HTTPException(status_code=400, detail=str(err)) from err
571 mid = mem.remember(
572 user_id=effective_user,
573 content=req.content,
574 importance=req.importance,
575 source=source,
576 tags=req.tags,
577 namespace=req.namespace,
578 session_id=req.session_id,
579 memory_type=mtype,
580 confidence=req.confidence,
581 )
582 return {"memory_id": mid}
584 @app.post("/recall")
585 async def recall(req: RecallRequest, request: Request) -> dict[str, Any]:
586 effective_user = _resolve_user_id(request, req.user_id)
587 # Rate limit check
588 allowed, retry_after = _check_rate_limit(effective_user)
589 if not allowed:
590 raise HTTPException(
591 status_code=429,
592 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.",
593 headers={"Retry-After": str(retry_after)},
594 )
596 try:
597 results = mem.recall(
598 user_id=effective_user,
599 query=req.query,
600 top_k=req.top_k,
601 max_tokens=req.max_tokens,
602 namespace=req.namespace,
603 session_id=req.session_id,
604 hybrid_search=req.hybrid_search,
605 )
606 except ValueError as e:
607 raise HTTPException(status_code=400, detail=str(e)) from e
609 return {
610 "results": [
611 {
612 "memory_id": r.memory_id,
613 "content": r.content,
614 "score": r.score,
615 "importance": r.importance,
616 "lifecycle_state": r.lifecycle_state.value,
617 "created_at": r.created_at.isoformat() if r.created_at else None,
618 "tags": r.tags,
619 "memory_type": r.memory_type.value,
620 "confidence": r.confidence,
621 "session_id": r.session_id,
622 "namespace": r.namespace,
623 "version": r.version,
624 }
625 for r in results
626 ]
627 }
629 @app.post("/recall/stream")
630 async def recall_stream(req: RecallRequest, request: Request) -> StreamingResponse:
631 """Stream recall results as Server-Sent Events."""
632 import json
634 effective_user = _resolve_user_id(request, req.user_id)
636 # Rate limit check
637 allowed, retry_after = _check_rate_limit(effective_user)
638 if not allowed:
639 raise HTTPException(
640 status_code=429,
641 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.",
642 headers={"Retry-After": str(retry_after)},
643 )
645 async def _generate() -> Any:
646 count = 0
647 try:
648 stream = mem.recall_stream(
649 user_id=effective_user,
650 query=req.query,
651 top_k=req.top_k,
652 max_tokens=req.max_tokens,
653 namespace=req.namespace,
654 session_id=req.session_id,
655 hybrid_search=req.hybrid_search,
656 )
657 async for result in stream:
658 count += 1
659 payload = {
660 "memory_id": result.memory_id,
661 "content": result.content,
662 "score": result.score,
663 }
664 yield f"data: {json.dumps(payload)}\n\n"
665 except ValueError as e:
666 yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
667 return
669 yield f"event: done\ndata: {json.dumps({'total': count})}\n\n"
671 return StreamingResponse(
672 _generate(),
673 media_type="text/event-stream",
674 )
676 @app.post("/recall-explain")
677 async def recall_explain(req: RecallRequest, request: Request) -> dict[str, Any]:
678 effective_user = _resolve_user_id(request, req.user_id)
679 # Rate limit check
680 allowed, retry_after = _check_rate_limit(effective_user)
681 if not allowed:
682 raise HTTPException(
683 status_code=429,
684 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.",
685 headers={"Retry-After": str(retry_after)},
686 )
688 try:
689 explained = mem.recall_explain(
690 user_id=effective_user,
691 query=req.query,
692 top_k=req.top_k,
693 namespace=req.namespace,
694 session_id=req.session_id,
695 )
696 except ValueError as e:
697 raise HTTPException(status_code=400, detail=str(e)) from e
699 return {
700 "results": [
701 {
702 "memory": {
703 "memory_id": item["memory"].memory_id,
704 "content": item["memory"].content,
705 "score": item["memory"].score,
706 },
707 "explanation": item["explanation"],
708 }
709 for item in explained
710 ]
711 }
713 @app.post("/forget")
714 async def forget(
715 request: Request,
716 user_id: str,
717 memory_id: str | None = None,
718 ) -> dict[str, Any]:
719 effective_user = _resolve_user_id(request, user_id)
720 # Rate limit check
721 allowed, retry_after = _check_rate_limit(effective_user)
722 if not allowed:
723 raise HTTPException(
724 status_code=429,
725 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.",
726 headers={"Retry-After": str(retry_after)},
727 )
729 count = mem.forget(effective_user, memory_id)
730 return {"deleted": count}
732 @app.patch("/memories/{memory_id}")
733 async def update_memory(
734 memory_id: str,
735 req: UpdateRequest,
736 request: Request,
737 ) -> dict[str, Any]:
738 # We need the user_id to rate-limit per-user; look up the memory.
739 existing = mem._store.get(memory_id)
740 if existing is None:
741 raise HTTPException(status_code=404, detail=f"Memory not found: {memory_id}")
742 effective_user = _resolve_user_id(request, existing.user_id)
743 allowed, retry_after = _check_rate_limit(effective_user)
744 if not allowed:
745 raise HTTPException(
746 status_code=429,
747 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.",
748 headers={"Retry-After": str(retry_after)},
749 )
751 try:
752 mtype = None
753 if req.memory_type:
754 mtype = MemoryType(req.memory_type)
755 except ValueError as e:
756 raise HTTPException(status_code=400, detail=str(e)) from e
758 try:
759 mem.update(
760 memory_id=memory_id,
761 content=req.content,
762 importance=req.importance,
763 confidence=req.confidence,
764 memory_type=mtype,
765 )
766 except ValueError as e:
767 raise HTTPException(status_code=404, detail=str(e)) from e
768 return {"memory_id": memory_id, "status": "updated"}
770 @app.post("/prune")
771 async def prune(
772 request: Request,
773 user_id: str,
774 req: PruneRequest,
775 ) -> dict[str, Any]:
776 effective_user = _resolve_user_id(request, user_id)
777 allowed, retry_after = _check_rate_limit(effective_user)
778 if not allowed:
779 raise HTTPException(
780 status_code=429,
781 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.",
782 headers={"Retry-After": str(retry_after)},
783 )
785 lifecycle_filter = None
786 if req.lifecycle_states:
787 try:
788 lifecycle_filter = [LifecycleState(s) for s in req.lifecycle_states]
789 except ValueError as e:
790 raise HTTPException(status_code=400, detail=str(e)) from e
792 deleted = mem.prune(
793 user_id=effective_user,
794 max_age_days=req.max_age_days,
795 min_importance=req.min_importance,
796 lifecycle_states=lifecycle_filter,
797 namespace=req.namespace,
798 )
799 return {"deleted": deleted}
801 @app.get("/stats/{user_id}")
802 async def stats(user_id: str, request: Request) -> dict[str, Any]:
803 effective_user = _resolve_user_id(request, user_id)
804 # Rate limit check
805 allowed, retry_after = _check_rate_limit(effective_user)
806 if not allowed:
807 raise HTTPException(
808 status_code=429,
809 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.",
810 headers={"Retry-After": str(retry_after)},
811 )
813 try:
814 return mem.stats(effective_user)
815 except ValueError as e:
816 raise HTTPException(status_code=400, detail=str(e)) from e
818 @app.get("/users")
819 async def list_users(request: Request) -> dict[str, Any]:
820 # If authed, restrict to the caller's own user_id.
821 authed = getattr(request.state, "user_id", None)
822 users = mem.list_users()
823 if authed is not None:
824 users = [u for u in users if u == authed]
825 return {"users": users}
827 @app.post("/consolidate/{user_id}")
828 async def consolidate_user(
829 user_id: str,
830 req: ConsolidateRequest,
831 request: Request,
832 ) -> dict[str, Any]:
833 effective_user = _resolve_user_id(request, user_id)
834 allowed, retry_after = _check_rate_limit(effective_user)
835 if not allowed:
836 raise HTTPException(
837 status_code=429,
838 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.",
839 headers={"Retry-After": str(retry_after)},
840 )
841 try:
842 mid = mem.consolidate(
843 user_id=effective_user,
844 namespace=req.namespace,
845 min_memories=req.min_memories,
846 max_age_days=req.max_age_days,
847 )
848 except ValueError as e:
849 raise HTTPException(status_code=400, detail=str(e)) from e
850 if mid:
851 return {"consolidated_memory_id": mid}
852 return {"message": "No consolidation needed"}
854 @app.post("/topics/{user_id}")
855 async def topics_user(
856 user_id: str,
857 req: TopicsRequest,
858 request: Request,
859 ) -> dict[str, Any]:
860 effective_user = _resolve_user_id(request, user_id)
861 allowed, retry_after = _check_rate_limit(effective_user)
862 if not allowed:
863 raise HTTPException(
864 status_code=429,
865 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.",
866 headers={"Retry-After": str(retry_after)},
867 )
868 try:
869 clusters = mem.cluster_topics(
870 user_id=effective_user,
871 n_clusters=req.n_clusters,
872 namespace=req.namespace,
873 )
874 except ValueError as e:
875 raise HTTPException(status_code=400, detail=str(e)) from e
876 return {
877 "topics": {
878 label: [
879 {
880 "memory_id": m.memory_id,
881 "content": m.content,
882 "importance": m.importance,
883 }
884 for m in mems
885 ]
886 for label, mems in clusters.items()
887 }
888 }
890 @app.post("/graph/{user_id}")
891 async def graph_user(
892 user_id: str,
893 req: GraphRequest,
894 request: Request,
895 ) -> dict[str, Any]:
896 effective_user = _resolve_user_id(request, user_id)
897 allowed, retry_after = _check_rate_limit(effective_user)
898 if not allowed:
899 raise HTTPException(
900 status_code=429,
901 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.",
902 headers={"Retry-After": str(retry_after)},
903 )
904 try:
905 graph_data = mem.get_memory_graph(
906 user_id=effective_user,
907 namespace=req.namespace,
908 )
909 except ValueError as e:
910 raise HTTPException(status_code=400, detail=str(e)) from e
911 return graph_data
913 @app.post("/feedback/{user_id}")
914 async def feedback_user(
915 user_id: str,
916 req: FeedbackRequest,
917 request: Request,
918 ) -> dict[str, Any]:
919 effective_user = _resolve_user_id(request, user_id)
920 allowed, retry_after = _check_rate_limit(effective_user)
921 if not allowed:
922 raise HTTPException(
923 status_code=429,
924 detail=f"Rate limit exceeded. Retry after {retry_after} seconds.",
925 headers={"Retry-After": str(retry_after)},
926 )
927 try:
928 mem.feedback(
929 user_id=effective_user,
930 memory_id=req.memory_id,
931 helpful=req.helpful,
932 namespace=req.namespace,
933 )
934 except (ValueError, AttributeError) as e:
935 raise HTTPException(status_code=400, detail=str(e)) from e
936 return {"status": "ok", "memory_id": req.memory_id, "helpful": req.helpful}
938 @app.get("/health")
939 async def health() -> dict[str, Any]:
940 """Enhanced health check endpoint.
942 Returns:
943 status: "ok" if healthy, "degraded" if issues detected
944 components: Dictionary of component statuses
945 timestamp: ISO timestamp of the health check
946 """
947 components: dict[str, Any] = {}
948 overall_healthy = True
949 mem = memory or _get_memory_singleton()
951 # Check database connectivity
952 try:
953 if hasattr(mem, "_store") and mem._store:
954 conn = mem._store._get_connection() # type: ignore[attr-defined]
955 cursor = conn.execute("SELECT 1")
956 cursor.fetchone()
957 components["database"] = {"status": "healthy", "type": "sqlite"}
958 else:
959 components["database"] = {"status": "unknown", "message": "Storage not initialized"}
960 overall_healthy = False
961 except (sqlite3.Error, AttributeError) as e:
962 components["database"] = {"status": "unhealthy", "error": str(e)}
963 overall_healthy = False
965 # Check embedding adapter availability
966 try:
967 embed = getattr(mem, "_embed", None)
968 if embed is not None:
969 components["embedding"] = {"status": "healthy", "adapter": type(embed).__name__}
970 else:
971 components["embedding"] = {
972 "status": "not_configured",
973 "message": "No embedding adapter",
974 }
975 except AttributeError as e:
976 components["embedding"] = {"status": "unhealthy", "error": str(e)}
977 overall_healthy = False
979 status = "ok" if overall_healthy else "degraded"
981 return {
982 "status": status,
983 "components": components,
984 "timestamp": datetime.now(timezone.utc).isoformat(),
985 }
987 # Background Task Endpoints
989 @app.post("/tasks/embed-batch")
990 async def submit_embed_batch_task(
991 req: BatchRememberRequest, request: Request
992 ) -> dict[str, Any]:
993 """Submit a batch embedding task to run in background.
995 This endpoint returns immediately with a task_id that can be used
996 to check the task status at /tasks/{task_id}.
998 Args:
999 req: BatchRememberRequest with user_id, contents, etc.
1001 Returns:
1002 Dict with task_id for tracking progress.
1003 """
1004 effective_user = _resolve_user_id(request, req.user_id)
1005 from kemi.background_tasks import get_task_manager
1007 task_manager = get_task_manager()
1008 try:
1009 task_id = task_manager.submit_embed_batch(
1010 user_id=effective_user,
1011 contents=req.contents,
1012 importance=req.importance,
1013 namespace=req.namespace,
1014 )
1015 return {"task_id": task_id, "status": "pending"}
1016 except RuntimeError as err:
1017 raise HTTPException(status_code=429, detail=str(err)) from err
1019 @app.post("/tasks/rebuild-fts")
1020 async def submit_rebuild_fts_task(
1021 req: RebuildFTSRequest, request: Request
1022 ) -> dict[str, Any]:
1023 """Submit an FTS index rebuild task to run in background.
1025 This endpoint returns immediately with a task_id that can be used
1026 to check the task status at /tasks/{task_id}.
1028 Args:
1029 req: RebuildFTSRequest with optional user_id filter.
1031 Returns:
1032 Dict with task_id for tracking progress.
1033 """
1034 # Optional user_id: when authed, we always scope to the authed user
1035 # regardless of what the body says (prevents cross-tenant rebuilds).
1036 authed = getattr(request.state, "user_id", None)
1037 target_user: str | None
1038 if authed is not None:
1039 if req.user_id is not None and req.user_id != authed:
1040 raise HTTPException(
1041 status_code=403,
1042 detail="user_id does not match authenticated user",
1043 )
1044 target_user = authed
1045 else:
1046 target_user = req.user_id
1048 from kemi.background_tasks import get_task_manager
1050 task_manager = get_task_manager()
1051 try:
1052 task_id = task_manager.submit_rebuild_fts_index(user_id=target_user)
1053 return {"task_id": task_id, "status": "pending"}
1054 except RuntimeError as err:
1055 raise HTTPException(status_code=429, detail=str(err)) from err
1057 @app.get("/tasks/stats")
1058 async def get_task_stats() -> dict[str, Any]:
1059 """Get background task manager statistics.
1061 Returns:
1062 Dict with counts of pending, running, completed, failed tasks.
1063 """
1064 from kemi.background_tasks import get_task_manager
1066 task_manager = get_task_manager()
1067 return task_manager.get_stats()
1069 @app.get("/tasks/{task_id}")
1070 async def get_task_status(task_id: str) -> dict[str, Any]:
1071 """Get the status of a background task.
1073 Args:
1074 task_id: The task ID returned from submit_* endpoints.
1076 Returns:
1077 Task status including progress, result, or error.
1078 """
1079 from kemi.background_tasks import get_task_manager
1081 task_manager = get_task_manager()
1082 task = task_manager.get_task_status(task_id)
1084 if task is None:
1085 raise HTTPException(status_code=404, detail=f"Task not found: {task_id}")
1087 return task.to_dict()
1089 @app.get("/tasks")
1090 async def list_tasks(
1091 status: str | None = None,
1092 limit: int = 50,
1093 ) -> dict[str, Any]:
1094 """List all background tasks.
1096 Args:
1097 status: Optional filter by status (pending, running, completed, failed).
1098 limit: Maximum number of tasks to return (default 50).
1100 Returns:
1101 List of tasks with their statuses.
1102 """
1103 from kemi.background_tasks import TaskStatus, get_task_manager
1105 task_manager = get_task_manager()
1107 filter_status = None
1108 if status:
1109 try:
1110 filter_status = TaskStatus(status)
1111 except ValueError as exc:
1112 valid_vals = "pending, running, completed, failed"
1113 raise HTTPException(
1114 status_code=400,
1115 detail=f"Invalid status: {status}. Valid values: {valid_vals}",
1116 ) from exc
1118 tasks = task_manager.list_tasks(status=filter_status, limit=limit)
1119 return {
1120 "tasks": [t.to_dict() for t in tasks],
1121 "stats": task_manager.get_stats(),
1122 }
1124 @app.delete("/tasks/{task_id}")
1125 async def cancel_task(task_id: str) -> dict[str, Any]:
1126 """Cancel a pending background task.
1128 Note: Running tasks cannot be cancelled mid-execution.
1130 Args:
1131 task_id: The task ID to cancel.
1133 Returns:
1134 Dict with success status.
1135 """
1136 from kemi.background_tasks import get_task_manager
1138 task_manager = get_task_manager()
1139 cancelled = task_manager.cancel_task(task_id)
1141 if not cancelled:
1142 raise HTTPException(
1143 status_code=400,
1144 detail="Cannot cancel task: not found or already running",
1145 )
1147 return {"task_id": task_id, "cancelled": True}
1149 # Admin Endpoints for Index Maintenance
1151 @app.post("/admin/fts/rebuild")
1152 async def admin_rebuild_fts() -> dict[str, Any]:
1153 """Admin endpoint to rebuild FTS5 index synchronously.
1155 This is a blocking operation that rebuilds the full-text search index.
1156 For large datasets, consider using the background task endpoint instead.
1158 Returns:
1159 Dict with rebuild statistics.
1160 """
1161 mem = _get_memory_singleton()
1163 if not hasattr(mem._store, "rebuild_fts_index"):
1164 raise HTTPException(
1165 status_code=501,
1166 detail="Storage adapter does not support FTS index rebuild",
1167 )
1169 try:
1170 # Rebuild entire FTS index
1171 count = mem._store.rebuild_fts_index()
1172 return {
1173 "status": "completed",
1174 "memories_indexed": count,
1175 "scope": "all",
1176 }
1177 except Exception as e:
1178 logger.error(f"FTS rebuild failed: {e}")
1179 raise HTTPException(status_code=500, detail=str(e)) from e
1181 @app.get("/admin/fts/stats")
1182 async def admin_fts_stats(
1183 user_id: str | None = None, request: Request = None # type: ignore[assignment]
1184 ) -> dict[str, Any]:
1185 """Admin endpoint to get FTS5 index statistics.
1187 Args:
1188 user_id: Optional user ID to get stats for specific user.
1190 Returns:
1191 Dict with FTS index statistics.
1192 """
1193 # If authed, restrict the stats view to the caller's own user.
1194 if request is not None:
1195 authed = getattr(request.state, "user_id", None)
1196 if authed is not None:
1197 if user_id is not None and user_id != authed:
1198 raise HTTPException(
1199 status_code=403,
1200 detail="user_id does not match authenticated user",
1201 )
1202 user_id = authed
1204 mem = _get_memory_singleton()
1206 try:
1207 conn = mem._store._get_connection() # type: ignore[attr-defined]
1209 # Get total FTS entries
1210 cursor = conn.execute("SELECT COUNT(*) FROM memories_fts")
1211 fts_total = cursor.fetchone()[0]
1213 # Get total memories in main table
1214 if user_id:
1215 cursor = conn.execute("SELECT COUNT(*) FROM memories WHERE user_id = ?", (user_id,))
1216 mem_total = cursor.fetchone()[0]
1218 # Get FTS entries for this user
1219 cursor = conn.execute(
1220 "SELECT COUNT(*) FROM memories_fts WHERE user_id = ?", (user_id,)
1221 )
1222 fts_user = cursor.fetchone()[0]
1224 in_sync = (fts_user == mem_total) if mem_total > 0 else True
1226 return {
1227 "fts_total_entries": fts_total,
1228 "user_id": user_id,
1229 "user_memories": mem_total,
1230 "user_fts_entries": fts_user,
1231 "in_sync": in_sync,
1232 "sync_gap": mem_total - fts_user,
1233 }
1234 else:
1235 cursor = conn.execute("SELECT COUNT(*) FROM memories")
1236 mem_total = cursor.fetchone()[0]
1238 in_sync = (fts_total == mem_total) if mem_total > 0 else True
1240 return {
1241 "fts_total_entries": fts_total,
1242 "total_memories": mem_total,
1243 "in_sync": in_sync,
1244 "sync_gap": mem_total - fts_total,
1245 }
1246 except Exception as e:
1247 logger.error(f"FTS stats failed: {e}")
1248 raise HTTPException(status_code=500, detail=str(e)) from e
1250 @app.post("/admin/fts/verify")
1251 async def admin_fts_verify(req: AdminFTSRepairRequest) -> dict[str, Any]:
1252 """Admin endpoint to verify FTS5 index integrity.
1254 Checks if all memories have corresponding FTS entries and vice versa.
1256 Args:
1257 req: AdminFTSRepairRequest with verify_only flag.
1259 Returns:
1260 Dict with verification results.
1261 """
1262 mem = _get_memory_singleton()
1264 try:
1265 conn = mem._store._get_connection() # type: ignore[attr-defined]
1267 # Get all memory IDs
1268 cursor = conn.execute("SELECT memory_id FROM memories")
1269 memory_ids = set(row[0] for row in cursor.fetchall())
1271 # Get all FTS IDs
1272 cursor = conn.execute("SELECT memory_id FROM memories_fts")
1273 fts_ids = set(row[0] for row in cursor.fetchall())
1275 # Find discrepancies
1276 missing_from_fts = memory_ids - fts_ids
1277 orphaned_in_fts = fts_ids - memory_ids
1279 in_sync = len(missing_from_fts) == 0 and len(orphaned_in_fts) == 0
1281 result = {
1282 "status": "ok" if in_sync else "degraded",
1283 "total_memories": len(memory_ids),
1284 "total_fts_entries": len(fts_ids),
1285 "in_sync": in_sync,
1286 "missing_from_fts": len(missing_from_fts),
1287 "orphaned_in_fts": len(orphaned_in_fts),
1288 }
1290 if not in_sync and not req.verify_only:
1291 # Auto-repair: remove orphaned FTS entries
1292 if orphaned_in_fts:
1293 placeholders = ",".join("?" * len(orphaned_in_fts))
1294 conn.execute(
1295 f"DELETE FROM memories_fts WHERE memory_id IN ({placeholders})",
1296 list(orphaned_in_fts),
1297 )
1298 result["repaired_orphaned"] = len(orphaned_in_fts)
1300 result["auto_repaired"] = True
1301 result["status"] = "repaired"
1303 return result
1305 except Exception as e:
1306 logger.error(f"FTS verify failed: {e}")
1307 raise HTTPException(status_code=500, detail=str(e)) from e
1309 @app.get("/admin/health")
1310 async def admin_health() -> dict[str, Any]:
1311 """Admin health check endpoint with detailed system status.
1313 Returns:
1314 Dict with detailed component statuses and system metrics.
1315 """
1316 mem = _get_memory_singleton()
1317 components: dict[str, Any] = {}
1319 # Database health
1320 try:
1321 conn = mem._store._get_connection() # type: ignore[attr-defined]
1322 cursor = conn.execute("SELECT 1")
1323 cursor.fetchone()
1325 # Get database stats
1326 cursor = conn.execute("SELECT COUNT(*) FROM memories")
1327 total_memories = cursor.fetchone()[0]
1329 cursor = conn.execute("SELECT COUNT(*) FROM memories_fts")
1330 fts_entries = cursor.fetchone()[0]
1332 cursor = conn.execute("SELECT COUNT(DISTINCT user_id) FROM memories")
1333 total_users = cursor.fetchone()[0]
1335 components["database"] = {
1336 "status": "healthy",
1337 "type": "sqlite",
1338 "total_memories": total_memories,
1339 "total_users": total_users,
1340 "fts_entries": fts_entries,
1341 "fts_in_sync": total_memories == fts_entries,
1342 }
1343 except Exception as e:
1344 components["database"] = {
1345 "status": "unhealthy",
1346 "error": str(e),
1347 }
1349 # Embedding adapter health
1350 try:
1351 embed = getattr(mem, "_embed", None)
1352 if embed is not None:
1353 adapter_name = type(embed).__name__
1354 components["embedding"] = {
1355 "status": "healthy",
1356 "adapter": adapter_name,
1357 }
1359 # Check circuit breaker if available
1360 if hasattr(embed, "get_circuit_breaker_state"):
1361 cb_state = embed.get_circuit_breaker_state()
1362 components["embedding"]["circuit_breaker"] = cb_state
1363 else:
1364 components["embedding"] = {
1365 "status": "not_configured",
1366 }
1367 except Exception as e:
1368 components["embedding"] = {
1369 "status": "unhealthy",
1370 "error": str(e),
1371 }
1373 # Background task manager health
1374 try:
1375 from kemi.background_tasks import get_task_manager
1377 task_manager = get_task_manager()
1378 stats = task_manager.get_stats()
1379 components["task_manager"] = {
1380 "status": "healthy",
1381 "pending": stats["pending"],
1382 "running": stats["running"],
1383 "completed": stats["completed"],
1384 "failed": stats["failed"],
1385 }
1386 except Exception as e:
1387 components["task_manager"] = {
1388 "status": "unhealthy",
1389 "error": str(e),
1390 }
1392 # Determine overall status
1393 all_healthy = all(c.get("status") == "healthy" for c in components.values())
1396 return {
1397 "status": "ok" if all_healthy else "degraded",
1398 "components": components,
1399 "timestamp": datetime.now(timezone.utc).isoformat(),
1400 }
1402 # Observability / Metrics Endpoints
1404 @app.get("/metrics")
1405 async def get_metrics(output_format: str = "json") -> Any:
1406 """Get system metrics.
1408 Args:
1409 format: Output format — "json" or "prometheus".
1411 Returns:
1412 Metrics data in the requested format.
1413 """
1414 metrics_data = mem.get_metrics()
1415 if metrics_data is None:
1416 raise HTTPException(
1417 status_code=503,
1418 detail="Metrics collector not available",
1419 )
1421 if output_format.lower() == "prometheus":
1422 prom = mem.get_metrics_prometheus()
1423 if prom is None:
1424 raise HTTPException(
1425 status_code=503,
1426 detail="Metrics collector not available",
1427 )
1428 return PlainTextResponse(content=prom, media_type="text/plain")
1430 return metrics_data
1432 # Audit Trail Endpoints
1434 @app.post("/audit/log")
1435 async def audit_log(req: AuditLogRequest, request: Request) -> dict[str, Any]:
1436 """Log an operation to the audit trail.
1438 Returns:
1439 Dict with the entry ID of the logged operation.
1440 """
1441 effective_user = _resolve_user_id(request, req.user_id)
1442 if not hasattr(mem, "_audit_trail") or mem._audit_trail is None:
1443 raise HTTPException(
1444 status_code=503,
1445 detail="Audit trail not enabled. Use POST /admin/enable-audit first.",
1446 )
1448 try:
1449 entry_id = mem._audit_trail.log_operation(
1450 user_id=effective_user,
1451 operation=req.operation,
1452 details=req.details or {},
1453 memory_id=req.memory_id,
1454 namespace=req.namespace,
1455 status=req.status,
1456 client_ip=req.client_ip,
1457 user_agent=req.user_agent,
1458 duration_ms=req.duration_ms,
1459 )
1460 return {"entry_id": entry_id, "status": "logged"}
1461 except Exception as e:
1462 logger.error(f"Audit log failed: {e}")
1463 raise HTTPException(status_code=500, detail=str(e)) from e
1465 @app.post("/audit/query")
1466 async def audit_query(
1467 req: AuditQueryRequest, request: Request
1468 ) -> dict[str, Any]:
1469 """Query the audit trail with filters.
1471 Returns:
1472 Dict with list of matching audit entries and total count.
1473 """
1474 if not hasattr(mem, "_audit_trail") or mem._audit_trail is None:
1475 raise HTTPException(
1476 status_code=503,
1477 detail="Audit trail not enabled. Use POST /admin/enable-audit first.",
1478 )
1480 # When authed, force the user_id filter to the caller's own id
1481 # so a tenant cannot read another tenant's audit entries.
1482 authed = getattr(request.state, "user_id", None)
1483 query_user = req.user_id
1484 if authed is not None:
1485 if query_user is not None and query_user != authed:
1486 raise HTTPException(
1487 status_code=403,
1488 detail="user_id does not match authenticated user",
1489 )
1490 query_user = authed
1492 try:
1493 entries = mem._audit_trail.query(
1494 user_id=query_user,
1495 operation=req.operation,
1496 status=req.status,
1497 memory_id=req.memory_id,
1498 namespace=req.namespace,
1499 start_time=req.start_time,
1500 end_time=req.end_time,
1501 limit=req.limit,
1502 offset=req.offset,
1503 )
1504 return {
1505 "entries": [e.to_dict() for e in entries],
1506 "count": len(entries),
1507 "limit": req.limit,
1508 "offset": req.offset,
1509 }
1510 except Exception as e:
1511 logger.error(f"Audit query failed: {e}")
1512 raise HTTPException(status_code=500, detail=str(e)) from e
1514 @app.get("/audit/stats")
1515 async def audit_stats() -> dict[str, Any]:
1516 """Get overall audit trail statistics.
1518 Returns:
1519 Dict with total entries, unique users, date range, retention policy.
1520 """
1521 if not hasattr(mem, "_audit_trail") or mem._audit_trail is None:
1522 raise HTTPException(
1523 status_code=503,
1524 detail="Audit trail not enabled. Use POST /admin/enable-audit first.",
1525 )
1527 try:
1528 return mem._audit_trail.get_stats() # type: ignore[no-any-return]
1529 except Exception as e:
1530 logger.error(f"Audit stats failed: {e}")
1531 raise HTTPException(status_code=500, detail=str(e)) from e
1533 @app.post("/audit/export")
1534 async def audit_export(
1535 req: AuditExportRequest, request: Request
1536 ) -> dict[str, Any]:
1537 """Export audit entries for compliance.
1539 Returns:
1540 Dict with exported entries.
1541 """
1542 if not hasattr(mem, "_audit_trail") or mem._audit_trail is None:
1543 raise HTTPException(
1544 status_code=503,
1545 detail="Audit trail not enabled. Use POST /admin/enable-audit first.",
1546 )
1548 # Same isolation rule as audit_query.
1549 authed = getattr(request.state, "user_id", None)
1550 export_user = req.user_id
1551 if authed is not None:
1552 if export_user is not None and export_user != authed:
1553 raise HTTPException(
1554 status_code=403,
1555 detail="user_id does not match authenticated user",
1556 )
1557 export_user = authed
1559 try:
1560 entries = mem._audit_trail.export(
1561 start_time=req.start_time,
1562 end_time=req.end_time,
1563 user_id=export_user,
1564 )
1565 return {"entries": entries, "count": len(entries)}
1566 except Exception as e:
1567 logger.error(f"Audit export failed: {e}")
1568 raise HTTPException(status_code=500, detail=str(e)) from e
1570 # Adaptive Retrieval Endpoints
1572 @app.post("/adaptive/analyze")
1573 async def adaptive_analyze(req: AdaptiveAnalyzeRequest) -> dict[str, Any]:
1574 """Analyze a query and return adaptive retrieval weights.
1576 Returns:
1577 Dict with query classification, confidence, and recommended weights.
1578 """
1579 if not hasattr(mem, "_adaptive_retriever") or mem._adaptive_retriever is None:
1580 raise HTTPException(
1581 status_code=503,
1582 detail="Adaptive retrieval not enabled. Use POST /admin/enable-adaptive first.",
1583 )
1585 try:
1586 profile = mem._adaptive_retriever.analyze_query(req.query)
1587 return {
1588 "query": profile.query,
1589 "query_type": profile.query_type.value,
1590 "confidence": profile.confidence,
1591 "word_count": profile.word_count,
1592 "keyword_density": profile.keyword_density,
1593 "specificity": profile.specificity,
1594 "has_question_mark": profile.has_question_mark,
1595 "has_named_entity_hint": profile.has_named_entity_hint,
1596 "recommended_weights": profile.recommended_weights,
1597 }
1598 except Exception as e:
1599 logger.error(f"Adaptive analyze failed: {e}")
1600 raise HTTPException(status_code=500, detail=str(e)) from e
1602 @app.get("/adaptive/user-profile/{user_id}")
1603 async def adaptive_user_profile(
1604 user_id: str, request: Request
1605 ) -> dict[str, Any]:
1606 """Get the adaptive query type distribution for a user.
1608 Returns:
1609 Dict with query distribution and dominant type.
1610 """
1611 effective_user = _resolve_user_id(request, user_id)
1612 if not hasattr(mem, "_adaptive_retriever") or mem._adaptive_retriever is None:
1613 raise HTTPException(
1614 status_code=503,
1615 detail="Adaptive retrieval not enabled. Use POST /admin/enable-adaptive first.",
1616 )
1618 try:
1619 return mem._adaptive_retriever.get_user_profile(effective_user) # type: ignore[no-any-return]
1620 except Exception as e:
1621 logger.error(f"Adaptive user profile failed: {e}")
1622 raise HTTPException(status_code=500, detail=str(e)) from e
1624 # Admin Feature Toggle Endpoints
1626 @app.post("/admin/enable-audit")
1627 async def admin_enable_audit(req: EnableFeatureRequest) -> dict[str, Any]:
1628 """Enable or disable the audit trail.
1630 Returns:
1631 Dict with enabled status.
1632 """
1633 if not hasattr(mem, "enable_audit_trail"):
1634 raise HTTPException(
1635 status_code=501,
1636 detail="Memory instance does not support audit trail",
1637 )
1639 try:
1640 if req.enable:
1641 mem.enable_audit_trail(
1642 retention_days=req.retention_days,
1643 auto_purge=req.auto_purge,
1644 )
1645 else:
1646 mem._audit_trail = None
1648 return {
1649 "audit_trail_enabled": req.enable,
1650 "retention_days": req.retention_days,
1651 "auto_purge": req.auto_purge,
1652 }
1653 except Exception as e:
1654 logger.error(f"Enable audit trail failed: {e}")
1655 raise HTTPException(status_code=500, detail=str(e)) from e
1657 @app.post("/admin/enable-adaptive")
1658 async def admin_enable_adaptive(req: EnableFeatureRequest) -> dict[str, Any]:
1659 """Enable or disable adaptive retrieval.
1661 Returns:
1662 Dict with enabled status.
1663 """
1664 if not hasattr(mem, "enable_adaptive_retrieval"):
1665 raise HTTPException(
1666 status_code=501,
1667 detail="Memory instance does not support adaptive retrieval",
1668 )
1670 try:
1671 mem.enable_adaptive_retrieval(enable=req.enable)
1672 return {"adaptive_retrieval_enabled": req.enable}
1673 except Exception as e:
1674 logger.error(f"Enable adaptive retrieval failed: {e}")
1675 raise HTTPException(status_code=500, detail=str(e)) from e
1677 # API Key Management Endpoints
1679 @app.post("/api/keys")
1680 async def create_api_key(
1681 req: CreateAPIKeyRequest, request: Request
1682 ) -> dict[str, Any]:
1683 """Create a new API key for a user.
1685 The raw key is returned in the response exactly once; it cannot
1686 be retrieved later. Store it securely.
1688 When the caller is authenticated, the key is bound to the
1689 caller's own user_id (a 403 is raised if the body disagrees).
1690 When unauthenticated, the body's user_id is used — this is the
1691 bootstrap path for a brand-new tenant.
1692 """
1693 manager = _get_api_key_manager()
1694 if manager is None:
1695 raise HTTPException(
1696 status_code=501,
1697 detail="API key management not supported by this storage",
1698 )
1700 authed = getattr(request.state, "user_id", None)
1701 if authed is not None and req.user_id != authed:
1702 raise HTTPException(
1703 status_code=403,
1704 detail="user_id does not match authenticated user",
1705 )
1707 from kemi.api_keys import make_expiry
1709 expires_at = make_expiry(req.expires_in_days) if req.expires_in_days else None
1710 try:
1711 key = manager.create_key(
1712 user_id=req.user_id,
1713 name=req.name,
1714 expires_at=expires_at,
1715 )
1716 except (ValueError, RuntimeError) as e:
1717 raise HTTPException(status_code=400, detail=str(e)) from e
1719 return key.to_dict(include_secret=True)
1721 @app.get("/api/keys")
1722 async def list_api_keys(request: Request) -> dict[str, Any]:
1723 """List API keys. Authenticated callers see only their own keys."""
1724 manager = _get_api_key_manager()
1725 if manager is None:
1726 raise HTTPException(
1727 status_code=501,
1728 detail="API key management not supported by this storage",
1729 )
1731 authed = getattr(request.state, "user_id", None)
1732 # When authed, always scope to the caller's user_id regardless of
1733 # any user_id query param. When unauthed, the listing is global
1734 # (admin-style view).
1735 scope_user = authed
1736 keys = manager.list_keys(user_id=scope_user)
1737 return {
1738 "keys": [k.to_dict() for k in keys],
1739 "count": len(keys),
1740 }
1742 @app.delete("/api/keys/{key_id}")
1743 async def revoke_api_key(key_id: str, request: Request) -> dict[str, Any]:
1744 """Revoke an API key by id.
1746 Authenticated callers may only revoke their own keys.
1747 """
1748 manager = _get_api_key_manager()
1749 if manager is None:
1750 raise HTTPException(
1751 status_code=501,
1752 detail="API key management not supported by this storage",
1753 )
1755 authed = getattr(request.state, "user_id", None)
1756 if authed is not None:
1757 existing = manager.get(key_id)
1758 if existing is None:
1759 raise HTTPException(status_code=404, detail="Key not found")
1760 if existing.user_id != authed:
1761 raise HTTPException(
1762 status_code=403,
1763 detail="Cannot revoke a key belonging to another user",
1764 )
1766 if not manager.revoke(key_id):
1767 raise HTTPException(
1768 status_code=404,
1769 detail="Key not found or already revoked",
1770 )
1771 return {"key_id": key_id, "revoked": True}
1773 # Memory Version History Endpoint
1775 @app.get("/memories/{memory_id}/history")
1776 async def get_memory_history(memory_id: str, request: Request, limit: int = 100) -> dict[str, Any]:
1777 """Get version history for a memory."""
1778 try:
1779 mem.configure_versioning()
1780 history = mem.get_history(memory_id, limit=limit)
1781 except RuntimeError as e:
1782 raise HTTPException(status_code=501, detail=str(e)) from e
1784 return {
1785 "memory_id": memory_id,
1786 "versions": [
1787 {
1788 "version": snap.version,
1789 "content": snap.content,
1790 "importance": snap.importance,
1791 "tags": snap.tags,
1792 "memory_type": snap.memory_type,
1793 "confidence": snap.confidence,
1794 "namespace": snap.namespace,
1795 "source": snap.source,
1796 "changed_at": snap.changed_at.isoformat() if snap.changed_at else None,
1797 "changed_by": snap.changed_by,
1798 }
1799 for snap in history
1800 ],
1801 "count": len(history),
1802 }
1804 # Webhook Management Endpoints
1806 class CreateWebhookRequest(BaseModel):
1807 url: str = Field(..., min_length=1)
1808 events: list[str] = Field(..., min_length=1)
1809 secret: str = ""
1810 active: bool = True
1812 class UpdateWebhookRequest(BaseModel):
1813 url: str | None = None
1814 events: list[str] | None = None
1815 secret: str | None = None
1816 active: bool | None = None
1818 def _get_webhook_store() -> WebhookStore | None:
1819 """Get or create a WebhookStore bound to the active memory's database."""
1820 try:
1821 db_path = mem._store._db_path # type: ignore[attr-defined]
1822 return WebhookStore(db_path=db_path)
1823 except (AttributeError, Exception):
1824 return None
1826 @app.post("/webhooks", status_code=201)
1827 async def create_webhook(req: CreateWebhookRequest, request: Request) -> dict[str, Any]:
1828 """Register a new webhook endpoint."""
1829 store = _get_webhook_store()
1830 if store is None:
1831 raise HTTPException(
1832 status_code=501,
1833 detail="Webhook store not available (storage adapter does not expose db_path)",
1834 )
1836 try:
1837 event_types = [WebhookEventType.from_string(e) for e in req.events]
1838 except ValueError as e:
1839 raise HTTPException(status_code=400, detail=str(e)) from e
1841 cfg = WebhookConfig(
1842 webhook_id="",
1843 url=req.url,
1844 events=event_types,
1845 secret=req.secret,
1846 active=req.active,
1847 )
1848 wh_id = store.create(cfg)
1849 return {"webhook_id": wh_id, "url": req.url, "events": req.events, "active": req.active}
1851 @app.get("/webhooks")
1852 async def list_webhooks(request: Request) -> dict[str, Any]:
1853 """List all registered webhooks."""
1854 store = _get_webhook_store()
1855 if store is None:
1856 raise HTTPException(status_code=501, detail="Webhook store not available")
1858 configs = store.list_all(active_only=False)
1859 return {
1860 "webhooks": [
1861 {
1862 "webhook_id": c.webhook_id,
1863 "url": c.url,
1864 "events": [e.value for e in c.events],
1865 "active": c.active,
1866 }
1867 for c in configs
1868 ],
1869 "count": len(configs),
1870 }
1872 @app.delete("/webhooks/{webhook_id}")
1873 async def delete_webhook(webhook_id: str, request: Request) -> dict[str, Any]:
1874 """Delete a webhook configuration."""
1875 store = _get_webhook_store()
1876 if store is None:
1877 raise HTTPException(status_code=501, detail="Webhook store not available")
1879 if not store.delete(webhook_id):
1880 raise HTTPException(status_code=404, detail=f"Webhook not found: {webhook_id}")
1881 return {"webhook_id": webhook_id, "deleted": True}
1883 # Admin endpoint: list users with their memory counts
1885 @app.get("/admin/users")
1886 async def admin_list_users(request: Request) -> dict[str, Any]:
1887 """List all users and their memory counts.
1889 When authenticated, only the caller's own row is returned.
1890 When unauthenticated (default in backward-compat mode), the
1891 full list is returned.
1892 """
1893 authed = getattr(request.state, "user_id", None)
1894 all_users = mem.list_users()
1895 if authed is not None:
1896 users = [u for u in all_users if u == authed]
1897 else:
1898 users = all_users
1900 store = getattr(mem, "_store", None)
1901 rows: list[dict[str, Any]] = []
1902 for uid in users:
1903 try:
1904 count = store.count(uid) if store is not None else 0
1905 except Exception: # pragma: no cover - defensive
1906 count = 0
1907 row: dict[str, Any] = {"user_id": uid, "memory_count": count}
1908 # Optional metadata — best effort, never fail the response.
1909 last_active = getattr(store, "get_last_active", lambda _u: None)(uid)
1910 if last_active is not None:
1911 row["last_active"] = last_active
1912 rows.append(row)
1913 return {"users": rows, "count": len(rows)}
1915 return app