Coverage for src / kemi / webhooks.py: 95%
171 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1"""Webhook callbacks for memory lifecycle events.
3Event types: memory.remembered, memory.updated, memory.forgotten,
4memory.deleted, memory.conflict, memory.consolidated.
6Requires httpx (dev dependency) for HTTP dispatch.
7"""
9import hashlib
10import hmac
11import json
12import logging
13import sqlite3
14import time
15import uuid
16from dataclasses import dataclass, field
17from datetime import datetime, timezone
18from enum import Enum
19from typing import Any
21logger = logging.getLogger(__name__)
23# ---------------------------------------------------------------------------
24# Event types
25# ---------------------------------------------------------------------------
28class WebhookEventType(str, Enum):
29 """Enum of supported memory lifecycle event types."""
31 REMEMBERED = "memory.remembered"
32 UPDATED = "memory.updated"
33 FORGOTTEN = "memory.forgotten" # soft delete (lifecycle transition)
34 DELETED = "memory.deleted" # hard delete
35 CONFLICT = "memory.conflict"
36 CONSOLIDATED = "memory.consolidated"
38 @classmethod
39 def from_string(cls, value: str) -> "WebhookEventType":
40 """Parse a string to an event type, raising ValueError on failure."""
41 try:
42 return cls(value)
43 except ValueError:
44 valid = ", ".join(e.value for e in cls)
45 raise ValueError(f"Invalid event type '{value}'. Valid: {valid}")
48# ---------------------------------------------------------------------------
49# Retry config
50# ---------------------------------------------------------------------------
53@dataclass
54class RetryConfig:
55 """Configuration for webhook retry with exponential backoff."""
57 max_retries: int = 5
58 base_delay_seconds: float = 1.0
59 max_delay_seconds: float = 60.0
60 backoff_multiplier: float = 2.0
62 def delay(self, attempt: int) -> float:
63 """Compute delay in seconds for a given attempt (0-indexed)."""
64 delay = self.base_delay_seconds * (self.backoff_multiplier ** attempt)
65 return min(delay, self.max_delay_seconds)
68# ---------------------------------------------------------------------------
69# Webhook config model
70# ---------------------------------------------------------------------------
73@dataclass
74class WebhookConfig:
75 """Configuration for a single webhook endpoint."""
77 webhook_id: str
78 url: str
79 events: list[WebhookEventType] = field(default_factory=list)
80 secret: str = ""
81 active: bool = True
82 retry_config: RetryConfig = field(default_factory=RetryConfig)
84 def matches_event(self, event: WebhookEventType) -> bool:
85 """Return True if this webhook is active and subscribes to *event*."""
86 return self.active and event in self.events
89# ---------------------------------------------------------------------------
90# Payload building
91# ---------------------------------------------------------------------------
94def build_payload(
95 event: WebhookEventType,
96 memory_id: str,
97 user_id: str,
98 snapshot: dict[str, Any] | None = None,
99 previous_state: dict[str, Any] | None = None,
100 extra: dict[str, Any] | None = None,
101) -> dict[str, Any]:
102 """Build the JSON-serialisable payload sent to webhook endpoints.
104 Args:
105 event: The lifecycle event type.
106 memory_id: ID of the memory that triggered the event.
107 user_id: ID of the memory's owner.
108 snapshot: Current snapshot of the memory fields (optional).
109 previous_state: Snapshot of the memory *before* the change (only for
110 ``memory.updated``).
111 extra: Any additional key/value pairs to merge into the payload.
112 """
113 payload: dict[str, Any] = {
114 "event": event.value,
115 "memory_id": memory_id,
116 "user_id": user_id,
117 "timestamp": datetime.now(timezone.utc).isoformat(),
118 }
119 if snapshot is not None:
120 payload["snapshot"] = snapshot
121 if previous_state is not None:
122 payload["previous_state"] = previous_state
123 if extra:
124 payload.update(extra)
125 return payload
128# ---------------------------------------------------------------------------
129# HMAC signature
130# ---------------------------------------------------------------------------
133def sign_payload(payload: dict[str, Any], secret: str) -> str:
134 """Compute ``X-Kemi-Signature`` header value for a payload.
136 Uses HMAC-SHA256 with the webhook's *secret* over the JSON-serialised
137 payload (sorted keys, no whitespace).
138 """
139 body = json.dumps(payload, sort_keys=True, separators=(",", ":"))
140 mac = hmac.new(secret.encode("utf-8"), body.encode("utf-8"), hashlib.sha256)
141 return mac.hexdigest()
144# ---------------------------------------------------------------------------
145# Persistent storage
146# ---------------------------------------------------------------------------
148_WEBHOOKS_DDL = """
149CREATE TABLE IF NOT EXISTS webhooks (
150 webhook_id TEXT PRIMARY KEY,
151 url TEXT NOT NULL,
152 events TEXT NOT NULL, -- JSON array of event type strings
153 secret TEXT NOT NULL DEFAULT '',
154 active INTEGER NOT NULL DEFAULT 1,
155 retry_max INTEGER NOT NULL DEFAULT 5,
156 retry_base_delay REAL NOT NULL DEFAULT 1.0,
157 retry_max_delay REAL NOT NULL DEFAULT 60.0,
158 retry_multiplier REAL NOT NULL DEFAULT 2.0,
159 created_at TEXT NOT NULL,
160 updated_at TEXT NOT NULL
161);
162"""
165class WebhookStore:
166 """Persistent storage for webhook configurations.
168 Uses a separate ``webhooks`` table in the same SQLite database as the
169 main memory store (or any other SQLite database path).
170 """
172 def __init__(self, db_path: str) -> None:
173 self._db_path = db_path
174 self._init_schema()
176 def _get_connection(self) -> sqlite3.Connection:
177 conn = sqlite3.connect(self._db_path)
178 conn.row_factory = sqlite3.Row
179 conn.execute("PRAGMA journal_mode=WAL")
180 return conn
182 def _init_schema(self) -> None:
183 with self._get_connection() as conn:
184 conn.execute(_WEBHOOKS_DDL)
186 # ------------------------------------------------------------------
187 # CRUD
188 # ------------------------------------------------------------------
190 def create(self, cfg: WebhookConfig) -> str:
191 """Persist a new webhook config. Returns the webhook_id."""
192 if not cfg.webhook_id:
193 cfg.webhook_id = str(uuid.uuid4())
194 now = datetime.now(timezone.utc).isoformat()
195 events_json = json.dumps([e.value for e in cfg.events])
196 with self._get_connection() as conn:
197 conn.execute(
198 """INSERT INTO webhooks
199 (webhook_id, url, events, secret, active,
200 retry_max, retry_base_delay, retry_max_delay,
201 retry_multiplier, created_at, updated_at)
202 VALUES (?, ?, ?, ?, ?,
203 ?, ?, ?,
204 ?, ?, ?)""",
205 (
206 cfg.webhook_id,
207 cfg.url,
208 events_json,
209 cfg.secret,
210 1 if cfg.active else 0,
211 cfg.retry_config.max_retries,
212 cfg.retry_config.base_delay_seconds,
213 cfg.retry_config.max_delay_seconds,
214 cfg.retry_config.backoff_multiplier,
215 now,
216 now,
217 ),
218 )
219 return cfg.webhook_id
221 def get(self, webhook_id: str) -> WebhookConfig | None:
222 """Retrieve a webhook config by ID."""
223 with self._get_connection() as conn:
224 row = conn.execute(
225 "SELECT * FROM webhooks WHERE webhook_id = ?",
226 (webhook_id,),
227 ).fetchone()
228 if row is None:
229 return None
230 return self._row_to_config(row)
232 def list_all(self, active_only: bool = True) -> list[WebhookConfig]:
233 """Return all webhook configs, optionally only active ones."""
234 with self._get_connection() as conn:
235 if active_only:
236 rows = conn.execute(
237 "SELECT * FROM webhooks WHERE active = 1 ORDER BY created_at"
238 ).fetchall()
239 else:
240 rows = conn.execute(
241 "SELECT * FROM webhooks ORDER BY created_at"
242 ).fetchall()
243 return [self._row_to_config(r) for r in rows]
245 def list_for_event(self, event: WebhookEventType) -> list[WebhookConfig]:
246 """Return active webhooks that subscribe to *event*."""
247 all_webhooks = self.list_all(active_only=True)
248 return [w for w in all_webhooks if w.matches_event(event)]
250 def delete(self, webhook_id: str) -> bool:
251 """Delete a webhook config. Returns True if a row was deleted."""
252 with self._get_connection() as conn:
253 cursor = conn.execute(
254 "DELETE FROM webhooks WHERE webhook_id = ?",
255 (webhook_id,),
256 )
257 return cursor.rowcount > 0
259 def update(self, cfg: WebhookConfig) -> bool:
260 """Update an existing webhook config. Returns True if updated."""
261 now = datetime.now(timezone.utc).isoformat()
262 events_json = json.dumps([e.value for e in cfg.events])
263 with self._get_connection() as conn:
264 cursor = conn.execute(
265 """UPDATE webhooks SET
266 url = ?, events = ?, secret = ?, active = ?,
267 retry_max = ?, retry_base_delay = ?, retry_max_delay = ?,
268 retry_multiplier = ?, updated_at = ?
269 WHERE webhook_id = ?""",
270 (
271 cfg.url,
272 events_json,
273 cfg.secret,
274 1 if cfg.active else 0,
275 cfg.retry_config.max_retries,
276 cfg.retry_config.base_delay_seconds,
277 cfg.retry_config.max_delay_seconds,
278 cfg.retry_config.backoff_multiplier,
279 now,
280 cfg.webhook_id,
281 ),
282 )
283 return cursor.rowcount > 0
285 @staticmethod
286 def _row_to_config(row: sqlite3.Row) -> WebhookConfig:
287 events_str = row["events"]
288 try:
289 events_list = json.loads(events_str) if events_str else []
290 except json.JSONDecodeError:
291 events_list = []
292 return WebhookConfig(
293 webhook_id=row["webhook_id"],
294 url=row["url"],
295 events=[WebhookEventType(e) for e in events_list],
296 secret=row["secret"],
297 active=bool(row["active"]),
298 retry_config=RetryConfig(
299 max_retries=row["retry_max"],
300 base_delay_seconds=row["retry_base_delay"],
301 max_delay_seconds=row["retry_max_delay"],
302 backoff_multiplier=row["retry_multiplier"],
303 ),
304 )
307# ---------------------------------------------------------------------------
308# Dispatcher
309# ---------------------------------------------------------------------------
312class WebhookDispatcher:
313 """Dispatches webhook callbacks to registered endpoints.
315 Supports both synchronous (blocking) and asynchronous (non-blocking)
316 dispatch with automatic retry via exponential backoff.
317 """
319 def __init__(self, store: WebhookStore) -> None:
320 self._store = store
322 # ------------------------------------------------------------------
323 # Synchronous dispatch
324 # ------------------------------------------------------------------
326 def dispatch_sync(self, payload: dict[str, Any], event: WebhookEventType) -> list[dict[str, Any]]:
327 """Dispatch *payload* synchronously to all subscribers of *event*.
329 This blocks until each webhook call completes (including retries).
330 Returns a list of result dicts (one per webhook).
331 """
332 results: list[dict[str, Any]] = []
333 for wh in self._store.list_for_event(event):
334 result = self._call_with_retry_sync(wh, payload)
335 results.append(result)
336 return results
338 def _call_with_retry_sync(
339 self,
340 wh: WebhookConfig,
341 payload: dict[str, Any],
342 ) -> dict[str, Any]:
343 """Call a single webhook synchronously with retry logic."""
344 import httpx
346 signature = sign_payload(payload, wh.secret)
347 headers = {
348 "Content-Type": "application/json",
349 "X-Kemi-Signature": signature,
350 "User-Agent": "kemi-webhook/1.0",
351 }
352 body = json.dumps(payload)
354 last_error: str | None = None
355 for attempt in range(wh.retry_config.max_retries):
356 try:
357 with httpx.Client(timeout=10.0) as client:
358 resp = client.post(wh.url, content=body, headers=headers)
359 if resp.is_success:
360 logger.info(
361 "Webhook %s dispatched to %s (attempt %d)",
362 wh.webhook_id, wh.url, attempt + 1,
363 )
364 return {
365 "webhook_id": wh.webhook_id,
366 "url": wh.url,
367 "status_code": resp.status_code,
368 "success": True,
369 }
370 last_error = f"HTTP {resp.status_code}: {resp.text[:200]}"
371 logger.warning(
372 "Webhook %s attempt %d failed: %s",
373 wh.webhook_id, attempt + 1, last_error,
374 )
375 except Exception as exc:
376 last_error = str(exc)
377 logger.warning(
378 "Webhook %s attempt %d error: %s",
379 wh.webhook_id, attempt + 1, last_error,
380 )
382 if attempt < wh.retry_config.max_retries - 1:
383 delay = wh.retry_config.delay(attempt)
384 logger.info("Retrying webhook %s in %.1fs...", wh.webhook_id, delay)
385 time.sleep(delay)
387 logger.error(
388 "Webhook %s failed after %d retries: %s",
389 wh.webhook_id, wh.retry_config.max_retries, last_error,
390 )
391 return {
392 "webhook_id": wh.webhook_id,
393 "url": wh.url,
394 "error": last_error,
395 "success": False,
396 }
398 # ------------------------------------------------------------------
399 # Asynchronous dispatch
400 # ------------------------------------------------------------------
402 async def dispatch_async(
403 self,
404 payload: dict[str, Any],
405 event: WebhookEventType,
406 ) -> list[dict[str, Any]]:
407 """Dispatch *payload* asynchronously to all subscribers of *event*.
409 Non-blocking — returns immediately after scheduling all HTTP calls.
410 Each call runs its retry loop in an async task.
411 Returns a list of result dicts (one per webhook).
412 """
413 import asyncio
414 import httpx
416 async def _call(wh: WebhookConfig) -> dict[str, Any]:
417 signature = sign_payload(payload, wh.secret)
418 headers = {
419 "Content-Type": "application/json",
420 "X-Kemi-Signature": signature,
421 "User-Agent": "kemi-webhook/1.0",
422 }
423 body = json.dumps(payload)
425 last_error: str | None = None
426 for attempt in range(wh.retry_config.max_retries):
427 try:
428 async with httpx.AsyncClient(timeout=10.0) as client:
429 resp = await client.post(wh.url, content=body, headers=headers)
430 if resp.is_success:
431 logger.info(
432 "Webhook %s dispatched to %s (attempt %d)",
433 wh.webhook_id, wh.url, attempt + 1,
434 )
435 return {
436 "webhook_id": wh.webhook_id,
437 "url": wh.url,
438 "status_code": resp.status_code,
439 "success": True,
440 }
441 last_error = f"HTTP {resp.status_code}: {resp.text[:200]}"
442 logger.warning(
443 "Webhook %s attempt %d failed: %s",
444 wh.webhook_id, attempt + 1, last_error,
445 )
446 except Exception as exc:
447 last_error = str(exc)
448 logger.warning(
449 "Webhook %s attempt %d error: %s",
450 wh.webhook_id, attempt + 1, last_error,
451 )
453 if attempt < wh.retry_config.max_retries - 1:
454 delay = wh.retry_config.delay(attempt)
455 logger.info("Retrying webhook %s in %.1fs...", wh.webhook_id, delay)
456 await asyncio.sleep(delay)
458 logger.error(
459 "Webhook %s failed after %d retries: %s",
460 wh.webhook_id, wh.retry_config.max_retries, last_error,
461 )
462 return {
463 "webhook_id": wh.webhook_id,
464 "url": wh.url,
465 "error": last_error,
466 "success": False,
467 }
469 tasks = [_call(wh) for wh in self._store.list_for_event(event)]
470 return await asyncio.gather(*tasks)