Coverage for src / kemi / background_tasks.py: 99%

222 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1"""Background task management for async embedding and indexing operations. 

2 

3This module provides a BackgroundTaskManager that handles long-running operations 

4like batch embedding and FTS index rebuilding as background tasks. 

5""" 

6 

7import asyncio 

8import enum 

9import logging 

10import time 

11import uuid 

12from dataclasses import dataclass 

13from threading import Lock 

14from typing import Any 

15 

16logger = logging.getLogger(__name__) 

17 

18 

19class TaskType(enum.Enum): 

20 """Types of background tasks supported.""" 

21 

22 EMBED_BATCH = "embed_batch" 

23 REBUILD_FTS_INDEX = "rebuild_fts_index" 

24 MIGRATE_EMBEDDINGS = "migrate_embeddings" 

25 TTL_SWEEP = "ttl_sweep" 

26 

27 

28class TaskStatus(enum.Enum): 

29 """Status of a background task.""" 

30 

31 PENDING = "pending" 

32 RUNNING = "running" 

33 COMPLETED = "completed" 

34 FAILED = "failed" 

35 

36 

37@dataclass 

38class TaskResult: 

39 """Result of a background task.""" 

40 

41 task_id: str 

42 task_type: TaskType 

43 status: TaskStatus 

44 created_at: float 

45 started_at: float | None = None 

46 completed_at: float | None = None 

47 result: Any = None 

48 error: str | None = None 

49 progress: float = 0.0 # 0.0 to 1.0 

50 

51 def to_dict(self) -> dict[str, Any]: 

52 """Convert to dictionary for API response.""" 

53 return { 

54 "task_id": self.task_id, 

55 "task_type": self.task_type.value, 

56 "status": self.status.value, 

57 "created_at": self.created_at, 

58 "started_at": self.started_at, 

59 "completed_at": self.completed_at, 

60 "result": self.result, 

61 "error": self.error, 

62 "progress": self.progress, 

63 } 

64 

65 

66class BackgroundTaskManager: 

67 """Manages background tasks for long-running operations. 

68 

69 Tasks run in a dedicated event loop on a background thread. 

70 This allows non-blocking API responses while heavy operations 

71 like embedding batches or index rebuilding run asynchronously. 

72 

73 Args: 

74 max_concurrent_tasks: Maximum number of tasks that can run simultaneously. 

75 """ 

76 

77 def __init__(self, max_concurrent_tasks: int = 3, max_task_history: int = 1000) -> None: 

78 self._max_concurrent = max_concurrent_tasks 

79 self._max_task_history = max_task_history 

80 self._tasks: dict[str, TaskResult] = {} 

81 self._lock = Lock() 

82 self._loop: asyncio.AbstractEventLoop | None = None 

83 self._thread: Any = None 

84 self._running_count = 0 

85 

86 def _ensure_loop_started(self) -> None: 

87 """Start the background event loop if not already running.""" 

88 if self._loop is not None: 

89 return 

90 

91 def run_loop() -> None: 

92 self._loop = asyncio.new_event_loop() 

93 asyncio.set_event_loop(self._loop) 

94 self._loop.run_forever() 

95 

96 import threading 

97 

98 self._thread = threading.Thread(target=run_loop, daemon=True) 

99 self._thread.start() 

100 

101 # Wait for loop to be ready 

102 while self._loop is None: 

103 time.sleep(0.01) 

104 

105 def _get_loop(self) -> asyncio.AbstractEventLoop: 

106 """Get the background event loop.""" 

107 self._ensure_loop_started() 

108 assert self._loop is not None 

109 return self._loop 

110 

111 def submit_embed_batch( 

112 self, 

113 user_id: str, 

114 contents: list[str], 

115 importance: float = 0.5, 

116 namespace: str = "default", 

117 ) -> str: 

118 """Submit a batch embedding task to run in background. 

119 

120 Args: 

121 user_id: User ID for the memories. 

122 contents: List of content strings to embed and store. 

123 importance: Importance value for all memories. 

124 namespace: Memory namespace. 

125 

126 Returns: 

127 task_id that can be used to check status. 

128 

129 Raises: 

130 RuntimeError: If max concurrent tasks limit reached. 

131 """ 

132 with self._lock: 

133 if self._running_count >= self._max_concurrent: 

134 raise RuntimeError( 

135 f"Max concurrent tasks ({self._max_concurrent}) reached. " 

136 "Wait for a task to complete before submitting more." 

137 ) 

138 

139 task_id = str(uuid.uuid4()) 

140 created_at = time.time() 

141 

142 with self._lock: 

143 self._tasks[task_id] = TaskResult( 

144 task_id=task_id, 

145 task_type=TaskType.EMBED_BATCH, 

146 status=TaskStatus.PENDING, 

147 created_at=created_at, 

148 ) 

149 self._cleanup_old_tasks() 

150 

151 # Submit to event loop 

152 loop = self._get_loop() 

153 coro = self._run_embed_batch(task_id, user_id, contents, importance, namespace) 

154 asyncio.run_coroutine_threadsafe(coro, loop) 

155 

156 logger.info(f"Submitted embed_batch task {task_id} with {len(contents)} items") 

157 return task_id 

158 

159 async def _run_embed_batch( 

160 self, 

161 task_id: str, 

162 user_id: str, 

163 contents: list[str], 

164 importance: float, 

165 namespace: str, 

166 ) -> None: 

167 """Run the batch embedding task.""" 

168 from kemi import Memory 

169 from kemi.models import MemorySource, MemoryType 

170 

171 with self._lock: 

172 self._tasks[task_id].status = TaskStatus.RUNNING 

173 self._tasks[task_id].started_at = time.time() 

174 self._running_count += 1 

175 

176 try: 

177 mem = Memory() 

178 # Run directly on background thread - no need for asyncio.to_thread 

179 total = len(contents) 

180 stored = mem.remember_many( 

181 user_id=user_id, 

182 contents=contents, 

183 importance=importance, 

184 namespace=namespace, 

185 source=MemorySource.USER_STATED, 

186 memory_type=MemoryType.EPISODIC, 

187 ) 

188 

189 result = {"stored_count": len(stored), "user_id": user_id} 

190 

191 with self._lock: 

192 self._tasks[task_id].status = TaskStatus.COMPLETED 

193 self._tasks[task_id].result = result 

194 self._tasks[task_id].completed_at = time.time() 

195 self._tasks[task_id].progress = 1.0 

196 

197 logger.info(f"Completed embed_batch task {task_id}: stored {total} memories") 

198 

199 except Exception as e: 

200 logger.error(f"Failed embed_batch task {task_id}: {e}") 

201 with self._lock: 

202 self._tasks[task_id].status = TaskStatus.FAILED 

203 self._tasks[task_id].error = str(e) 

204 self._tasks[task_id].completed_at = time.time() 

205 

206 finally: 

207 with self._lock: 

208 self._running_count -= 1 

209 

210 def submit_rebuild_fts_index(self, user_id: str | None = None) -> str: 

211 """Submit an FTS index rebuild task to run in background. 

212 

213 Args: 

214 user_id: Optional user ID to limit rebuild scope. If None, rebuilds all. 

215 

216 Returns: 

217 task_id that can be used to check status. 

218 

219 Raises: 

220 RuntimeError: If max concurrent tasks limit reached. 

221 """ 

222 with self._lock: 

223 if self._running_count >= self._max_concurrent: 

224 raise RuntimeError( 

225 f"Max concurrent tasks ({self._max_concurrent}) reached. " 

226 "Wait for a task to complete before submitting more." 

227 ) 

228 

229 task_id = str(uuid.uuid4()) 

230 created_at = time.time() 

231 

232 with self._lock: 

233 self._tasks[task_id] = TaskResult( 

234 task_id=task_id, 

235 task_type=TaskType.REBUILD_FTS_INDEX, 

236 status=TaskStatus.PENDING, 

237 created_at=created_at, 

238 ) 

239 self._cleanup_old_tasks() 

240 

241 loop = self._get_loop() 

242 coro = self._run_rebuild_fts(task_id, user_id) 

243 asyncio.run_coroutine_threadsafe(coro, loop) 

244 

245 scope = f"user {user_id}" if user_id else "all users" 

246 logger.info(f"Submitted rebuild_fts_index task {task_id} for {scope}") 

247 return task_id 

248 

249 async def _run_rebuild_fts(self, task_id: str, user_id: str | None) -> None: 

250 """Run the FTS index rebuild task.""" 

251 from kemi import Memory 

252 

253 with self._lock: 

254 self._tasks[task_id].status = TaskStatus.RUNNING 

255 self._tasks[task_id].started_at = time.time() 

256 self._running_count += 1 

257 

258 try: 

259 mem = Memory() 

260 

261 # Rebuild FTS index 

262 if hasattr(mem._store, "rebuild_fts_index"): 

263 # Pass user_id through so per-user rebuilds only touch that 

264 # user's FTS rows instead of rebuilding the whole index. 

265 count = await asyncio.to_thread( 

266 mem._store.rebuild_fts_index, 

267 user_id, 

268 ) 

269 

270 result = { 

271 "rebuilt": True, 

272 "count": count, 

273 "user_id": user_id, 

274 "scope": "user" if user_id else "all", 

275 } 

276 else: 

277 result = { 

278 "rebuilt": False, 

279 "message": "Storage adapter does not support FTS rebuild", 

280 } 

281 

282 with self._lock: 

283 self._tasks[task_id].status = TaskStatus.COMPLETED 

284 self._tasks[task_id].result = result 

285 self._tasks[task_id].completed_at = time.time() 

286 self._tasks[task_id].progress = 1.0 

287 

288 logger.info(f"Completed rebuild_fts_index task {task_id}") 

289 

290 except Exception as e: 

291 logger.error(f"Failed rebuild_fts_index task {task_id}: {e}") 

292 with self._lock: 

293 self._tasks[task_id].status = TaskStatus.FAILED 

294 self._tasks[task_id].error = str(e) 

295 self._tasks[task_id].completed_at = time.time() 

296 

297 finally: 

298 with self._lock: 

299 self._running_count -= 1 

300 

301 def get_task_status(self, task_id: str) -> TaskResult | None: 

302 """Get the status of a task. 

303 

304 Args: 

305 task_id: The task ID returned from submit_*. 

306 

307 Returns: 

308 TaskResult if found, None otherwise. 

309 """ 

310 with self._lock: 

311 return self._tasks.get(task_id) 

312 

313 def list_tasks( 

314 self, 

315 status: TaskStatus | None = None, 

316 limit: int = 50, 

317 ) -> list[TaskResult]: 

318 """List all tasks, optionally filtered by status. 

319 

320 Args: 

321 status: Optional filter by task status. 

322 limit: Maximum number of tasks to return. 

323 

324 Returns: 

325 List of TaskResult objects. 

326 """ 

327 with self._lock: 

328 tasks = list(self._tasks.values()) 

329 

330 if status is not None: 

331 tasks = [t for t in tasks if t.status == status] 

332 

333 # Sort by created_at descending 

334 tasks.sort(key=lambda t: t.created_at, reverse=True) 

335 return tasks[:limit] 

336 

337 def get_stats(self) -> dict[str, Any]: 

338 """Get task manager statistics. 

339 

340 Returns: 

341 Dict with counts of pending, running, completed, failed tasks. 

342 """ 

343 with self._lock: 

344 tasks = list(self._tasks.values()) 

345 

346 pending = sum(1 for t in tasks if t.status == TaskStatus.PENDING) 

347 running = sum(1 for t in tasks if t.status == TaskStatus.RUNNING) 

348 completed = sum(1 for t in tasks if t.status == TaskStatus.COMPLETED) 

349 failed = sum(1 for t in tasks if t.status == TaskStatus.FAILED) 

350 

351 return { 

352 "total_tasks": len(tasks), 

353 "pending": pending, 

354 "running": running, 

355 "completed": completed, 

356 "failed": failed, 

357 "max_concurrent": self._max_concurrent, 

358 } 

359 

360 def shutdown(self) -> None: 

361 """Gracefully shutdown the task manager. 

362 

363 Stops accepting new tasks and closes the background event loop. 

364 """ 

365 if self._loop is not None: 

366 # Schedule loop stop 

367 self._loop.call_soon_threadsafe(self._loop.stop) 

368 self._loop = None 

369 self._thread = None 

370 logger.info("Background task manager shutdown complete") 

371 

372 def _cleanup_old_tasks(self) -> None: 

373 """Remove old completed/failed tasks if history limit exceeded.""" 

374 if len(self._tasks) <= self._max_task_history: 

375 return 

376 

377 # Get completed/failed tasks sorted by completion time 

378 old_tasks = [ 

379 (tid, t.completed_at or t.created_at) 

380 for tid, t in self._tasks.items() 

381 if t.status in (TaskStatus.COMPLETED, TaskStatus.FAILED) 

382 ] 

383 old_tasks.sort(key=lambda x: x[1]) # Sort by time ascending 

384 

385 # Remove oldest until under limit 

386 tasks_to_remove = len(self._tasks) - self._max_task_history 

387 for tid, _ in old_tasks[:tasks_to_remove]: 

388 del self._tasks[tid] 

389 

390 def cancel_task(self, task_id: str) -> bool: 

391 """Cancel a pending task. 

392 

393 Note: Running tasks cannot be cancelled mid-execution. 

394 

395 Args: 

396 task_id: The task ID to cancel. 

397 

398 Returns: 

399 True if cancelled, False if not found or already running. 

400 """ 

401 with self._lock: 

402 task = self._tasks.get(task_id) 

403 if task is None: 

404 return False 

405 if task.status == TaskStatus.PENDING: 

406 task.status = TaskStatus.FAILED 

407 task.error = "Cancelled by user" 

408 task.completed_at = time.time() 

409 return True 

410 return False 

411 

412 def submit_ttl_sweep( 

413 self, 

414 user_id: str | None = None, 

415 namespace: str | None = None, 

416 ) -> str: 

417 """Submit a TTL sweep task to delete expired memories in the background. 

418 

419 Args: 

420 user_id: If provided, only sweep this user's expired memories. 

421 If None, sweep all users. 

422 namespace: If provided, only sweep this namespace. 

423 

424 Returns: 

425 task_id that can be used to check status. 

426 

427 Raises: 

428 RuntimeError: If max concurrent tasks limit reached. 

429 """ 

430 with self._lock: 

431 if self._running_count >= self._max_concurrent: 

432 raise RuntimeError( 

433 f"Max concurrent tasks ({self._max_concurrent}) reached. " 

434 "Wait for a task to complete before submitting more." 

435 ) 

436 

437 task_id = str(uuid.uuid4()) 

438 created_at = time.time() 

439 

440 with self._lock: 

441 self._tasks[task_id] = TaskResult( 

442 task_id=task_id, 

443 task_type=TaskType.TTL_SWEEP, 

444 status=TaskStatus.PENDING, 

445 created_at=created_at, 

446 ) 

447 self._cleanup_old_tasks() 

448 

449 loop = self._get_loop() 

450 coro = self._run_ttl_sweep(task_id, user_id, namespace) 

451 asyncio.run_coroutine_threadsafe(coro, loop) 

452 

453 scope = f"user {user_id}" if user_id else "all users" 

454 logger.info(f"Submitted ttl_sweep task {task_id} for {scope}") 

455 return task_id 

456 

457 async def _run_ttl_sweep( 

458 self, 

459 task_id: str, 

460 user_id: str | None, 

461 namespace: str | None, 

462 ) -> None: 

463 """Run the TTL sweep task.""" 

464 from kemi import Memory 

465 

466 with self._lock: 

467 self._tasks[task_id].status = TaskStatus.RUNNING 

468 self._tasks[task_id].started_at = time.time() 

469 self._running_count += 1 

470 

471 try: 

472 mem = Memory() 

473 deleted = await asyncio.to_thread( 

474 mem.prune_expired, 

475 user_id, 

476 namespace, 

477 ) 

478 result = { 

479 "deleted_count": deleted, 

480 "user_id": user_id, 

481 "namespace": namespace, 

482 } 

483 

484 with self._lock: 

485 self._tasks[task_id].status = TaskStatus.COMPLETED 

486 self._tasks[task_id].result = result 

487 self._tasks[task_id].completed_at = time.time() 

488 self._tasks[task_id].progress = 1.0 

489 

490 logger.info( 

491 f"Completed ttl_sweep task {task_id}: deleted {deleted} memories" 

492 ) 

493 

494 except Exception as e: 

495 logger.error(f"Failed ttl_sweep task {task_id}: {e}") 

496 with self._lock: 

497 self._tasks[task_id].status = TaskStatus.FAILED 

498 self._tasks[task_id].error = str(e) 

499 self._tasks[task_id].completed_at = time.time() 

500 

501 finally: 

502 with self._lock: 

503 self._running_count -= 1 

504 

505 

506# Global task manager instance 

507_task_manager: BackgroundTaskManager | None = None 

508 

509 

510def get_task_manager() -> BackgroundTaskManager: 

511 """Get or create the global task manager instance.""" 

512 global _task_manager 

513 if _task_manager is None: 

514 max_concurrent = int(__import__("os").environ.get("KEMI_MAX_BACKGROUND_TASKS", "3")) 

515 _task_manager = BackgroundTaskManager(max_concurrent_tasks=max_concurrent) 

516 return _task_manager