Coverage for agentos/protocols/a2a_store.py: 34%

127 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2A2A Task Store — persistent task and session storage for A2A protocol. 

3 

4Backends: InMemory (default), SQLite, custom. 

5""" 

6 

7from __future__ import annotations 

8 

9import json 

10import sqlite3 

11import threading 

12import time 

13from abc import ABC, abstractmethod 

14from contextlib import contextmanager 

15from typing import Any, Dict, Iterator, List, Optional 

16 

17from agentos.protocols.a2a import A2ATask, A2ASession, TaskState 

18 

19 

20class A2ATaskStore(ABC): 

21 """Abstract task store for A2A protocol persistence.""" 

22 

23 @abstractmethod 

24 def save_task(self, task: A2ATask) -> None: 

25 """Insert or update a task.""" 

26 ... 

27 

28 @abstractmethod 

29 def get_task(self, task_id: str) -> Optional[A2ATask]: 

30 """Retrieve a task by ID.""" 

31 ... 

32 

33 @abstractmethod 

34 def list_tasks( 

35 self, 

36 state: TaskState | None = None, 

37 limit: int = 100, 

38 offset: int = 0, 

39 agent: str = "", 

40 ) -> list[A2ATask]: 

41 """List tasks, optionally filtered by state/agent.""" 

42 ... 

43 

44 @abstractmethod 

45 def delete_task(self, task_id: str) -> bool: 

46 """Delete a task. Returns True if deleted.""" 

47 ... 

48 

49 @abstractmethod 

50 def cleanup_terminal( 

51 self, 

52 max_age_seconds: float = 3600.0, 

53 ) -> int: 

54 """Remove terminal tasks older than max_age. Returns count.""" 

55 ... 

56 

57 @abstractmethod 

58 def count(self, state: TaskState | None = None) -> int: 

59 """Count tasks, optionally filtered by state.""" 

60 ... 

61 

62 

63class InMemoryTaskStore(A2ATaskStore): 

64 """Fast, non-persistent task store for development/testing.""" 

65 

66 def __init__(self): 

67 self._tasks: dict[str, A2ATask] = {} 

68 self._lock = threading.Lock() 

69 

70 def save_task(self, task: A2ATask) -> None: 

71 with self._lock: 

72 self._tasks[task.task_id] = task 

73 

74 def get_task(self, task_id: str) -> Optional[A2ATask]: 

75 with self._lock: 

76 return self._tasks.get(task_id) 

77 

78 def list_tasks( 

79 self, 

80 state: TaskState | None = None, 

81 limit: int = 100, 

82 offset: int = 0, 

83 agent: str = "", 

84 ) -> list[A2ATask]: 

85 with self._lock: 

86 tasks = list(self._tasks.values()) 

87 if state: 

88 tasks = [t for t in tasks if t.state == state] 

89 if agent: 

90 tasks = [t for t in tasks if t.meta.get("target_agent") == agent] 

91 return tasks[offset : offset + limit] 

92 

93 def delete_task(self, task_id: str) -> bool: 

94 with self._lock: 

95 if task_id in self._tasks: 

96 del self._tasks[task_id] 

97 return True 

98 return False 

99 

100 def cleanup_terminal(self, max_age_seconds: float = 3600.0) -> int: 

101 now = time.time() 

102 with self._lock: 

103 to_del = [ 

104 tid 

105 for tid, t in self._tasks.items() 

106 if t.is_terminal() and (now - t._updated) > max_age_seconds 

107 ] 

108 for tid in to_del: 

109 del self._tasks[tid] 

110 return len(to_del) 

111 

112 def count(self, state: TaskState | None = None) -> int: 

113 if state is None: 

114 with self._lock: 

115 return len(self._tasks) 

116 tasks = self.list_tasks(state=state, limit=999999) 

117 return len(tasks) 

118 

119 

120class SqliteTaskStore(A2ATaskStore): 

121 """Persistent SQLite-backed task store for production use.""" 

122 

123 SCHEMA = """ 

124 CREATE TABLE IF NOT EXISTS a2a_tasks ( 

125 task_id TEXT PRIMARY KEY, 

126 state TEXT NOT NULL DEFAULT 'submitted', 

127 input_json TEXT, 

128 output_json TEXT, 

129 artifacts_json TEXT DEFAULT '[]', 

130 error TEXT, 

131 meta_json TEXT DEFAULT '{}', 

132 created REAL NOT NULL, 

133 updated REAL NOT NULL, 

134 agent TEXT DEFAULT '' 

135 ); 

136 CREATE INDEX IF NOT EXISTS idx_a2a_state ON a2a_tasks(state); 

137 CREATE INDEX IF NOT EXISTS idx_a2a_agent ON a2a_tasks(agent); 

138 CREATE INDEX IF NOT EXISTS idx_a2a_updated ON a2a_tasks(updated); 

139 """ 

140 

141 def __init__(self, db_path: str = ":memory:"): 

142 self.db_path = db_path 

143 self._local = threading.local() 

144 self._init_db() 

145 

146 def _init_db(self): 

147 with self._conn() as conn: 

148 conn.executescript(self.SCHEMA) 

149 

150 @contextmanager 

151 def _conn(self) -> Iterator[sqlite3.Connection]: 

152 if not hasattr(self._local, "conn") or self._local.conn is None: 

153 conn = sqlite3.connect(self.db_path, check_same_thread=False) 

154 conn.row_factory = sqlite3.Row 

155 self._local.conn = conn 

156 yield self._local.conn 

157 

158 def save_task(self, task: A2ATask) -> None: 

159 with self._conn() as conn: 

160 conn.execute( 

161 """INSERT OR REPLACE INTO a2a_tasks 

162 (task_id, state, input_json, output_json, artifacts_json, 

163 error, meta_json, created, updated, agent) 

164 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", 

165 ( 

166 task.task_id, 

167 task.state.value, 

168 json.dumps(task.input.to_dict()) if task.input else None, 

169 json.dumps(task.output.to_dict()) if task.output else None, 

170 json.dumps([a.to_dict() for a in task.artifacts]), 

171 task.error, 

172 json.dumps(task.meta), 

173 task._created, 

174 task._updated, 

175 task.meta.get("target_agent", ""), 

176 ), 

177 ) 

178 conn.commit() 

179 

180 def get_task(self, task_id: str) -> Optional[A2ATask]: 

181 with self._conn() as conn: 

182 row = conn.execute( 

183 "SELECT * FROM a2a_tasks WHERE task_id = ?", (task_id,) 

184 ).fetchone() 

185 if row is None: 

186 return None 

187 return self._row_to_task(row) 

188 

189 def list_tasks( 

190 self, 

191 state: TaskState | None = None, 

192 limit: int = 100, 

193 offset: int = 0, 

194 agent: str = "", 

195 ) -> list[A2ATask]: 

196 query = "SELECT * FROM a2a_tasks WHERE 1=1" 

197 params: list[Any] = [] 

198 if state: 

199 query += " AND state = ?" 

200 params.append(state.value) 

201 if agent: 

202 query += " AND agent = ?" 

203 params.append(agent) 

204 query += " ORDER BY updated DESC LIMIT ? OFFSET ?" 

205 params.extend([limit, offset]) 

206 

207 with self._conn() as conn: 

208 rows = conn.execute(query, params).fetchall() 

209 return [self._row_to_task(r) for r in rows] 

210 

211 def delete_task(self, task_id: str) -> bool: 

212 with self._conn() as conn: 

213 cur = conn.execute("DELETE FROM a2a_tasks WHERE task_id = ?", (task_id,)) 

214 conn.commit() 

215 return cur.rowcount > 0 

216 

217 def cleanup_terminal(self, max_age_seconds: float = 3600.0) -> int: 

218 cutoff = time.time() - max_age_seconds 

219 with self._conn() as conn: 

220 cur = conn.execute( 

221 """DELETE FROM a2a_tasks 

222 WHERE state IN ('completed', 'failed', 'cancelled') 

223 AND updated < ?""", 

224 (cutoff,), 

225 ) 

226 conn.commit() 

227 return cur.rowcount 

228 

229 def count(self, state: TaskState | None = None) -> int: 

230 query = "SELECT COUNT(*) FROM a2a_tasks" 

231 params: list[Any] = [] 

232 if state: 

233 query += " WHERE state = ?" 

234 params.append(state.value) 

235 with self._conn() as conn: 

236 return conn.execute(query, params).fetchone()[0] 

237 

238 def _row_to_task(self, row) -> A2ATask: 

239 from agentos.protocols.a2a import A2AMessage, A2AArtifact 

240 

241 task = A2ATask( 

242 task_id=row["task_id"], 

243 state=TaskState(row["state"]), 

244 error=row["error"], 

245 meta=json.loads(row["meta_json"] or "{}"), 

246 _created=row["created"], 

247 _updated=row["updated"], 

248 ) 

249 if row["input_json"]: 

250 task.input = A2AMessage.from_dict(json.loads(row["input_json"])) 

251 if row["output_json"]: 

252 task.output = A2AMessage.from_dict(json.loads(row["output_json"])) 

253 task.artifacts = [ 

254 A2AArtifact.from_dict(a) 

255 for a in json.loads(row["artifacts_json"] or "[]") 

256 ] 

257 return task