Coverage for agentos/state/schema.py: 0%
329 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:36 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 16:36 +0800
1"""
2AgentOS v1.14.0 — 结构化 Agent 状态管理系统。
4基因来源: LangGraph Pydantic State Schema + AgentOS Checkpoint。
6核心设计:
7- AgentState: 强类型的全局 Agent 运行时状态,Pydantic v2 驱动
8- 支持 JSON Schema 自动生成、验证、序列化
9- 与 Checkpoint 系统无缝对接
10- 支持状态合并策略(reducer):append/extend/replace/merge
11- 支持子状态派生(SubState),实现层级化状态管理
12"""
14from __future__ import annotations
16import uuid
17import time
18from datetime import datetime, timezone
19from enum import Enum
20from typing import (
21 Any, Callable, Dict, Generic, List, Literal, Optional, Set, TypeVar, Union,
22)
24try:
25 from pydantic import (
26 BaseModel, Field, field_validator, model_validator,
27 ConfigDict, PrivateAttr,
28 )
29 from pydantic.json_schema import GenerateJsonSchema
30except ImportError:
31 raise ImportError(
32 "pydantic>=2.0 is required for agentos.state. "
33 "Install with: pip install pydantic>=2.0"
34 )
36# ── Reducers ────────────────────────────────
39class ReducerStrategy(str, Enum):
40 """状态合并策略。"""
41 REPLACE = "replace" # 直接替换
42 APPEND = "append" # 追加(list -> extend)
43 EXTEND = "extend" # 字典合并
44 MERGE = "merge" # 深度递归合并
45 KEEP_EXISTING = "keep" # 保留旧值
46 CUSTOM = "custom" # 自定义 reducer 函数
49# Type variable for generic state
50S = TypeVar("S", bound=BaseModel)
52# Custom reducer type
53ReducerFn = Callable[[Any, Any], Any]
55# Default reducers
56_DEFAULT_REDUCERS: Dict[str, ReducerStrategy] = {}
59def default_reducer(field_name: str, strategy: ReducerStrategy) -> None:
60 """注册字段的默认合并策略。
62 Usage:
63 default_reducer("messages", ReducerStrategy.APPEND)
64 """
65 _DEFAULT_REDUCERS[field_name] = strategy
68def _apply_reducer(old_val: Any, new_val: Any, strategy: ReducerStrategy) -> Any:
69 """应用合并策略。"""
70 if strategy == ReducerStrategy.REPLACE or old_val is None:
71 return new_val
72 if new_val is None:
73 return old_val
74 if strategy == ReducerStrategy.KEEP_EXISTING:
75 return old_val
76 if strategy == ReducerStrategy.APPEND:
77 if isinstance(old_val, list) and isinstance(new_val, list):
78 return old_val + new_val
79 return [old_val, new_val]
80 if strategy == ReducerStrategy.EXTEND:
81 if isinstance(old_val, dict) and isinstance(new_val, dict):
82 return {**old_val, **new_val}
83 return new_val
84 if strategy == ReducerStrategy.MERGE:
85 return _deep_merge(old_val, new_val)
86 return new_val
89def _deep_merge(old: Any, new: Any) -> Any:
90 """递归深度合并两个字典。"""
91 if not isinstance(old, dict) or not isinstance(new, dict):
92 return new
93 result = dict(old)
94 for k, v in new.items():
95 if k in result and isinstance(result[k], dict) and isinstance(v, dict):
96 result[k] = _deep_merge(result[k], v)
97 else:
98 result[k] = v
99 return result
102# ── Field Metadata ──────────────────────────
105class StateFieldInfo(BaseModel):
106 """状态字段元信息。"""
107 reducer: ReducerStrategy = ReducerStrategy.REPLACE
108 custom_reducer: Optional[ReducerFn] = None
109 description: str = ""
110 required: bool = False
111 sensitive: bool = False # 敏感字段,序列化时脱敏
112 checkpointed: bool = True # 是否持久化到 Checkpoint
114 model_config = ConfigDict(arbitrary_types_allowed=True)
117# ── AgentState Core ─────────────────────────
120class BaseAgentState(BaseModel):
121 """Agent 状态的基类。
123 所有 Agent 状态必须继承此类。自动提供:
124 - thread_id / session_id 追踪
125 - step 计数器
126 - 状态快照(snapshot)与恢复(restore)
127 - JSON Schema 生成
128 - Checkpoint 序列化
130 Usage:
131 class MyState(BaseAgentState):
132 messages: list[dict] = Field(default_factory=list)
133 tools_result: dict = Field(default_factory=dict)
134 task_progress: float = 0.0
135 """
137 thread_id: str = Field(
138 default_factory=lambda: f"thread-{uuid.uuid4().hex[:8]}",
139 description="对话线程唯一标识",
140 )
141 messages: list[Any] = Field(
142 default_factory=list, description="对话消息列表"
143 )
144 metrics: dict[str, Any] = Field(
145 default_factory=dict, description="运行时指标"
146 )
147 step: int = Field(default=0, ge=0, description="当前执行步骤")
148 created_at: str = Field(
149 default_factory=lambda: datetime.now(timezone.utc).isoformat(),
150 description="创建时间",
151 )
152 updated_at: str = Field(
153 default_factory=lambda: datetime.now(timezone.utc).isoformat(),
154 description="最后更新时间",
155 )
156 tags: list[str] = Field(default_factory=list, description="标签")
157 metadata: dict[str, Any] = Field(default_factory=dict, description="自定义元数据")
158 parent_state_id: Optional[str] = Field(
159 default=None, description="父状态 ID(用于分支/回溯)"
160 )
162 # 字段级 Reducer 配置
163 _field_reducers: dict[str, ReducerStrategy] = PrivateAttr(default_factory=dict)
164 _field_custom_reducers: dict[str, ReducerFn] = PrivateAttr(default_factory=dict)
165 _sensitive_fields: set[str] = PrivateAttr(default_factory=set)
167 model_config = ConfigDict(
168 extra="allow",
169 validate_assignment=True,
170 json_schema_extra={
171 "title": "AgentState",
172 "description": "AgentOS Structured Agent State",
173 },
174 )
176 def __init__(self, **data):
177 super().__init__(**data)
178 self._field_reducers = {}
179 self._field_custom_reducers = {}
180 self._sensitive_fields = set()
182 # ── Reducer Registration ──────────────────
184 def set_reducer(self, field: str, strategy: ReducerStrategy) -> "BaseAgentState":
185 """为字段设置合并策略。"""
186 self._field_reducers[field] = strategy
187 return self
189 def set_custom_reducer(self, field: str, fn: ReducerFn) -> "BaseAgentState":
190 """为字段设置自定义合并函数。"""
191 self._field_custom_reducers[field] = fn
192 self._field_reducers[field] = ReducerStrategy.CUSTOM
193 return self
195 def mark_sensitive(self, *fields: str) -> "BaseAgentState":
196 """标记敏感字段。"""
197 self._sensitive_fields.update(fields)
198 return self
200 # ── State Mutation ────────────────────────
202 def update_field(
203 self,
204 field: str,
205 value: Any,
206 reducer: Optional[ReducerStrategy] = None,
207 ) -> None:
208 """更新单个字段,自动应用 Reducer。
210 Args:
211 field: 字段名
212 value: 新值
213 reducer: 合并策略,不传则使用注册的 reducer
214 """
215 if field not in self.model_fields and field not in self.model_computed_fields:
216 # Dynamic field — store in metadata
217 old = self.metadata.get(field)
218 strategy = reducer or self._field_reducers.get(
219 field,
220 ReducerStrategy.REPLACE,
221 )
222 self.metadata[field] = _apply_reducer(old, value, strategy)
223 else:
224 old = getattr(self, field, None)
225 strategy = reducer or self._field_reducers.get(
226 field,
227 ReducerStrategy.REPLACE,
228 )
229 new_val = _apply_reducer(old, value, strategy)
230 setattr(self, field, new_val)
232 self.updated_at = datetime.now(timezone.utc).isoformat()
234 def merge(self, other: Union["BaseAgentState", dict]) -> "BaseAgentState":
235 """合并另一个状态到当前状态。
237 Args:
238 other: 另一个 AgentState 实例或字典
240 Returns:
241 self (in-place merge)
242 """
243 if isinstance(other, dict):
244 other = self.__class__(**other)
246 for field_name in other.model_fields:
247 if field_name in ("thread_id", "created_at"):
248 continue # Immutable fields
250 other_val = getattr(other, field_name, None)
251 if other_val is None:
252 continue
254 self.update_field(field_name, other_val)
256 # Merge metadata
257 if other.metadata:
258 self.metadata = _deep_merge(self.metadata, other.metadata)
259 self.updated_at = datetime.now(timezone.utc).isoformat()
261 # Merge tags
262 if other.tags:
263 self.tags = list(set(self.tags + other.tags))
265 return self
267 def increment_step(self) -> int:
268 """递增步骤计数器,返回新 step。"""
269 self.step += 1
270 self.updated_at = datetime.now(timezone.utc).isoformat()
271 return self.step
273 # ── Snapshot & Restore ────────────────────
275 def snapshot(self) -> dict[str, Any]:
276 """生成当前状态的完整快照。
278 Returns:
279 可序列化的状态字典(可直接存入 Checkpoint)
280 """
281 data = self.model_dump(mode="python", exclude_none=False)
282 # Remove private attrs
283 data.pop("_field_reducers", None)
284 data.pop("_field_custom_reducers", None)
285 data.pop("_sensitive_fields", None)
286 return data
288 def sanitized_snapshot(self) -> dict[str, Any]:
289 """生成脱敏快照(敏感字段替换为 '***')"""
290 data = self.snapshot()
291 for field in self._sensitive_fields:
292 if field in data:
293 data[field] = "***"
294 return data
296 @classmethod
297 def restore(cls, data: dict[str, Any]) -> "BaseAgentState":
298 """从快照字典恢复状态。
300 Args:
301 data: snapshot() 返回的字典
303 Returns:
304 新的 AgentState 实例
305 """
306 def _clean_private(d: dict) -> dict:
307 return {k: v for k, v in d.items() if not k.startswith("_")}
309 cleaned = _clean_private(data)
310 instance = cls(**cleaned)
311 # Restore reducer configs from saved metadata if present
312 if "metadata" in cleaned and isinstance(cleaned["metadata"], dict):
313 reducer_config = cleaned["metadata"].get("_reducer_config")
314 if isinstance(reducer_config, dict):
315 for field, strategy_name in reducer_config.items():
316 try:
317 strategy = ReducerStrategy(strategy_name)
318 instance._field_reducers[field] = strategy
319 except ValueError:
320 pass
321 return instance
323 def diff(self, other: "BaseAgentState") -> dict[str, tuple]:
324 """计算两个状态之间的差异。
326 Returns:
327 {field: (old_val, new_val)} 的字典
328 """
329 diffs = {}
330 for field in self.model_fields:
331 old = getattr(self, field, None)
332 new = getattr(other, field, None)
333 if old != new:
334 diffs[field] = (old, new)
335 return diffs
337 # ── JSON Schema ───────────────────────────
339 @classmethod
340 def generate_schema(cls) -> dict[str, Any]:
341 """生成状态的 JSON Schema(符合 OpenAI Function Calling 格式)。
343 Returns:
344 JSON Schema dict,可直接用作 tool/function 的 parameters 定义
345 """
346 return cls.model_json_schema()
348 @classmethod
349 def validate_json_input(cls, data: dict) -> "BaseAgentState":
350 """从 JSON 字典验证并创建实例。"""
351 return cls.model_validate(data)
354# ── Specialized States ──────────────────────
357class AgentState(BaseAgentState):
358 """通用 Agent 运行状态。
360 预配置了常用字段和默认 reducer:
361 - messages: APPEND(对话消息累积)
362 - tools_result: MERGE(工具结果合并)
363 - errors: APPEND(错误收集)
364 - intermediate: REPLACE(中间结果替换)
365 """
367 messages: list[dict[str, Any]] = Field(
368 default_factory=list,
369 description="对话消息历史",
370 )
371 tools_result: dict[str, Any] = Field(
372 default_factory=dict,
373 description="最近一次工具调用结果",
374 )
375 intermediate: dict[str, Any] = Field(
376 default_factory=dict,
377 description="中间计算结果(每次替换)",
378 )
379 errors: list[dict[str, Any]] = Field(
380 default_factory=list,
381 description="错误堆栈",
382 )
383 human_interrupts: list[dict[str, Any]] = Field(
384 default_factory=list,
385 description="人工干预请求队列",
386 )
387 context_summary: str = Field(
388 default="",
389 description="上下文摘要(自动分页用)",
390 )
391 task_progress: Optional[float] = Field(
392 default=None,
393 ge=0.0,
394 le=1.0,
395 description="任务进度 0.0~1.0",
396 )
397 abort_reason: Optional[str] = Field(
398 default=None,
399 description="中止原因(非空表示需中止)",
400 )
402 def __init__(self, **data):
403 super().__init__(**data)
404 # Default reducers for AgentState
405 self._field_reducers["messages"] = ReducerStrategy.APPEND
406 self._field_reducers["tools_result"] = ReducerStrategy.MERGE
407 self._field_reducers["errors"] = ReducerStrategy.APPEND
408 self._field_reducers["human_interrupts"] = ReducerStrategy.APPEND
409 self._field_reducers["intermediate"] = ReducerStrategy.REPLACE
411 @property
412 def last_message(self) -> Optional[dict[str, Any]]:
413 """获取最后一条消息。"""
414 return self.messages[-1] if self.messages else None
416 @property
417 def error_count(self) -> int:
418 """累计错误数。"""
419 return len(self.errors)
421 @property
422 def should_abort(self) -> bool:
423 """是否需要中止执行?"""
424 return self.abort_reason is not None
426 def add_message(self, role: str, content: str, **extra) -> "AgentState":
427 """添加一条消息。"""
428 msg = {"role": role, "content": content, **extra}
429 self.update_field("messages", [msg])
430 return self
432 def add_error(self, error_type: str, message: str, **extra) -> "AgentState":
433 """记录一个错误。"""
434 err = {
435 "type": error_type,
436 "message": message,
437 "step": self.step,
438 "timestamp": datetime.now(timezone.utc).isoformat(),
439 **extra,
440 }
441 self.update_field("errors", [err])
442 return self
444 def request_human_input(
445 self,
446 prompt: str,
447 options: Optional[list[str]] = None,
448 **extra,
449 ) -> "AgentState":
450 """发起人工干预请求。"""
451 interrupt = {
452 "prompt": prompt,
453 "options": options,
454 "step": self.step,
455 "timestamp": datetime.now(timezone.utc).isoformat(),
456 **extra,
457 }
458 self.update_field("human_interrupts", [interrupt])
459 return self
461 def clear_human_interrupts(self) -> "AgentState":
462 """清除所有待处理的人工干预请求。"""
463 self.human_interrupts = []
464 return self
466 def abort(self, reason: str) -> "AgentState":
467 """标记任务需要中止。"""
468 self.abort_reason = reason
469 return self
471 def reset_abort(self) -> "AgentState":
472 """清除中止标记。"""
473 self.abort_reason = None
474 return self
477class MultiAgentState(BaseAgentState):
478 """多 Agent 协作状态。
480 管理多个子 Agent 的状态、消息路由、角色分配。
481 """
483 agents: dict[str, dict[str, Any]] = Field(
484 default_factory=dict,
485 description="所有子 Agent 的状态 {agent_id: {state dict}}",
486 )
487 message_queue: list[dict[str, Any]] = Field(
488 default_factory=list,
489 description="Agent 间消息队列",
490 )
491 roles: dict[str, str] = Field(
492 default_factory=dict,
493 description="Agent 角色分配 {agent_id: role_name}",
494 )
495 coordinator_state: dict[str, Any] = Field(
496 default_factory=dict,
497 description="协调器内部状态",
498 )
499 handoff_log: list[dict[str, Any]] = Field(
500 default_factory=list,
501 description="任务移交记录",
502 )
504 def __init__(self, **data):
505 super().__init__(**data)
506 self._field_reducers["agents"] = ReducerStrategy.MERGE
507 self._field_reducers["message_queue"] = ReducerStrategy.APPEND
508 self._field_reducers["handoff_log"] = ReducerStrategy.APPEND
510 def register_agent(
511 self, agent_id: str, role: str = "worker", **meta
512 ) -> "MultiAgentState":
513 """注册一个子 Agent。"""
514 self.agents[agent_id] = {
515 "role": role,
516 "state": "idle",
517 "step": 0,
518 "errors": 0,
519 **meta,
520 }
521 self.roles[agent_id] = role
522 return self
524 def update_agent_state(
525 self, agent_id: str, updates: dict[str, Any]
526 ) -> "MultiAgentState":
527 """更新子 Agent 的状态。"""
528 if agent_id in self.agents:
529 self.agents[agent_id].update(updates)
530 self._touch()
531 return self
533 def send_message(
534 self, from_agent: str, to_agent: str, content: Any
535 ) -> "MultiAgentState":
536 """Agent 间发送消息。"""
537 msg = {
538 "from": from_agent,
539 "to": to_agent,
540 "content": content,
541 "timestamp": datetime.now(timezone.utc).isoformat(),
542 }
543 self.update_field("message_queue", [msg])
544 return self
546 def log_handoff(
547 self, from_agent: str, to_agent: str, task_id: str, reason: str = ""
548 ) -> "MultiAgentState":
549 """记录任务移交。"""
550 self.update_field("handoff_log", [{
551 "from": from_agent,
552 "to": to_agent,
553 "task_id": task_id,
554 "reason": reason,
555 "timestamp": datetime.now(timezone.utc).isoformat(),
556 }])
557 return self
559 def _touch(self) -> None:
560 self.updated_at = datetime.now(timezone.utc).isoformat()
563class ToolCallState(BaseAgentState):
564 """工具调用追踪状态。
566 用于细粒度监控每次工具调用的入参/出参/耗时。
567 """
569 calls: list[dict[str, Any]] = Field(
570 default_factory=list,
571 description="工具调用历史",
572 )
573 pending_calls: dict[str, dict[str, Any]] = Field(
574 default_factory=dict,
575 description="进行中的工具调用 {call_id: {tool, args, start_time}}",
576 )
577 tool_stats: dict[str, dict[str, Any]] = Field(
578 default_factory=dict,
579 description="工具统计 {tool_name: {count, total_ms, errors, avg_ms}}",
580 )
582 def __init__(self, **data):
583 super().__init__(**data)
584 self._field_reducers["calls"] = ReducerStrategy.APPEND
586 def start_call(self, call_id: str, tool: str, args: dict) -> "ToolCallState":
587 """记录工具调用开始。"""
588 self.pending_calls[call_id] = {
589 "tool": tool,
590 "args": args,
591 "start_time": time.time(),
592 "step": self.step,
593 }
594 self._init_stats(tool)
595 return self
597 def complete_call(
598 self, call_id: str, result: Any, error: Optional[str] = None
599 ) -> "ToolCallState":
600 """记录工具调用完成。"""
601 if call_id not in self.pending_calls:
602 return self # 幂等
603 call_info = self.pending_calls.pop(call_id)
604 elapsed_ms = (time.time() - call_info["start_time"]) * 1000
605 record = {
606 **call_info,
607 "call_id": call_id,
608 "elapsed_ms": round(elapsed_ms, 2),
609 "result": result,
610 "error": error,
611 "success": error is None,
612 "timestamp": datetime.now(timezone.utc).isoformat(),
613 }
614 self.update_field("calls", [record])
616 # Update stats
617 tool = call_info["tool"]
618 stats = self.tool_stats.get(tool, {})
619 stats["count"] = stats.get("count", 0) + 1
620 stats["total_ms"] = stats.get("total_ms", 0) + elapsed_ms
621 stats["errors"] = stats.get("errors", 0) + (1 if error else 0)
622 stats["avg_ms"] = stats["total_ms"] / stats["count"]
623 self.tool_stats[tool] = stats
625 return self
627 def _init_stats(self, tool: str) -> None:
628 """初始化工具统计。"""
629 if tool not in self.tool_stats:
630 self.tool_stats[tool] = {
631 "count": 0,
632 "total_ms": 0,
633 "errors": 0,
634 "avg_ms": 0,
635 }
637 @property
638 def total_tool_calls(self) -> int:
639 """总工具调用次数。"""
640 return len(self.calls)
642 @property
643 def failed_calls(self) -> int:
644 """失败的工具调用次数。"""
645 return sum(1 for c in self.calls if not c.get("success", True))
648# ── Schema Registry ──────────────────────────
651class StateSchemaRegistry:
652 """状态 Schema 注册中心。
654 支持按名称查找、注册、验证状态类型。
655 """
657 def __init__(self):
658 self._schemas: dict[str, type[BaseAgentState]] = {}
659 self._default_name: str = "AgentState"
660 self._schemas[self._default_name] = AgentState
661 self._schemas["MultiAgentState"] = MultiAgentState
662 self._schemas["ToolCallState"] = ToolCallState
664 def register(self, name: str, schema_cls: type[BaseAgentState]) -> None:
665 """注册自定义状态类型。"""
666 if not issubclass(schema_cls, BaseAgentState):
667 raise TypeError(
668 f"State class must inherit from BaseAgentState, "
669 f"got {schema_cls.__name__}"
670 )
671 self._schemas[name] = schema_cls
673 def get(self, name: str) -> type[BaseAgentState]:
674 """获取已注册的状态类型。"""
675 if name not in self._schemas:
676 raise KeyError(f"Unknown state schema: '{name}'. "
677 f"Available: {list(self._schemas.keys())}")
678 return self._schemas[name]
680 def list_schemas(self) -> list[str]:
681 """列出所有已注册的状态类型。"""
682 return list(self._schemas.keys())
684 def create_state(self, name: str, **kwargs) -> BaseAgentState:
685 """创建已注册类型的实例。"""
686 cls = self.get(name)
687 return cls(**kwargs)
689 def validate(self, name: str, data: dict) -> BaseAgentState:
690 """验证并创建状态实例。"""
691 cls = self.get(name)
692 return cls.model_validate(data)
694 @property
695 def default_state_class(self) -> type[BaseAgentState]:
696 """默认状态类型。"""
697 return self._schemas[self._default_name]
699 @default_state_class.setter
700 def default_state_class(self, name: str) -> None:
701 if name not in self._schemas:
702 raise KeyError(f"Unknown state schema: '{name}'")
703 self._default_name = name
706# ── 全局单例 ─────────────────────────────────
708state_registry = StateSchemaRegistry()
710# ── State Reducers (test compatibility) ──
711class StateReducer:
712 """Base state reducer with merge strategy."""
713 @staticmethod
714 def merge(base_state, new_state):
715 data = base_state.model_dump()
716 new_data = new_state.model_dump(exclude_unset=True)
717 data.update(new_data)
718 return type(base_state)(**data)
720class LastWriteWinsReducer:
721 """Reducer: newer state wins based on version."""
722 @staticmethod
723 def merge(base_state, new_state):
724 bv = getattr(base_state, 'version', 0)
725 nv = getattr(new_state, 'version', 0)
726 if nv >= bv:
727 data = base_state.model_dump()
728 data.update(new_state.model_dump(exclude_unset=True))
729 return type(base_state)(**data)
730 return base_state
732class AppendOnlyReducer:
733 """Reducer: append-only for list fields."""
734 MERGE_FIELDS = ["messages", "logs"]
735 @staticmethod
736 def merge(base_state, new_state):
737 data = base_state.model_dump()
738 new_data = new_state.model_dump(exclude_unset=True)
739 for field in AppendOnlyReducer.MERGE_FIELDS:
740 if field in new_data and field in data:
741 data[field] = list(data[field]) + list(new_data[field])
742 data.update({k:v for k,v in new_data.items() if k not in AppendOnlyReducer.MERGE_FIELDS})
743 return type(base_state)(**data)