Coverage for agentos/memory/session.py: 32%

226 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2Persistent Thread Context (PTC) — Session Manager. 

3 

4OpenClaw-style long-running session management: 

5 - Heartbeat: periodic ping to keep sessions alive 

6 - Auto-suspend: idle sessions that exceed TTL 

7 - State recovery: resume a session exactly where it left off 

8 - Cross-session memory: carry context across disconnected sessions 

9 - Event hooks: on_suspend, on_resume, on_expire 

10 

11Design: 

12 SessionManager 

13 └─ Session (per-thread lifecycle) 

14 ├─ heartbeat() — keep alive 

15 ├─ suspend() — save state, pause 

16 ├─ resume() — restore state, continue 

17 └─ expire() — cleanup after TTL 

18""" 

19 

20from __future__ import annotations 

21 

22import asyncio 

23import json 

24import time 

25import uuid 

26from dataclasses import dataclass, field 

27from enum import Enum 

28from pathlib import Path 

29from typing import Any, Callable, Coroutine, Optional 

30 

31import aiosqlite 

32 

33 

34# ── Session Models ── 

35 

36class SessionStatus(str, Enum): 

37 ACTIVE = "active" # Currently running 

38 IDLE = "idle" # Alive but no recent activity 

39 SUSPENDED = "suspended" # Paused, state saved 

40 EXPIRED = "expired" # Timed out, cleaned up 

41 ERROR = "error" # Crashed but state saved 

42 

43 

44@dataclass 

45class SessionState: 

46 """Serializable state snapshot for a session.""" 

47 

48 conversation_history: list[dict] = field(default_factory=list) 

49 working_memory: dict[str, Any] = field(default_factory=dict) 

50 agent_context: dict[str, Any] = field(default_factory=dict) 

51 tool_state: dict[str, Any] = field(default_factory=dict) 

52 metadata: dict[str, Any] = field(default_factory=dict) 

53 

54 def to_json(self) -> str: 

55 return json.dumps({ 

56 "conversation_history": self.conversation_history, 

57 "working_memory": self.working_memory, 

58 "agent_context": self.agent_context, 

59 "tool_state": self.tool_state, 

60 "metadata": self.metadata, 

61 }, ensure_ascii=False, default=str) 

62 

63 @classmethod 

64 def from_json(cls, data: str) -> "SessionState": 

65 d = json.loads(data) 

66 return cls( 

67 conversation_history=d.get("conversation_history", []), 

68 working_memory=d.get("working_memory", {}), 

69 agent_context=d.get("agent_context", {}), 

70 tool_state=d.get("tool_state", {}), 

71 metadata=d.get("metadata", {}), 

72 ) 

73 

74 

75@dataclass 

76class Session: 

77 """A single PTC session (one conversational thread).""" 

78 

79 id: str = field(default_factory=lambda: uuid.uuid4().hex[:16]) 

80 name: str = "" 

81 user_id: str = "default" 

82 status: SessionStatus = SessionStatus.ACTIVE 

83 

84 created_at: float = field(default_factory=time.time) 

85 last_heartbeat: float = field(default_factory=time.time) 

86 last_activity: float = field(default_factory=time.time) 

87 

88 state: SessionState = field(default_factory=SessionState) 

89 

90 # Config 

91 heartbeat_interval: float = 30.0 # seconds between heartbeats 

92 idle_timeout: float = 300.0 # idle → suspend (5 min) 

93 absolute_ttl: float = 86400.0 # max lifetime (24h) 

94 max_history_turns: int = 1000 

95 

96 # Internal 

97 _heartbeat_task: Optional[asyncio.Task] = None 

98 

99 @property 

100 def age_seconds(self) -> float: 

101 return time.time() - self.created_at 

102 

103 @property 

104 def idle_seconds(self) -> float: 

105 return time.time() - self.last_activity 

106 

107 @property 

108 def is_expired(self) -> bool: 

109 return self.age_seconds > self.absolute_ttl 

110 

111 def to_dict(self) -> dict: 

112 return { 

113 "id": self.id, 

114 "name": self.name, 

115 "user_id": self.user_id, 

116 "status": self.status.value, 

117 "created_at": self.created_at, 

118 "last_heartbeat": self.last_heartbeat, 

119 "last_activity": self.last_activity, 

120 "heartbeat_interval": self.heartbeat_interval, 

121 "idle_timeout": self.idle_timeout, 

122 "absolute_ttl": self.absolute_ttl, 

123 "state": self.state.to_json(), 

124 } 

125 

126 

127# ── Session Manager ── 

128 

129class SessionManager: 

130 """Manage PTC sessions with heartbeat, suspend/resume, and persistence. 

131 

132 Usage: 

133 manager = SessionManager(db_path="~/.agentos/sessions.db") 

134 

135 # Create a new session 

136 session = await manager.create(name="research-thread") 

137 

138 # Heartbeat loop (runs in background) 

139 await manager.start_heartbeat(session) 

140 

141 # Suspend on idle 

142 await manager.suspend(session.id) 

143 

144 # Resume later — state restored 

145 session = await manager.resume(session.id) 

146 

147 # Hooks 

148 manager.on_suspend(lambda s: print(f"{s.name} suspended")) 

149 manager.on_resume(lambda s: print(f"{s.name} resumed")) 

150 """ 

151 

152 def __init__( 

153 self, 

154 db_path: str = "", 

155 max_concurrent: int = 100, 

156 ): 

157 db_path = Path(db_path) if db_path else Path.home() / ".agentos" / "sessions.db" 

158 db_path.parent.mkdir(parents=True, exist_ok=True) 

159 self._db_path = str(db_path) 

160 self._max_concurrent = max_concurrent 

161 

162 self._sessions: dict[str, Session] = {} 

163 self._heartbeat_tasks: dict[str, asyncio.Task] = {} 

164 self._hooks: dict[str, list[Callable]] = { 

165 "create": [], 

166 "suspend": [], 

167 "resume": [], 

168 "expire": [], 

169 "heartbeat_missed": [], 

170 } 

171 

172 # ── Hooks ── 

173 

174 def on(self, event: str): 

175 """Decorator: register a hook for session events.""" 

176 def decorator(fn): 

177 self._hooks.setdefault(event, []).append(fn) 

178 return fn 

179 return decorator 

180 

181 def on_create(self, fn: Callable[[Session], Any]): self._hooks["create"].append(fn) 

182 def on_suspend(self, fn: Callable[[Session], Any]): self._hooks["suspend"].append(fn) 

183 def on_resume(self, fn: Callable[[Session], Any]): self._hooks["resume"].append(fn) 

184 def on_expire(self, fn: Callable[[Session], Any]): self._hooks["expire"].append(fn) 

185 

186 async def _fire(self, event: str, session: Session): 

187 for hook in self._hooks.get(event, []): 

188 try: 

189 result = hook(session) 

190 if asyncio.iscoroutine(result): 

191 await result 

192 except Exception: 

193 pass 

194 

195 # ── Session Lifecycle ── 

196 

197 async def create( 

198 self, 

199 name: str = "", 

200 user_id: str = "default", 

201 heartbeat_interval: float = 30.0, 

202 idle_timeout: float = 300.0, 

203 absolute_ttl: float = 86400.0, 

204 ) -> Session: 

205 """Create a new PTC session.""" 

206 if len(self._sessions) >= self._max_concurrent: 

207 oldest = min(self._sessions.values(), key=lambda s: s.last_activity) 

208 await self.expire(oldest.id) 

209 

210 session = Session( 

211 name=name or f"session-{uuid.uuid4().hex[:6]}", 

212 user_id=user_id, 

213 heartbeat_interval=heartbeat_interval, 

214 idle_timeout=idle_timeout, 

215 absolute_ttl=absolute_ttl, 

216 ) 

217 

218 self._sessions[session.id] = session 

219 await self._persist(session) 

220 await self._fire("create", session) 

221 

222 return session 

223 

224 async def suspend(self, session_id: str) -> bool: 

225 """Suspend a session — save state, stop heartbeat.""" 

226 session = self._sessions.get(session_id) 

227 if not session: 

228 return False 

229 

230 session.status = SessionStatus.SUSPENDED 

231 

232 # Stop heartbeat 

233 if session_id in self._heartbeat_tasks: 

234 self._heartbeat_tasks[session_id].cancel() 

235 del self._heartbeat_tasks[session_id] 

236 

237 await self._persist(session) 

238 await self._fire("suspend", session) 

239 

240 return True 

241 

242 async def resume(self, session_id: str) -> Optional[Session]: 

243 """Resume a suspended session — restore state, restart heartbeat.""" 

244 session = self._sessions.get(session_id) 

245 

246 # Try loading from DB if not in memory 

247 if not session: 

248 session = await self._load_from_db(session_id) 

249 if not session: 

250 return None 

251 

252 if session.status == SessionStatus.EXPIRED: 

253 return None 

254 

255 session.status = SessionStatus.ACTIVE 

256 session.last_activity = time.time() 

257 session.last_heartbeat = time.time() 

258 

259 self._sessions[session.id] = session 

260 await self._fire("resume", session) 

261 

262 return session 

263 

264 async def expire(self, session_id: str) -> bool: 

265 """Permanently expire a session — cleanup.""" 

266 session = self._sessions.pop(session_id, None) 

267 if not session: 

268 return False 

269 

270 session.status = SessionStatus.EXPIRED 

271 

272 if session_id in self._heartbeat_tasks: 

273 self._heartbeat_tasks[session_id].cancel() 

274 del self._heartbeat_tasks[session_id] 

275 

276 await self._persist(session) 

277 await self._fire("expire", session) 

278 

279 return True 

280 

281 async def destroy(self, session_id: str) -> bool: 

282 """Hard delete a session from memory and DB.""" 

283 await self.expire(session_id) 

284 async with aiosqlite.connect(self._db_path) as db: 

285 await db.execute("DELETE FROM sessions WHERE id = ?", (session_id,)) 

286 await db.commit() 

287 return True 

288 

289 # ── Heartbeat ── 

290 

291 async def start_heartbeat(self, session: Session) -> None: 

292 """Start background heartbeat for a session.""" 

293 if session.id in self._heartbeat_tasks: 

294 return 

295 

296 async def _loop(): 

297 while True: 

298 await asyncio.sleep(session.heartbeat_interval) 

299 

300 if session.id not in self._sessions: 

301 return 

302 

303 session.last_heartbeat = time.time() 

304 

305 # Check idle timeout → suspend 

306 if session.idle_seconds > session.idle_timeout: 

307 await self.suspend(session.id) 

308 return 

309 

310 # Check absolute TTL → expire 

311 if session.is_expired: 

312 await self.expire(session.id) 

313 return 

314 

315 # Re-persist state snapshot 

316 await self._persist(session) 

317 

318 self._heartbeat_tasks[session.id] = asyncio.create_task(_loop()) 

319 

320 async def heartbeat(self, session_id: str) -> bool: 

321 """Manual heartbeat ping. Returns False if session not found.""" 

322 session = self._sessions.get(session_id) 

323 if not session: 

324 return False 

325 

326 session.last_heartbeat = time.time() 

327 session.last_activity = time.time() 

328 

329 # Re-activate if suspended 

330 if session.status == SessionStatus.SUSPENDED: 

331 await self.resume(session_id) 

332 

333 return True 

334 

335 # ── State Management ── 

336 

337 async def save_state(self, session_id: str, state: SessionState) -> bool: 

338 """Save explicit state snapshot for a session.""" 

339 session = self._sessions.get(session_id) 

340 if not session: 

341 return False 

342 

343 session.state = state 

344 session.last_activity = time.time() 

345 await self._persist(session) 

346 return True 

347 

348 async def get_state(self, session_id: str) -> Optional[SessionState]: 

349 """Get the latest state snapshot for a session.""" 

350 session = self._sessions.get(session_id) 

351 if session: 

352 return session.state 

353 

354 session = await self._load_from_db(session_id) 

355 return session.state if session else None 

356 

357 async def add_context(self, session_id: str, key: str, value: Any) -> bool: 

358 """Add a key-value to the session's working memory.""" 

359 session = self._sessions.get(session_id) 

360 if not session: 

361 return False 

362 session.state.working_memory[key] = value 

363 session.last_activity = time.time() 

364 return True 

365 

366 # ── Query ── 

367 

368 def get(self, session_id: str) -> Optional[Session]: 

369 """Get an active session by ID.""" 

370 return self._sessions.get(session_id) 

371 

372 def list_active(self, user_id: str = "") -> list[Session]: 

373 """List all active/idle sessions, optionally filtered by user.""" 

374 sessions = [s for s in self._sessions.values() 

375 if s.status in (SessionStatus.ACTIVE, SessionStatus.IDLE)] 

376 if user_id: 

377 sessions = [s for s in sessions if s.user_id == user_id] 

378 return sorted(sessions, key=lambda s: s.last_activity, reverse=True) 

379 

380 def list_suspended(self, user_id: str = "") -> list[Session]: 

381 """List suspended sessions.""" 

382 sessions = [s for s in self._sessions.values() 

383 if s.status == SessionStatus.SUSPENDED] 

384 if user_id: 

385 sessions = [s for s in sessions if s.user_id == user_id] 

386 return sorted(sessions, key=lambda s: s.last_activity, reverse=True) 

387 

388 async def count(self) -> int: 

389 """Total sessions in memory.""" 

390 return len(self._sessions) 

391 

392 # ── Monitor ── 

393 

394 async def monitor(self) -> dict[str, Any]: 

395 """Get a monitoring snapshot of all sessions.""" 

396 active = 0 

397 idle = 0 

398 suspended = 0 

399 

400 for s in self._sessions.values(): 

401 if s.status == SessionStatus.ACTIVE: 

402 active += 1 

403 elif s.status == SessionStatus.IDLE: 

404 idle += 1 

405 elif s.status == SessionStatus.SUSPENDED: 

406 suspended += 1 

407 

408 return { 

409 "total": len(self._sessions), 

410 "active": active, 

411 "idle": idle, 

412 "suspended": suspended, 

413 "heartbeat_tasks": len(self._heartbeat_tasks), 

414 } 

415 

416 # ── Persistence ── 

417 

418 async def _persist(self, session: Session) -> None: 

419 """Save session to SQLite.""" 

420 try: 

421 async with aiosqlite.connect(self._db_path) as db: 

422 await db.execute(""" 

423 CREATE TABLE IF NOT EXISTS sessions ( 

424 id TEXT PRIMARY KEY, 

425 name TEXT, 

426 user_id TEXT, 

427 status TEXT, 

428 created_at REAL, 

429 last_heartbeat REAL, 

430 last_activity REAL, 

431 heartbeat_interval REAL, 

432 idle_timeout REAL, 

433 absolute_ttl REAL, 

434 state TEXT 

435 ) 

436 """) 

437 await db.execute(""" 

438 INSERT OR REPLACE INTO sessions 

439 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 

440 """, ( 

441 session.id, session.name, session.user_id, 

442 session.status.value, session.created_at, 

443 session.last_heartbeat, session.last_activity, 

444 session.heartbeat_interval, session.idle_timeout, 

445 session.absolute_ttl, session.state.to_json(), 

446 )) 

447 await db.commit() 

448 except Exception: 

449 pass 

450 

451 async def _load_from_db(self, session_id: str) -> Optional[Session]: 

452 """Load a session from SQLite.""" 

453 try: 

454 async with aiosqlite.connect(self._db_path) as db: 

455 await db.execute(""" 

456 CREATE TABLE IF NOT EXISTS sessions ( 

457 id TEXT PRIMARY KEY, name TEXT, user_id TEXT, 

458 status TEXT, created_at REAL, last_heartbeat REAL, 

459 last_activity REAL, heartbeat_interval REAL, 

460 idle_timeout REAL, absolute_ttl REAL, state TEXT 

461 ) 

462 """) 

463 cursor = await db.execute( 

464 "SELECT * FROM sessions WHERE id = ?", (session_id,) 

465 ) 

466 row = await cursor.fetchone() 

467 if not row: 

468 return None 

469 

470 session = Session( 

471 id=row[0], name=row[1], user_id=row[2], 

472 status=SessionStatus(row[3]), 

473 created_at=row[4], last_heartbeat=row[5], 

474 last_activity=row[6], heartbeat_interval=row[7], 

475 idle_timeout=row[8], absolute_ttl=row[9], 

476 state=SessionState.from_json(row[10]), 

477 ) 

478 self._sessions[session.id] = session 

479 return session 

480 except Exception: 

481 return None