Coverage for agentos/protocols/a2a_store.py: 34%
127 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2A2A Task Store — persistent task and session storage for A2A protocol.
4Backends: InMemory (default), SQLite, custom.
5"""
7from __future__ import annotations
9import json
10import sqlite3
11import threading
12import time
13from abc import ABC, abstractmethod
14from contextlib import contextmanager
15from typing import Any, Dict, Iterator, List, Optional
17from agentos.protocols.a2a import A2ATask, A2ASession, TaskState
20class A2ATaskStore(ABC):
21 """Abstract task store for A2A protocol persistence."""
23 @abstractmethod
24 def save_task(self, task: A2ATask) -> None:
25 """Insert or update a task."""
26 ...
28 @abstractmethod
29 def get_task(self, task_id: str) -> Optional[A2ATask]:
30 """Retrieve a task by ID."""
31 ...
33 @abstractmethod
34 def list_tasks(
35 self,
36 state: TaskState | None = None,
37 limit: int = 100,
38 offset: int = 0,
39 agent: str = "",
40 ) -> list[A2ATask]:
41 """List tasks, optionally filtered by state/agent."""
42 ...
44 @abstractmethod
45 def delete_task(self, task_id: str) -> bool:
46 """Delete a task. Returns True if deleted."""
47 ...
49 @abstractmethod
50 def cleanup_terminal(
51 self,
52 max_age_seconds: float = 3600.0,
53 ) -> int:
54 """Remove terminal tasks older than max_age. Returns count."""
55 ...
57 @abstractmethod
58 def count(self, state: TaskState | None = None) -> int:
59 """Count tasks, optionally filtered by state."""
60 ...
63class InMemoryTaskStore(A2ATaskStore):
64 """Fast, non-persistent task store for development/testing."""
66 def __init__(self):
67 self._tasks: dict[str, A2ATask] = {}
68 self._lock = threading.Lock()
70 def save_task(self, task: A2ATask) -> None:
71 with self._lock:
72 self._tasks[task.task_id] = task
74 def get_task(self, task_id: str) -> Optional[A2ATask]:
75 with self._lock:
76 return self._tasks.get(task_id)
78 def list_tasks(
79 self,
80 state: TaskState | None = None,
81 limit: int = 100,
82 offset: int = 0,
83 agent: str = "",
84 ) -> list[A2ATask]:
85 with self._lock:
86 tasks = list(self._tasks.values())
87 if state:
88 tasks = [t for t in tasks if t.state == state]
89 if agent:
90 tasks = [t for t in tasks if t.meta.get("target_agent") == agent]
91 return tasks[offset : offset + limit]
93 def delete_task(self, task_id: str) -> bool:
94 with self._lock:
95 if task_id in self._tasks:
96 del self._tasks[task_id]
97 return True
98 return False
100 def cleanup_terminal(self, max_age_seconds: float = 3600.0) -> int:
101 now = time.time()
102 with self._lock:
103 to_del = [
104 tid
105 for tid, t in self._tasks.items()
106 if t.is_terminal() and (now - t._updated) > max_age_seconds
107 ]
108 for tid in to_del:
109 del self._tasks[tid]
110 return len(to_del)
112 def count(self, state: TaskState | None = None) -> int:
113 if state is None:
114 with self._lock:
115 return len(self._tasks)
116 tasks = self.list_tasks(state=state, limit=999999)
117 return len(tasks)
120class SqliteTaskStore(A2ATaskStore):
121 """Persistent SQLite-backed task store for production use."""
123 SCHEMA = """
124 CREATE TABLE IF NOT EXISTS a2a_tasks (
125 task_id TEXT PRIMARY KEY,
126 state TEXT NOT NULL DEFAULT 'submitted',
127 input_json TEXT,
128 output_json TEXT,
129 artifacts_json TEXT DEFAULT '[]',
130 error TEXT,
131 meta_json TEXT DEFAULT '{}',
132 created REAL NOT NULL,
133 updated REAL NOT NULL,
134 agent TEXT DEFAULT ''
135 );
136 CREATE INDEX IF NOT EXISTS idx_a2a_state ON a2a_tasks(state);
137 CREATE INDEX IF NOT EXISTS idx_a2a_agent ON a2a_tasks(agent);
138 CREATE INDEX IF NOT EXISTS idx_a2a_updated ON a2a_tasks(updated);
139 """
141 def __init__(self, db_path: str = ":memory:"):
142 self.db_path = db_path
143 self._local = threading.local()
144 self._init_db()
146 def _init_db(self):
147 with self._conn() as conn:
148 conn.executescript(self.SCHEMA)
150 @contextmanager
151 def _conn(self) -> Iterator[sqlite3.Connection]:
152 if not hasattr(self._local, "conn") or self._local.conn is None:
153 conn = sqlite3.connect(self.db_path, check_same_thread=False)
154 conn.row_factory = sqlite3.Row
155 self._local.conn = conn
156 yield self._local.conn
158 def save_task(self, task: A2ATask) -> None:
159 with self._conn() as conn:
160 conn.execute(
161 """INSERT OR REPLACE INTO a2a_tasks
162 (task_id, state, input_json, output_json, artifacts_json,
163 error, meta_json, created, updated, agent)
164 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
165 (
166 task.task_id,
167 task.state.value,
168 json.dumps(task.input.to_dict()) if task.input else None,
169 json.dumps(task.output.to_dict()) if task.output else None,
170 json.dumps([a.to_dict() for a in task.artifacts]),
171 task.error,
172 json.dumps(task.meta),
173 task._created,
174 task._updated,
175 task.meta.get("target_agent", ""),
176 ),
177 )
178 conn.commit()
180 def get_task(self, task_id: str) -> Optional[A2ATask]:
181 with self._conn() as conn:
182 row = conn.execute(
183 "SELECT * FROM a2a_tasks WHERE task_id = ?", (task_id,)
184 ).fetchone()
185 if row is None:
186 return None
187 return self._row_to_task(row)
189 def list_tasks(
190 self,
191 state: TaskState | None = None,
192 limit: int = 100,
193 offset: int = 0,
194 agent: str = "",
195 ) -> list[A2ATask]:
196 query = "SELECT * FROM a2a_tasks WHERE 1=1"
197 params: list[Any] = []
198 if state:
199 query += " AND state = ?"
200 params.append(state.value)
201 if agent:
202 query += " AND agent = ?"
203 params.append(agent)
204 query += " ORDER BY updated DESC LIMIT ? OFFSET ?"
205 params.extend([limit, offset])
207 with self._conn() as conn:
208 rows = conn.execute(query, params).fetchall()
209 return [self._row_to_task(r) for r in rows]
211 def delete_task(self, task_id: str) -> bool:
212 with self._conn() as conn:
213 cur = conn.execute("DELETE FROM a2a_tasks WHERE task_id = ?", (task_id,))
214 conn.commit()
215 return cur.rowcount > 0
217 def cleanup_terminal(self, max_age_seconds: float = 3600.0) -> int:
218 cutoff = time.time() - max_age_seconds
219 with self._conn() as conn:
220 cur = conn.execute(
221 """DELETE FROM a2a_tasks
222 WHERE state IN ('completed', 'failed', 'cancelled')
223 AND updated < ?""",
224 (cutoff,),
225 )
226 conn.commit()
227 return cur.rowcount
229 def count(self, state: TaskState | None = None) -> int:
230 query = "SELECT COUNT(*) FROM a2a_tasks"
231 params: list[Any] = []
232 if state:
233 query += " WHERE state = ?"
234 params.append(state.value)
235 with self._conn() as conn:
236 return conn.execute(query, params).fetchone()[0]
238 def _row_to_task(self, row) -> A2ATask:
239 from agentos.protocols.a2a import A2AMessage, A2AArtifact
241 task = A2ATask(
242 task_id=row["task_id"],
243 state=TaskState(row["state"]),
244 error=row["error"],
245 meta=json.loads(row["meta_json"] or "{}"),
246 _created=row["created"],
247 _updated=row["updated"],
248 )
249 if row["input_json"]:
250 task.input = A2AMessage.from_dict(json.loads(row["input_json"]))
251 if row["output_json"]:
252 task.output = A2AMessage.from_dict(json.loads(row["output_json"]))
253 task.artifacts = [
254 A2AArtifact.from_dict(a)
255 for a in json.loads(row["artifacts_json"] or "[]")
256 ]
257 return task