Coverage for src / kemi / background_tasks.py: 99%
222 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1"""Background task management for async embedding and indexing operations.
3This module provides a BackgroundTaskManager that handles long-running operations
4like batch embedding and FTS index rebuilding as background tasks.
5"""
7import asyncio
8import enum
9import logging
10import time
11import uuid
12from dataclasses import dataclass
13from threading import Lock
14from typing import Any
16logger = logging.getLogger(__name__)
19class TaskType(enum.Enum):
20 """Types of background tasks supported."""
22 EMBED_BATCH = "embed_batch"
23 REBUILD_FTS_INDEX = "rebuild_fts_index"
24 MIGRATE_EMBEDDINGS = "migrate_embeddings"
25 TTL_SWEEP = "ttl_sweep"
28class TaskStatus(enum.Enum):
29 """Status of a background task."""
31 PENDING = "pending"
32 RUNNING = "running"
33 COMPLETED = "completed"
34 FAILED = "failed"
37@dataclass
38class TaskResult:
39 """Result of a background task."""
41 task_id: str
42 task_type: TaskType
43 status: TaskStatus
44 created_at: float
45 started_at: float | None = None
46 completed_at: float | None = None
47 result: Any = None
48 error: str | None = None
49 progress: float = 0.0 # 0.0 to 1.0
51 def to_dict(self) -> dict[str, Any]:
52 """Convert to dictionary for API response."""
53 return {
54 "task_id": self.task_id,
55 "task_type": self.task_type.value,
56 "status": self.status.value,
57 "created_at": self.created_at,
58 "started_at": self.started_at,
59 "completed_at": self.completed_at,
60 "result": self.result,
61 "error": self.error,
62 "progress": self.progress,
63 }
66class BackgroundTaskManager:
67 """Manages background tasks for long-running operations.
69 Tasks run in a dedicated event loop on a background thread.
70 This allows non-blocking API responses while heavy operations
71 like embedding batches or index rebuilding run asynchronously.
73 Args:
74 max_concurrent_tasks: Maximum number of tasks that can run simultaneously.
75 """
77 def __init__(self, max_concurrent_tasks: int = 3, max_task_history: int = 1000) -> None:
78 self._max_concurrent = max_concurrent_tasks
79 self._max_task_history = max_task_history
80 self._tasks: dict[str, TaskResult] = {}
81 self._lock = Lock()
82 self._loop: asyncio.AbstractEventLoop | None = None
83 self._thread: Any = None
84 self._running_count = 0
86 def _ensure_loop_started(self) -> None:
87 """Start the background event loop if not already running."""
88 if self._loop is not None:
89 return
91 def run_loop() -> None:
92 self._loop = asyncio.new_event_loop()
93 asyncio.set_event_loop(self._loop)
94 self._loop.run_forever()
96 import threading
98 self._thread = threading.Thread(target=run_loop, daemon=True)
99 self._thread.start()
101 # Wait for loop to be ready
102 while self._loop is None:
103 time.sleep(0.01)
105 def _get_loop(self) -> asyncio.AbstractEventLoop:
106 """Get the background event loop."""
107 self._ensure_loop_started()
108 assert self._loop is not None
109 return self._loop
111 def submit_embed_batch(
112 self,
113 user_id: str,
114 contents: list[str],
115 importance: float = 0.5,
116 namespace: str = "default",
117 ) -> str:
118 """Submit a batch embedding task to run in background.
120 Args:
121 user_id: User ID for the memories.
122 contents: List of content strings to embed and store.
123 importance: Importance value for all memories.
124 namespace: Memory namespace.
126 Returns:
127 task_id that can be used to check status.
129 Raises:
130 RuntimeError: If max concurrent tasks limit reached.
131 """
132 with self._lock:
133 if self._running_count >= self._max_concurrent:
134 raise RuntimeError(
135 f"Max concurrent tasks ({self._max_concurrent}) reached. "
136 "Wait for a task to complete before submitting more."
137 )
139 task_id = str(uuid.uuid4())
140 created_at = time.time()
142 with self._lock:
143 self._tasks[task_id] = TaskResult(
144 task_id=task_id,
145 task_type=TaskType.EMBED_BATCH,
146 status=TaskStatus.PENDING,
147 created_at=created_at,
148 )
149 self._cleanup_old_tasks()
151 # Submit to event loop
152 loop = self._get_loop()
153 coro = self._run_embed_batch(task_id, user_id, contents, importance, namespace)
154 asyncio.run_coroutine_threadsafe(coro, loop)
156 logger.info(f"Submitted embed_batch task {task_id} with {len(contents)} items")
157 return task_id
159 async def _run_embed_batch(
160 self,
161 task_id: str,
162 user_id: str,
163 contents: list[str],
164 importance: float,
165 namespace: str,
166 ) -> None:
167 """Run the batch embedding task."""
168 from kemi import Memory
169 from kemi.models import MemorySource, MemoryType
171 with self._lock:
172 self._tasks[task_id].status = TaskStatus.RUNNING
173 self._tasks[task_id].started_at = time.time()
174 self._running_count += 1
176 try:
177 mem = Memory()
178 # Run directly on background thread - no need for asyncio.to_thread
179 total = len(contents)
180 stored = mem.remember_many(
181 user_id=user_id,
182 contents=contents,
183 importance=importance,
184 namespace=namespace,
185 source=MemorySource.USER_STATED,
186 memory_type=MemoryType.EPISODIC,
187 )
189 result = {"stored_count": len(stored), "user_id": user_id}
191 with self._lock:
192 self._tasks[task_id].status = TaskStatus.COMPLETED
193 self._tasks[task_id].result = result
194 self._tasks[task_id].completed_at = time.time()
195 self._tasks[task_id].progress = 1.0
197 logger.info(f"Completed embed_batch task {task_id}: stored {total} memories")
199 except Exception as e:
200 logger.error(f"Failed embed_batch task {task_id}: {e}")
201 with self._lock:
202 self._tasks[task_id].status = TaskStatus.FAILED
203 self._tasks[task_id].error = str(e)
204 self._tasks[task_id].completed_at = time.time()
206 finally:
207 with self._lock:
208 self._running_count -= 1
210 def submit_rebuild_fts_index(self, user_id: str | None = None) -> str:
211 """Submit an FTS index rebuild task to run in background.
213 Args:
214 user_id: Optional user ID to limit rebuild scope. If None, rebuilds all.
216 Returns:
217 task_id that can be used to check status.
219 Raises:
220 RuntimeError: If max concurrent tasks limit reached.
221 """
222 with self._lock:
223 if self._running_count >= self._max_concurrent:
224 raise RuntimeError(
225 f"Max concurrent tasks ({self._max_concurrent}) reached. "
226 "Wait for a task to complete before submitting more."
227 )
229 task_id = str(uuid.uuid4())
230 created_at = time.time()
232 with self._lock:
233 self._tasks[task_id] = TaskResult(
234 task_id=task_id,
235 task_type=TaskType.REBUILD_FTS_INDEX,
236 status=TaskStatus.PENDING,
237 created_at=created_at,
238 )
239 self._cleanup_old_tasks()
241 loop = self._get_loop()
242 coro = self._run_rebuild_fts(task_id, user_id)
243 asyncio.run_coroutine_threadsafe(coro, loop)
245 scope = f"user {user_id}" if user_id else "all users"
246 logger.info(f"Submitted rebuild_fts_index task {task_id} for {scope}")
247 return task_id
249 async def _run_rebuild_fts(self, task_id: str, user_id: str | None) -> None:
250 """Run the FTS index rebuild task."""
251 from kemi import Memory
253 with self._lock:
254 self._tasks[task_id].status = TaskStatus.RUNNING
255 self._tasks[task_id].started_at = time.time()
256 self._running_count += 1
258 try:
259 mem = Memory()
261 # Rebuild FTS index
262 if hasattr(mem._store, "rebuild_fts_index"):
263 # Pass user_id through so per-user rebuilds only touch that
264 # user's FTS rows instead of rebuilding the whole index.
265 count = await asyncio.to_thread(
266 mem._store.rebuild_fts_index,
267 user_id,
268 )
270 result = {
271 "rebuilt": True,
272 "count": count,
273 "user_id": user_id,
274 "scope": "user" if user_id else "all",
275 }
276 else:
277 result = {
278 "rebuilt": False,
279 "message": "Storage adapter does not support FTS rebuild",
280 }
282 with self._lock:
283 self._tasks[task_id].status = TaskStatus.COMPLETED
284 self._tasks[task_id].result = result
285 self._tasks[task_id].completed_at = time.time()
286 self._tasks[task_id].progress = 1.0
288 logger.info(f"Completed rebuild_fts_index task {task_id}")
290 except Exception as e:
291 logger.error(f"Failed rebuild_fts_index task {task_id}: {e}")
292 with self._lock:
293 self._tasks[task_id].status = TaskStatus.FAILED
294 self._tasks[task_id].error = str(e)
295 self._tasks[task_id].completed_at = time.time()
297 finally:
298 with self._lock:
299 self._running_count -= 1
301 def get_task_status(self, task_id: str) -> TaskResult | None:
302 """Get the status of a task.
304 Args:
305 task_id: The task ID returned from submit_*.
307 Returns:
308 TaskResult if found, None otherwise.
309 """
310 with self._lock:
311 return self._tasks.get(task_id)
313 def list_tasks(
314 self,
315 status: TaskStatus | None = None,
316 limit: int = 50,
317 ) -> list[TaskResult]:
318 """List all tasks, optionally filtered by status.
320 Args:
321 status: Optional filter by task status.
322 limit: Maximum number of tasks to return.
324 Returns:
325 List of TaskResult objects.
326 """
327 with self._lock:
328 tasks = list(self._tasks.values())
330 if status is not None:
331 tasks = [t for t in tasks if t.status == status]
333 # Sort by created_at descending
334 tasks.sort(key=lambda t: t.created_at, reverse=True)
335 return tasks[:limit]
337 def get_stats(self) -> dict[str, Any]:
338 """Get task manager statistics.
340 Returns:
341 Dict with counts of pending, running, completed, failed tasks.
342 """
343 with self._lock:
344 tasks = list(self._tasks.values())
346 pending = sum(1 for t in tasks if t.status == TaskStatus.PENDING)
347 running = sum(1 for t in tasks if t.status == TaskStatus.RUNNING)
348 completed = sum(1 for t in tasks if t.status == TaskStatus.COMPLETED)
349 failed = sum(1 for t in tasks if t.status == TaskStatus.FAILED)
351 return {
352 "total_tasks": len(tasks),
353 "pending": pending,
354 "running": running,
355 "completed": completed,
356 "failed": failed,
357 "max_concurrent": self._max_concurrent,
358 }
360 def shutdown(self) -> None:
361 """Gracefully shutdown the task manager.
363 Stops accepting new tasks and closes the background event loop.
364 """
365 if self._loop is not None:
366 # Schedule loop stop
367 self._loop.call_soon_threadsafe(self._loop.stop)
368 self._loop = None
369 self._thread = None
370 logger.info("Background task manager shutdown complete")
372 def _cleanup_old_tasks(self) -> None:
373 """Remove old completed/failed tasks if history limit exceeded."""
374 if len(self._tasks) <= self._max_task_history:
375 return
377 # Get completed/failed tasks sorted by completion time
378 old_tasks = [
379 (tid, t.completed_at or t.created_at)
380 for tid, t in self._tasks.items()
381 if t.status in (TaskStatus.COMPLETED, TaskStatus.FAILED)
382 ]
383 old_tasks.sort(key=lambda x: x[1]) # Sort by time ascending
385 # Remove oldest until under limit
386 tasks_to_remove = len(self._tasks) - self._max_task_history
387 for tid, _ in old_tasks[:tasks_to_remove]:
388 del self._tasks[tid]
390 def cancel_task(self, task_id: str) -> bool:
391 """Cancel a pending task.
393 Note: Running tasks cannot be cancelled mid-execution.
395 Args:
396 task_id: The task ID to cancel.
398 Returns:
399 True if cancelled, False if not found or already running.
400 """
401 with self._lock:
402 task = self._tasks.get(task_id)
403 if task is None:
404 return False
405 if task.status == TaskStatus.PENDING:
406 task.status = TaskStatus.FAILED
407 task.error = "Cancelled by user"
408 task.completed_at = time.time()
409 return True
410 return False
412 def submit_ttl_sweep(
413 self,
414 user_id: str | None = None,
415 namespace: str | None = None,
416 ) -> str:
417 """Submit a TTL sweep task to delete expired memories in the background.
419 Args:
420 user_id: If provided, only sweep this user's expired memories.
421 If None, sweep all users.
422 namespace: If provided, only sweep this namespace.
424 Returns:
425 task_id that can be used to check status.
427 Raises:
428 RuntimeError: If max concurrent tasks limit reached.
429 """
430 with self._lock:
431 if self._running_count >= self._max_concurrent:
432 raise RuntimeError(
433 f"Max concurrent tasks ({self._max_concurrent}) reached. "
434 "Wait for a task to complete before submitting more."
435 )
437 task_id = str(uuid.uuid4())
438 created_at = time.time()
440 with self._lock:
441 self._tasks[task_id] = TaskResult(
442 task_id=task_id,
443 task_type=TaskType.TTL_SWEEP,
444 status=TaskStatus.PENDING,
445 created_at=created_at,
446 )
447 self._cleanup_old_tasks()
449 loop = self._get_loop()
450 coro = self._run_ttl_sweep(task_id, user_id, namespace)
451 asyncio.run_coroutine_threadsafe(coro, loop)
453 scope = f"user {user_id}" if user_id else "all users"
454 logger.info(f"Submitted ttl_sweep task {task_id} for {scope}")
455 return task_id
457 async def _run_ttl_sweep(
458 self,
459 task_id: str,
460 user_id: str | None,
461 namespace: str | None,
462 ) -> None:
463 """Run the TTL sweep task."""
464 from kemi import Memory
466 with self._lock:
467 self._tasks[task_id].status = TaskStatus.RUNNING
468 self._tasks[task_id].started_at = time.time()
469 self._running_count += 1
471 try:
472 mem = Memory()
473 deleted = await asyncio.to_thread(
474 mem.prune_expired,
475 user_id,
476 namespace,
477 )
478 result = {
479 "deleted_count": deleted,
480 "user_id": user_id,
481 "namespace": namespace,
482 }
484 with self._lock:
485 self._tasks[task_id].status = TaskStatus.COMPLETED
486 self._tasks[task_id].result = result
487 self._tasks[task_id].completed_at = time.time()
488 self._tasks[task_id].progress = 1.0
490 logger.info(
491 f"Completed ttl_sweep task {task_id}: deleted {deleted} memories"
492 )
494 except Exception as e:
495 logger.error(f"Failed ttl_sweep task {task_id}: {e}")
496 with self._lock:
497 self._tasks[task_id].status = TaskStatus.FAILED
498 self._tasks[task_id].error = str(e)
499 self._tasks[task_id].completed_at = time.time()
501 finally:
502 with self._lock:
503 self._running_count -= 1
506# Global task manager instance
507_task_manager: BackgroundTaskManager | None = None
510def get_task_manager() -> BackgroundTaskManager:
511 """Get or create the global task manager instance."""
512 global _task_manager
513 if _task_manager is None:
514 max_concurrent = int(__import__("os").environ.get("KEMI_MAX_BACKGROUND_TASKS", "3"))
515 _task_manager = BackgroundTaskManager(max_concurrent_tasks=max_concurrent)
516 return _task_manager