Coverage for src / kemi / webhooks.py: 95%

171 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1"""Webhook callbacks for memory lifecycle events. 

2 

3Event types: memory.remembered, memory.updated, memory.forgotten, 

4memory.deleted, memory.conflict, memory.consolidated. 

5 

6Requires httpx (dev dependency) for HTTP dispatch. 

7""" 

8 

9import hashlib 

10import hmac 

11import json 

12import logging 

13import sqlite3 

14import time 

15import uuid 

16from dataclasses import dataclass, field 

17from datetime import datetime, timezone 

18from enum import Enum 

19from typing import Any 

20 

21logger = logging.getLogger(__name__) 

22 

23# --------------------------------------------------------------------------- 

24# Event types 

25# --------------------------------------------------------------------------- 

26 

27 

28class WebhookEventType(str, Enum): 

29 """Enum of supported memory lifecycle event types.""" 

30 

31 REMEMBERED = "memory.remembered" 

32 UPDATED = "memory.updated" 

33 FORGOTTEN = "memory.forgotten" # soft delete (lifecycle transition) 

34 DELETED = "memory.deleted" # hard delete 

35 CONFLICT = "memory.conflict" 

36 CONSOLIDATED = "memory.consolidated" 

37 

38 @classmethod 

39 def from_string(cls, value: str) -> "WebhookEventType": 

40 """Parse a string to an event type, raising ValueError on failure.""" 

41 try: 

42 return cls(value) 

43 except ValueError: 

44 valid = ", ".join(e.value for e in cls) 

45 raise ValueError(f"Invalid event type '{value}'. Valid: {valid}") 

46 

47 

48# --------------------------------------------------------------------------- 

49# Retry config 

50# --------------------------------------------------------------------------- 

51 

52 

53@dataclass 

54class RetryConfig: 

55 """Configuration for webhook retry with exponential backoff.""" 

56 

57 max_retries: int = 5 

58 base_delay_seconds: float = 1.0 

59 max_delay_seconds: float = 60.0 

60 backoff_multiplier: float = 2.0 

61 

62 def delay(self, attempt: int) -> float: 

63 """Compute delay in seconds for a given attempt (0-indexed).""" 

64 delay = self.base_delay_seconds * (self.backoff_multiplier ** attempt) 

65 return min(delay, self.max_delay_seconds) 

66 

67 

68# --------------------------------------------------------------------------- 

69# Webhook config model 

70# --------------------------------------------------------------------------- 

71 

72 

73@dataclass 

74class WebhookConfig: 

75 """Configuration for a single webhook endpoint.""" 

76 

77 webhook_id: str 

78 url: str 

79 events: list[WebhookEventType] = field(default_factory=list) 

80 secret: str = "" 

81 active: bool = True 

82 retry_config: RetryConfig = field(default_factory=RetryConfig) 

83 

84 def matches_event(self, event: WebhookEventType) -> bool: 

85 """Return True if this webhook is active and subscribes to *event*.""" 

86 return self.active and event in self.events 

87 

88 

89# --------------------------------------------------------------------------- 

90# Payload building 

91# --------------------------------------------------------------------------- 

92 

93 

94def build_payload( 

95 event: WebhookEventType, 

96 memory_id: str, 

97 user_id: str, 

98 snapshot: dict[str, Any] | None = None, 

99 previous_state: dict[str, Any] | None = None, 

100 extra: dict[str, Any] | None = None, 

101) -> dict[str, Any]: 

102 """Build the JSON-serialisable payload sent to webhook endpoints. 

103 

104 Args: 

105 event: The lifecycle event type. 

106 memory_id: ID of the memory that triggered the event. 

107 user_id: ID of the memory's owner. 

108 snapshot: Current snapshot of the memory fields (optional). 

109 previous_state: Snapshot of the memory *before* the change (only for 

110 ``memory.updated``). 

111 extra: Any additional key/value pairs to merge into the payload. 

112 """ 

113 payload: dict[str, Any] = { 

114 "event": event.value, 

115 "memory_id": memory_id, 

116 "user_id": user_id, 

117 "timestamp": datetime.now(timezone.utc).isoformat(), 

118 } 

119 if snapshot is not None: 

120 payload["snapshot"] = snapshot 

121 if previous_state is not None: 

122 payload["previous_state"] = previous_state 

123 if extra: 

124 payload.update(extra) 

125 return payload 

126 

127 

128# --------------------------------------------------------------------------- 

129# HMAC signature 

130# --------------------------------------------------------------------------- 

131 

132 

133def sign_payload(payload: dict[str, Any], secret: str) -> str: 

134 """Compute ``X-Kemi-Signature`` header value for a payload. 

135 

136 Uses HMAC-SHA256 with the webhook's *secret* over the JSON-serialised 

137 payload (sorted keys, no whitespace). 

138 """ 

139 body = json.dumps(payload, sort_keys=True, separators=(",", ":")) 

140 mac = hmac.new(secret.encode("utf-8"), body.encode("utf-8"), hashlib.sha256) 

141 return mac.hexdigest() 

142 

143 

144# --------------------------------------------------------------------------- 

145# Persistent storage 

146# --------------------------------------------------------------------------- 

147 

148_WEBHOOKS_DDL = """ 

149CREATE TABLE IF NOT EXISTS webhooks ( 

150 webhook_id TEXT PRIMARY KEY, 

151 url TEXT NOT NULL, 

152 events TEXT NOT NULL, -- JSON array of event type strings 

153 secret TEXT NOT NULL DEFAULT '', 

154 active INTEGER NOT NULL DEFAULT 1, 

155 retry_max INTEGER NOT NULL DEFAULT 5, 

156 retry_base_delay REAL NOT NULL DEFAULT 1.0, 

157 retry_max_delay REAL NOT NULL DEFAULT 60.0, 

158 retry_multiplier REAL NOT NULL DEFAULT 2.0, 

159 created_at TEXT NOT NULL, 

160 updated_at TEXT NOT NULL 

161); 

162""" 

163 

164 

165class WebhookStore: 

166 """Persistent storage for webhook configurations. 

167 

168 Uses a separate ``webhooks`` table in the same SQLite database as the 

169 main memory store (or any other SQLite database path). 

170 """ 

171 

172 def __init__(self, db_path: str) -> None: 

173 self._db_path = db_path 

174 self._init_schema() 

175 

176 def _get_connection(self) -> sqlite3.Connection: 

177 conn = sqlite3.connect(self._db_path) 

178 conn.row_factory = sqlite3.Row 

179 conn.execute("PRAGMA journal_mode=WAL") 

180 return conn 

181 

182 def _init_schema(self) -> None: 

183 with self._get_connection() as conn: 

184 conn.execute(_WEBHOOKS_DDL) 

185 

186 # ------------------------------------------------------------------ 

187 # CRUD 

188 # ------------------------------------------------------------------ 

189 

190 def create(self, cfg: WebhookConfig) -> str: 

191 """Persist a new webhook config. Returns the webhook_id.""" 

192 if not cfg.webhook_id: 

193 cfg.webhook_id = str(uuid.uuid4()) 

194 now = datetime.now(timezone.utc).isoformat() 

195 events_json = json.dumps([e.value for e in cfg.events]) 

196 with self._get_connection() as conn: 

197 conn.execute( 

198 """INSERT INTO webhooks 

199 (webhook_id, url, events, secret, active, 

200 retry_max, retry_base_delay, retry_max_delay, 

201 retry_multiplier, created_at, updated_at) 

202 VALUES (?, ?, ?, ?, ?, 

203 ?, ?, ?, 

204 ?, ?, ?)""", 

205 ( 

206 cfg.webhook_id, 

207 cfg.url, 

208 events_json, 

209 cfg.secret, 

210 1 if cfg.active else 0, 

211 cfg.retry_config.max_retries, 

212 cfg.retry_config.base_delay_seconds, 

213 cfg.retry_config.max_delay_seconds, 

214 cfg.retry_config.backoff_multiplier, 

215 now, 

216 now, 

217 ), 

218 ) 

219 return cfg.webhook_id 

220 

221 def get(self, webhook_id: str) -> WebhookConfig | None: 

222 """Retrieve a webhook config by ID.""" 

223 with self._get_connection() as conn: 

224 row = conn.execute( 

225 "SELECT * FROM webhooks WHERE webhook_id = ?", 

226 (webhook_id,), 

227 ).fetchone() 

228 if row is None: 

229 return None 

230 return self._row_to_config(row) 

231 

232 def list_all(self, active_only: bool = True) -> list[WebhookConfig]: 

233 """Return all webhook configs, optionally only active ones.""" 

234 with self._get_connection() as conn: 

235 if active_only: 

236 rows = conn.execute( 

237 "SELECT * FROM webhooks WHERE active = 1 ORDER BY created_at" 

238 ).fetchall() 

239 else: 

240 rows = conn.execute( 

241 "SELECT * FROM webhooks ORDER BY created_at" 

242 ).fetchall() 

243 return [self._row_to_config(r) for r in rows] 

244 

245 def list_for_event(self, event: WebhookEventType) -> list[WebhookConfig]: 

246 """Return active webhooks that subscribe to *event*.""" 

247 all_webhooks = self.list_all(active_only=True) 

248 return [w for w in all_webhooks if w.matches_event(event)] 

249 

250 def delete(self, webhook_id: str) -> bool: 

251 """Delete a webhook config. Returns True if a row was deleted.""" 

252 with self._get_connection() as conn: 

253 cursor = conn.execute( 

254 "DELETE FROM webhooks WHERE webhook_id = ?", 

255 (webhook_id,), 

256 ) 

257 return cursor.rowcount > 0 

258 

259 def update(self, cfg: WebhookConfig) -> bool: 

260 """Update an existing webhook config. Returns True if updated.""" 

261 now = datetime.now(timezone.utc).isoformat() 

262 events_json = json.dumps([e.value for e in cfg.events]) 

263 with self._get_connection() as conn: 

264 cursor = conn.execute( 

265 """UPDATE webhooks SET 

266 url = ?, events = ?, secret = ?, active = ?, 

267 retry_max = ?, retry_base_delay = ?, retry_max_delay = ?, 

268 retry_multiplier = ?, updated_at = ? 

269 WHERE webhook_id = ?""", 

270 ( 

271 cfg.url, 

272 events_json, 

273 cfg.secret, 

274 1 if cfg.active else 0, 

275 cfg.retry_config.max_retries, 

276 cfg.retry_config.base_delay_seconds, 

277 cfg.retry_config.max_delay_seconds, 

278 cfg.retry_config.backoff_multiplier, 

279 now, 

280 cfg.webhook_id, 

281 ), 

282 ) 

283 return cursor.rowcount > 0 

284 

285 @staticmethod 

286 def _row_to_config(row: sqlite3.Row) -> WebhookConfig: 

287 events_str = row["events"] 

288 try: 

289 events_list = json.loads(events_str) if events_str else [] 

290 except json.JSONDecodeError: 

291 events_list = [] 

292 return WebhookConfig( 

293 webhook_id=row["webhook_id"], 

294 url=row["url"], 

295 events=[WebhookEventType(e) for e in events_list], 

296 secret=row["secret"], 

297 active=bool(row["active"]), 

298 retry_config=RetryConfig( 

299 max_retries=row["retry_max"], 

300 base_delay_seconds=row["retry_base_delay"], 

301 max_delay_seconds=row["retry_max_delay"], 

302 backoff_multiplier=row["retry_multiplier"], 

303 ), 

304 ) 

305 

306 

307# --------------------------------------------------------------------------- 

308# Dispatcher 

309# --------------------------------------------------------------------------- 

310 

311 

312class WebhookDispatcher: 

313 """Dispatches webhook callbacks to registered endpoints. 

314 

315 Supports both synchronous (blocking) and asynchronous (non-blocking) 

316 dispatch with automatic retry via exponential backoff. 

317 """ 

318 

319 def __init__(self, store: WebhookStore) -> None: 

320 self._store = store 

321 

322 # ------------------------------------------------------------------ 

323 # Synchronous dispatch 

324 # ------------------------------------------------------------------ 

325 

326 def dispatch_sync(self, payload: dict[str, Any], event: WebhookEventType) -> list[dict[str, Any]]: 

327 """Dispatch *payload* synchronously to all subscribers of *event*. 

328 

329 This blocks until each webhook call completes (including retries). 

330 Returns a list of result dicts (one per webhook). 

331 """ 

332 results: list[dict[str, Any]] = [] 

333 for wh in self._store.list_for_event(event): 

334 result = self._call_with_retry_sync(wh, payload) 

335 results.append(result) 

336 return results 

337 

338 def _call_with_retry_sync( 

339 self, 

340 wh: WebhookConfig, 

341 payload: dict[str, Any], 

342 ) -> dict[str, Any]: 

343 """Call a single webhook synchronously with retry logic.""" 

344 import httpx 

345 

346 signature = sign_payload(payload, wh.secret) 

347 headers = { 

348 "Content-Type": "application/json", 

349 "X-Kemi-Signature": signature, 

350 "User-Agent": "kemi-webhook/1.0", 

351 } 

352 body = json.dumps(payload) 

353 

354 last_error: str | None = None 

355 for attempt in range(wh.retry_config.max_retries): 

356 try: 

357 with httpx.Client(timeout=10.0) as client: 

358 resp = client.post(wh.url, content=body, headers=headers) 

359 if resp.is_success: 

360 logger.info( 

361 "Webhook %s dispatched to %s (attempt %d)", 

362 wh.webhook_id, wh.url, attempt + 1, 

363 ) 

364 return { 

365 "webhook_id": wh.webhook_id, 

366 "url": wh.url, 

367 "status_code": resp.status_code, 

368 "success": True, 

369 } 

370 last_error = f"HTTP {resp.status_code}: {resp.text[:200]}" 

371 logger.warning( 

372 "Webhook %s attempt %d failed: %s", 

373 wh.webhook_id, attempt + 1, last_error, 

374 ) 

375 except Exception as exc: 

376 last_error = str(exc) 

377 logger.warning( 

378 "Webhook %s attempt %d error: %s", 

379 wh.webhook_id, attempt + 1, last_error, 

380 ) 

381 

382 if attempt < wh.retry_config.max_retries - 1: 

383 delay = wh.retry_config.delay(attempt) 

384 logger.info("Retrying webhook %s in %.1fs...", wh.webhook_id, delay) 

385 time.sleep(delay) 

386 

387 logger.error( 

388 "Webhook %s failed after %d retries: %s", 

389 wh.webhook_id, wh.retry_config.max_retries, last_error, 

390 ) 

391 return { 

392 "webhook_id": wh.webhook_id, 

393 "url": wh.url, 

394 "error": last_error, 

395 "success": False, 

396 } 

397 

398 # ------------------------------------------------------------------ 

399 # Asynchronous dispatch 

400 # ------------------------------------------------------------------ 

401 

402 async def dispatch_async( 

403 self, 

404 payload: dict[str, Any], 

405 event: WebhookEventType, 

406 ) -> list[dict[str, Any]]: 

407 """Dispatch *payload* asynchronously to all subscribers of *event*. 

408 

409 Non-blocking — returns immediately after scheduling all HTTP calls. 

410 Each call runs its retry loop in an async task. 

411 Returns a list of result dicts (one per webhook). 

412 """ 

413 import asyncio 

414 import httpx 

415 

416 async def _call(wh: WebhookConfig) -> dict[str, Any]: 

417 signature = sign_payload(payload, wh.secret) 

418 headers = { 

419 "Content-Type": "application/json", 

420 "X-Kemi-Signature": signature, 

421 "User-Agent": "kemi-webhook/1.0", 

422 } 

423 body = json.dumps(payload) 

424 

425 last_error: str | None = None 

426 for attempt in range(wh.retry_config.max_retries): 

427 try: 

428 async with httpx.AsyncClient(timeout=10.0) as client: 

429 resp = await client.post(wh.url, content=body, headers=headers) 

430 if resp.is_success: 

431 logger.info( 

432 "Webhook %s dispatched to %s (attempt %d)", 

433 wh.webhook_id, wh.url, attempt + 1, 

434 ) 

435 return { 

436 "webhook_id": wh.webhook_id, 

437 "url": wh.url, 

438 "status_code": resp.status_code, 

439 "success": True, 

440 } 

441 last_error = f"HTTP {resp.status_code}: {resp.text[:200]}" 

442 logger.warning( 

443 "Webhook %s attempt %d failed: %s", 

444 wh.webhook_id, attempt + 1, last_error, 

445 ) 

446 except Exception as exc: 

447 last_error = str(exc) 

448 logger.warning( 

449 "Webhook %s attempt %d error: %s", 

450 wh.webhook_id, attempt + 1, last_error, 

451 ) 

452 

453 if attempt < wh.retry_config.max_retries - 1: 

454 delay = wh.retry_config.delay(attempt) 

455 logger.info("Retrying webhook %s in %.1fs...", wh.webhook_id, delay) 

456 await asyncio.sleep(delay) 

457 

458 logger.error( 

459 "Webhook %s failed after %d retries: %s", 

460 wh.webhook_id, wh.retry_config.max_retries, last_error, 

461 ) 

462 return { 

463 "webhook_id": wh.webhook_id, 

464 "url": wh.url, 

465 "error": last_error, 

466 "success": False, 

467 } 

468 

469 tasks = [_call(wh) for wh in self._store.list_for_event(event)] 

470 return await asyncio.gather(*tasks)