Coverage for agentos/memory/session.py: 32%
226 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2Persistent Thread Context (PTC) — Session Manager.
4OpenClaw-style long-running session management:
5 - Heartbeat: periodic ping to keep sessions alive
6 - Auto-suspend: idle sessions that exceed TTL
7 - State recovery: resume a session exactly where it left off
8 - Cross-session memory: carry context across disconnected sessions
9 - Event hooks: on_suspend, on_resume, on_expire
11Design:
12 SessionManager
13 └─ Session (per-thread lifecycle)
14 ├─ heartbeat() — keep alive
15 ├─ suspend() — save state, pause
16 ├─ resume() — restore state, continue
17 └─ expire() — cleanup after TTL
18"""
20from __future__ import annotations
22import asyncio
23import json
24import time
25import uuid
26from dataclasses import dataclass, field
27from enum import Enum
28from pathlib import Path
29from typing import Any, Callable, Coroutine, Optional
31import aiosqlite
34# ── Session Models ──
36class SessionStatus(str, Enum):
37 ACTIVE = "active" # Currently running
38 IDLE = "idle" # Alive but no recent activity
39 SUSPENDED = "suspended" # Paused, state saved
40 EXPIRED = "expired" # Timed out, cleaned up
41 ERROR = "error" # Crashed but state saved
44@dataclass
45class SessionState:
46 """Serializable state snapshot for a session."""
48 conversation_history: list[dict] = field(default_factory=list)
49 working_memory: dict[str, Any] = field(default_factory=dict)
50 agent_context: dict[str, Any] = field(default_factory=dict)
51 tool_state: dict[str, Any] = field(default_factory=dict)
52 metadata: dict[str, Any] = field(default_factory=dict)
54 def to_json(self) -> str:
55 return json.dumps({
56 "conversation_history": self.conversation_history,
57 "working_memory": self.working_memory,
58 "agent_context": self.agent_context,
59 "tool_state": self.tool_state,
60 "metadata": self.metadata,
61 }, ensure_ascii=False, default=str)
63 @classmethod
64 def from_json(cls, data: str) -> "SessionState":
65 d = json.loads(data)
66 return cls(
67 conversation_history=d.get("conversation_history", []),
68 working_memory=d.get("working_memory", {}),
69 agent_context=d.get("agent_context", {}),
70 tool_state=d.get("tool_state", {}),
71 metadata=d.get("metadata", {}),
72 )
75@dataclass
76class Session:
77 """A single PTC session (one conversational thread)."""
79 id: str = field(default_factory=lambda: uuid.uuid4().hex[:16])
80 name: str = ""
81 user_id: str = "default"
82 status: SessionStatus = SessionStatus.ACTIVE
84 created_at: float = field(default_factory=time.time)
85 last_heartbeat: float = field(default_factory=time.time)
86 last_activity: float = field(default_factory=time.time)
88 state: SessionState = field(default_factory=SessionState)
90 # Config
91 heartbeat_interval: float = 30.0 # seconds between heartbeats
92 idle_timeout: float = 300.0 # idle → suspend (5 min)
93 absolute_ttl: float = 86400.0 # max lifetime (24h)
94 max_history_turns: int = 1000
96 # Internal
97 _heartbeat_task: Optional[asyncio.Task] = None
99 @property
100 def age_seconds(self) -> float:
101 return time.time() - self.created_at
103 @property
104 def idle_seconds(self) -> float:
105 return time.time() - self.last_activity
107 @property
108 def is_expired(self) -> bool:
109 return self.age_seconds > self.absolute_ttl
111 def to_dict(self) -> dict:
112 return {
113 "id": self.id,
114 "name": self.name,
115 "user_id": self.user_id,
116 "status": self.status.value,
117 "created_at": self.created_at,
118 "last_heartbeat": self.last_heartbeat,
119 "last_activity": self.last_activity,
120 "heartbeat_interval": self.heartbeat_interval,
121 "idle_timeout": self.idle_timeout,
122 "absolute_ttl": self.absolute_ttl,
123 "state": self.state.to_json(),
124 }
127# ── Session Manager ──
129class SessionManager:
130 """Manage PTC sessions with heartbeat, suspend/resume, and persistence.
132 Usage:
133 manager = SessionManager(db_path="~/.agentos/sessions.db")
135 # Create a new session
136 session = await manager.create(name="research-thread")
138 # Heartbeat loop (runs in background)
139 await manager.start_heartbeat(session)
141 # Suspend on idle
142 await manager.suspend(session.id)
144 # Resume later — state restored
145 session = await manager.resume(session.id)
147 # Hooks
148 manager.on_suspend(lambda s: print(f"{s.name} suspended"))
149 manager.on_resume(lambda s: print(f"{s.name} resumed"))
150 """
152 def __init__(
153 self,
154 db_path: str = "",
155 max_concurrent: int = 100,
156 ):
157 db_path = Path(db_path) if db_path else Path.home() / ".agentos" / "sessions.db"
158 db_path.parent.mkdir(parents=True, exist_ok=True)
159 self._db_path = str(db_path)
160 self._max_concurrent = max_concurrent
162 self._sessions: dict[str, Session] = {}
163 self._heartbeat_tasks: dict[str, asyncio.Task] = {}
164 self._hooks: dict[str, list[Callable]] = {
165 "create": [],
166 "suspend": [],
167 "resume": [],
168 "expire": [],
169 "heartbeat_missed": [],
170 }
172 # ── Hooks ──
174 def on(self, event: str):
175 """Decorator: register a hook for session events."""
176 def decorator(fn):
177 self._hooks.setdefault(event, []).append(fn)
178 return fn
179 return decorator
181 def on_create(self, fn: Callable[[Session], Any]): self._hooks["create"].append(fn)
182 def on_suspend(self, fn: Callable[[Session], Any]): self._hooks["suspend"].append(fn)
183 def on_resume(self, fn: Callable[[Session], Any]): self._hooks["resume"].append(fn)
184 def on_expire(self, fn: Callable[[Session], Any]): self._hooks["expire"].append(fn)
186 async def _fire(self, event: str, session: Session):
187 for hook in self._hooks.get(event, []):
188 try:
189 result = hook(session)
190 if asyncio.iscoroutine(result):
191 await result
192 except Exception:
193 pass
195 # ── Session Lifecycle ──
197 async def create(
198 self,
199 name: str = "",
200 user_id: str = "default",
201 heartbeat_interval: float = 30.0,
202 idle_timeout: float = 300.0,
203 absolute_ttl: float = 86400.0,
204 ) -> Session:
205 """Create a new PTC session."""
206 if len(self._sessions) >= self._max_concurrent:
207 oldest = min(self._sessions.values(), key=lambda s: s.last_activity)
208 await self.expire(oldest.id)
210 session = Session(
211 name=name or f"session-{uuid.uuid4().hex[:6]}",
212 user_id=user_id,
213 heartbeat_interval=heartbeat_interval,
214 idle_timeout=idle_timeout,
215 absolute_ttl=absolute_ttl,
216 )
218 self._sessions[session.id] = session
219 await self._persist(session)
220 await self._fire("create", session)
222 return session
224 async def suspend(self, session_id: str) -> bool:
225 """Suspend a session — save state, stop heartbeat."""
226 session = self._sessions.get(session_id)
227 if not session:
228 return False
230 session.status = SessionStatus.SUSPENDED
232 # Stop heartbeat
233 if session_id in self._heartbeat_tasks:
234 self._heartbeat_tasks[session_id].cancel()
235 del self._heartbeat_tasks[session_id]
237 await self._persist(session)
238 await self._fire("suspend", session)
240 return True
242 async def resume(self, session_id: str) -> Optional[Session]:
243 """Resume a suspended session — restore state, restart heartbeat."""
244 session = self._sessions.get(session_id)
246 # Try loading from DB if not in memory
247 if not session:
248 session = await self._load_from_db(session_id)
249 if not session:
250 return None
252 if session.status == SessionStatus.EXPIRED:
253 return None
255 session.status = SessionStatus.ACTIVE
256 session.last_activity = time.time()
257 session.last_heartbeat = time.time()
259 self._sessions[session.id] = session
260 await self._fire("resume", session)
262 return session
264 async def expire(self, session_id: str) -> bool:
265 """Permanently expire a session — cleanup."""
266 session = self._sessions.pop(session_id, None)
267 if not session:
268 return False
270 session.status = SessionStatus.EXPIRED
272 if session_id in self._heartbeat_tasks:
273 self._heartbeat_tasks[session_id].cancel()
274 del self._heartbeat_tasks[session_id]
276 await self._persist(session)
277 await self._fire("expire", session)
279 return True
281 async def destroy(self, session_id: str) -> bool:
282 """Hard delete a session from memory and DB."""
283 await self.expire(session_id)
284 async with aiosqlite.connect(self._db_path) as db:
285 await db.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
286 await db.commit()
287 return True
289 # ── Heartbeat ──
291 async def start_heartbeat(self, session: Session) -> None:
292 """Start background heartbeat for a session."""
293 if session.id in self._heartbeat_tasks:
294 return
296 async def _loop():
297 while True:
298 await asyncio.sleep(session.heartbeat_interval)
300 if session.id not in self._sessions:
301 return
303 session.last_heartbeat = time.time()
305 # Check idle timeout → suspend
306 if session.idle_seconds > session.idle_timeout:
307 await self.suspend(session.id)
308 return
310 # Check absolute TTL → expire
311 if session.is_expired:
312 await self.expire(session.id)
313 return
315 # Re-persist state snapshot
316 await self._persist(session)
318 self._heartbeat_tasks[session.id] = asyncio.create_task(_loop())
320 async def heartbeat(self, session_id: str) -> bool:
321 """Manual heartbeat ping. Returns False if session not found."""
322 session = self._sessions.get(session_id)
323 if not session:
324 return False
326 session.last_heartbeat = time.time()
327 session.last_activity = time.time()
329 # Re-activate if suspended
330 if session.status == SessionStatus.SUSPENDED:
331 await self.resume(session_id)
333 return True
335 # ── State Management ──
337 async def save_state(self, session_id: str, state: SessionState) -> bool:
338 """Save explicit state snapshot for a session."""
339 session = self._sessions.get(session_id)
340 if not session:
341 return False
343 session.state = state
344 session.last_activity = time.time()
345 await self._persist(session)
346 return True
348 async def get_state(self, session_id: str) -> Optional[SessionState]:
349 """Get the latest state snapshot for a session."""
350 session = self._sessions.get(session_id)
351 if session:
352 return session.state
354 session = await self._load_from_db(session_id)
355 return session.state if session else None
357 async def add_context(self, session_id: str, key: str, value: Any) -> bool:
358 """Add a key-value to the session's working memory."""
359 session = self._sessions.get(session_id)
360 if not session:
361 return False
362 session.state.working_memory[key] = value
363 session.last_activity = time.time()
364 return True
366 # ── Query ──
368 def get(self, session_id: str) -> Optional[Session]:
369 """Get an active session by ID."""
370 return self._sessions.get(session_id)
372 def list_active(self, user_id: str = "") -> list[Session]:
373 """List all active/idle sessions, optionally filtered by user."""
374 sessions = [s for s in self._sessions.values()
375 if s.status in (SessionStatus.ACTIVE, SessionStatus.IDLE)]
376 if user_id:
377 sessions = [s for s in sessions if s.user_id == user_id]
378 return sorted(sessions, key=lambda s: s.last_activity, reverse=True)
380 def list_suspended(self, user_id: str = "") -> list[Session]:
381 """List suspended sessions."""
382 sessions = [s for s in self._sessions.values()
383 if s.status == SessionStatus.SUSPENDED]
384 if user_id:
385 sessions = [s for s in sessions if s.user_id == user_id]
386 return sorted(sessions, key=lambda s: s.last_activity, reverse=True)
388 async def count(self) -> int:
389 """Total sessions in memory."""
390 return len(self._sessions)
392 # ── Monitor ──
394 async def monitor(self) -> dict[str, Any]:
395 """Get a monitoring snapshot of all sessions."""
396 active = 0
397 idle = 0
398 suspended = 0
400 for s in self._sessions.values():
401 if s.status == SessionStatus.ACTIVE:
402 active += 1
403 elif s.status == SessionStatus.IDLE:
404 idle += 1
405 elif s.status == SessionStatus.SUSPENDED:
406 suspended += 1
408 return {
409 "total": len(self._sessions),
410 "active": active,
411 "idle": idle,
412 "suspended": suspended,
413 "heartbeat_tasks": len(self._heartbeat_tasks),
414 }
416 # ── Persistence ──
418 async def _persist(self, session: Session) -> None:
419 """Save session to SQLite."""
420 try:
421 async with aiosqlite.connect(self._db_path) as db:
422 await db.execute("""
423 CREATE TABLE IF NOT EXISTS sessions (
424 id TEXT PRIMARY KEY,
425 name TEXT,
426 user_id TEXT,
427 status TEXT,
428 created_at REAL,
429 last_heartbeat REAL,
430 last_activity REAL,
431 heartbeat_interval REAL,
432 idle_timeout REAL,
433 absolute_ttl REAL,
434 state TEXT
435 )
436 """)
437 await db.execute("""
438 INSERT OR REPLACE INTO sessions
439 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
440 """, (
441 session.id, session.name, session.user_id,
442 session.status.value, session.created_at,
443 session.last_heartbeat, session.last_activity,
444 session.heartbeat_interval, session.idle_timeout,
445 session.absolute_ttl, session.state.to_json(),
446 ))
447 await db.commit()
448 except Exception:
449 pass
451 async def _load_from_db(self, session_id: str) -> Optional[Session]:
452 """Load a session from SQLite."""
453 try:
454 async with aiosqlite.connect(self._db_path) as db:
455 await db.execute("""
456 CREATE TABLE IF NOT EXISTS sessions (
457 id TEXT PRIMARY KEY, name TEXT, user_id TEXT,
458 status TEXT, created_at REAL, last_heartbeat REAL,
459 last_activity REAL, heartbeat_interval REAL,
460 idle_timeout REAL, absolute_ttl REAL, state TEXT
461 )
462 """)
463 cursor = await db.execute(
464 "SELECT * FROM sessions WHERE id = ?", (session_id,)
465 )
466 row = await cursor.fetchone()
467 if not row:
468 return None
470 session = Session(
471 id=row[0], name=row[1], user_id=row[2],
472 status=SessionStatus(row[3]),
473 created_at=row[4], last_heartbeat=row[5],
474 last_activity=row[6], heartbeat_interval=row[7],
475 idle_timeout=row[8], absolute_ttl=row[9],
476 state=SessionState.from_json(row[10]),
477 )
478 self._sessions[session.id] = session
479 return session
480 except Exception:
481 return None