Coverage for agentos/queue/task_queue.py: 43%

148 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2AgentOS v0.40 Task Queue — 异步任务调度与重试。 

3支持:内存队列(开发)/ Redis队列(生产)、优先级、重试、死信队列。 

4""" 

5 

6from __future__ import annotations 

7 

8import asyncio 

9import heapq 

10import time 

11import uuid 

12from dataclasses import dataclass, field 

13from enum import Enum 

14from typing import Optional, Callable, Any 

15 

16 

17class TaskState(str, Enum): 

18 

19 """任务状态枚举。""" 

20 

21 PENDING = "pending" 

22 RUNNING = "running" 

23 SUCCESS = "success" 

24 FAILED = "failed" 

25 RETRYING = "retrying" 

26 CANCELLED = "cancelled" 

27 DEAD = "dead" # 死信 

28 

29 

30class TaskPriority(int, Enum): 

31 

32 """任务优先级枚举。""" 

33 

34 LOW = 0 

35 NORMAL = 50 

36 HIGH = 100 

37 CRITICAL = 200 

38 

39 

40@dataclass(order=True) 

41class QueuedTask: 

42 """带优先级的任务节点(priority取负以实现最大堆)。""" 

43 priority: int # -priority for max-heap 

44 created_at: float = field(compare=False) 

45 task: "ScheduledTask" = field(compare=False) 

46 

47 

48@dataclass 

49class ScheduledTask: 

50 """调度任务。""" 

51 id: str = field(default_factory=lambda: uuid.uuid4().hex[:12]) 

52 name: str = "" 

53 payload: dict = field(default_factory=dict) 

54 priority: TaskPriority = TaskPriority.NORMAL 

55 state: TaskState = TaskState.PENDING 

56 max_retries: int = 3 

57 retry_delay: float = 1.0 # 秒 

58 timeout: float = 60.0 

59 callback: Callable | None = field(default=None, repr=False) 

60 result: Any = None 

61 error: str = "" 

62 retry_count: int = 0 

63 created_at: float = field(default_factory=time.time) 

64 started_at: float = 0 

65 completed_at: float = 0 

66 _tags: dict = field(default_factory=dict) 

67 

68 

69class MemoryQueue: 

70 """基于堆内存的任务队列 — 开发环境默认。""" 

71 

72 def __init__(self, max_size: int = 10000): 

73 self._heap: list[QueuedTask] = [] 

74 self._pending: dict[str, ScheduledTask] = {} 

75 self._dead: list[ScheduledTask] = [] 

76 self.max_size = max_size 

77 self._lock = asyncio.Lock() 

78 

79 async def enqueue(self, task: ScheduledTask) -> str: 

80 async with self._lock: 

81 if len(self._pending) >= self.max_size: 

82 raise RuntimeError(f"Queue full ({self.max_size})") 

83 self._pending[task.id] = task 

84 heapq.heappush(self._heap, QueuedTask(priority=-task.priority.value, created_at=task.created_at, task=task)) 

85 return task.id 

86 

87 async def dequeue(self) -> ScheduledTask | None: 

88 async with self._lock: 

89 while self._heap: 

90 qt = heapq.heappop(self._heap) 

91 task = self._pending.pop(qt.task.id, None) 

92 if task and task.state == TaskState.PENDING: 

93 return task 

94 return None 

95 

96 async def peek(self) -> ScheduledTask | None: 

97 async with self._lock: 

98 if self._heap: 

99 return self._heap[0].task 

100 return None 

101 

102 def pending_count(self) -> int: 

103 return len(self._heap) 

104 

105 def dead_count(self) -> int: 

106 return len(self._dead) 

107 

108 async def move_to_dead(self, task: ScheduledTask): 

109 async with self._lock: 

110 task.state = TaskState.DEAD 

111 self._dead.append(task) 

112 self._pending.pop(task.id, None) 

113 

114 def stats(self) -> dict: 

115 return {"pending": len(self._heap), "dead": len(self._dead), "max_size": self.max_size} 

116 

117 

118class TaskQueue: 

119 """任务队列管理器。""" 

120 

121 def __init__(self, queue: MemoryQueue | None = None, concurrency: int = 4): 

122 self._queue = queue or MemoryQueue() 

123 self._concurrency = concurrency 

124 self._running: set[str] = set() 

125 self._callbacks: dict[str, Callable] = {} 

126 self._semaphore = asyncio.Semaphore(concurrency) 

127 self._running_flag = False 

128 

129 def register_callback(self, task_name: str, handler: Callable): 

130 """注册任务处理器。""" 

131 self._callbacks[task_name] = handler 

132 

133 async def submit(self, task: ScheduledTask) -> str: 

134 if task.name not in self._callbacks: 

135 raise ValueError(f"No handler registered for task: {task.name}") 

136 task_id = await self._queue.enqueue(task) 

137 return task_id 

138 

139 async def start(self): 

140 """启动Worker循环。""" 

141 self._running_flag = True 

142 while self._running_flag: 

143 task = await self._queue.dequeue() 

144 if not task: 

145 await asyncio.sleep(0.1) 

146 continue 

147 asyncio.create_task(self._execute(task)) 

148 

149 def stop(self): 

150 self._running_flag = False 

151 

152 async def _execute(self, task: ScheduledTask): 

153 async with self._semaphore: 

154 self._running.add(task.id) 

155 task.state = TaskState.RUNNING 

156 task.started_at = time.time() 

157 

158 handler = self._callbacks.get(task.name) 

159 if not handler: 

160 task.state = TaskState.FAILED 

161 task.error = f"No handler: {task.name}" 

162 self._running.discard(task.id) 

163 return 

164 

165 try: 

166 result = handler(task.payload) 

167 if asyncio.iscoroutine(result): 

168 result = await asyncio.wait_for(result, timeout=task.timeout) 

169 task.result = result 

170 task.state = TaskState.SUCCESS 

171 except asyncio.TimeoutError: 

172 task.error = f"Timeout after {task.timeout}s" 

173 await self._handle_failure(task) 

174 except Exception as e: 

175 task.error = str(e) 

176 await self._handle_failure(task) 

177 finally: 

178 task.completed_at = time.time() 

179 self._running.discard(task.id) 

180 

181 async def _handle_failure(self, task: ScheduledTask): 

182 if task.retry_count < task.max_retries: 

183 task.retry_count += 1 

184 task.state = TaskState.RETRYING 

185 await asyncio.sleep(task.retry_delay * (2 ** (task.retry_count - 1))) # 指数退避 

186 task.state = TaskState.PENDING 

187 task.priority = TaskPriority(task.priority.value + 10) # 提升优先级 

188 await self._queue.enqueue(task) 

189 else: 

190 task.state = TaskState.FAILED 

191 await self._queue.move_to_dead(task) 

192 

193 def cancel(self, task_id: str): 

194 """取消任务。""" 

195 task = self._queue._pending.get(task_id) 

196 if task and task.state in (TaskState.PENDING, TaskState.RETRYING): 

197 task.state = TaskState.CANCELLED 

198 

199 def stats(self) -> dict: 

200 return { 

201 "running": len(self._running), 

202 "concurrency": self._concurrency, 

203 "queue": self._queue.stats(), 

204 "handlers": list(self._callbacks.keys()), 

205 }