Coverage for agentos/state/schema.py: 0%

329 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 16:36 +0800

1""" 

2AgentOS v1.14.0 — 结构化 Agent 状态管理系统。 

3 

4基因来源: LangGraph Pydantic State Schema + AgentOS Checkpoint。 

5 

6核心设计: 

7- AgentState: 强类型的全局 Agent 运行时状态,Pydantic v2 驱动 

8- 支持 JSON Schema 自动生成、验证、序列化 

9- 与 Checkpoint 系统无缝对接 

10- 支持状态合并策略(reducer):append/extend/replace/merge 

11- 支持子状态派生(SubState),实现层级化状态管理 

12""" 

13 

14from __future__ import annotations 

15 

16import uuid 

17import time 

18from datetime import datetime, timezone 

19from enum import Enum 

20from typing import ( 

21 Any, Callable, Dict, Generic, List, Literal, Optional, Set, TypeVar, Union, 

22) 

23 

24try: 

25 from pydantic import ( 

26 BaseModel, Field, field_validator, model_validator, 

27 ConfigDict, PrivateAttr, 

28 ) 

29 from pydantic.json_schema import GenerateJsonSchema 

30except ImportError: 

31 raise ImportError( 

32 "pydantic>=2.0 is required for agentos.state. " 

33 "Install with: pip install pydantic>=2.0" 

34 ) 

35 

36# ── Reducers ──────────────────────────────── 

37 

38 

39class ReducerStrategy(str, Enum): 

40 """状态合并策略。""" 

41 REPLACE = "replace" # 直接替换 

42 APPEND = "append" # 追加(list -> extend) 

43 EXTEND = "extend" # 字典合并 

44 MERGE = "merge" # 深度递归合并 

45 KEEP_EXISTING = "keep" # 保留旧值 

46 CUSTOM = "custom" # 自定义 reducer 函数 

47 

48 

49# Type variable for generic state 

50S = TypeVar("S", bound=BaseModel) 

51 

52# Custom reducer type 

53ReducerFn = Callable[[Any, Any], Any] 

54 

55# Default reducers 

56_DEFAULT_REDUCERS: Dict[str, ReducerStrategy] = {} 

57 

58 

59def default_reducer(field_name: str, strategy: ReducerStrategy) -> None: 

60 """注册字段的默认合并策略。 

61 

62 Usage: 

63 default_reducer("messages", ReducerStrategy.APPEND) 

64 """ 

65 _DEFAULT_REDUCERS[field_name] = strategy 

66 

67 

68def _apply_reducer(old_val: Any, new_val: Any, strategy: ReducerStrategy) -> Any: 

69 """应用合并策略。""" 

70 if strategy == ReducerStrategy.REPLACE or old_val is None: 

71 return new_val 

72 if new_val is None: 

73 return old_val 

74 if strategy == ReducerStrategy.KEEP_EXISTING: 

75 return old_val 

76 if strategy == ReducerStrategy.APPEND: 

77 if isinstance(old_val, list) and isinstance(new_val, list): 

78 return old_val + new_val 

79 return [old_val, new_val] 

80 if strategy == ReducerStrategy.EXTEND: 

81 if isinstance(old_val, dict) and isinstance(new_val, dict): 

82 return {**old_val, **new_val} 

83 return new_val 

84 if strategy == ReducerStrategy.MERGE: 

85 return _deep_merge(old_val, new_val) 

86 return new_val 

87 

88 

89def _deep_merge(old: Any, new: Any) -> Any: 

90 """递归深度合并两个字典。""" 

91 if not isinstance(old, dict) or not isinstance(new, dict): 

92 return new 

93 result = dict(old) 

94 for k, v in new.items(): 

95 if k in result and isinstance(result[k], dict) and isinstance(v, dict): 

96 result[k] = _deep_merge(result[k], v) 

97 else: 

98 result[k] = v 

99 return result 

100 

101 

102# ── Field Metadata ────────────────────────── 

103 

104 

105class StateFieldInfo(BaseModel): 

106 """状态字段元信息。""" 

107 reducer: ReducerStrategy = ReducerStrategy.REPLACE 

108 custom_reducer: Optional[ReducerFn] = None 

109 description: str = "" 

110 required: bool = False 

111 sensitive: bool = False # 敏感字段,序列化时脱敏 

112 checkpointed: bool = True # 是否持久化到 Checkpoint 

113 

114 model_config = ConfigDict(arbitrary_types_allowed=True) 

115 

116 

117# ── AgentState Core ───────────────────────── 

118 

119 

120class BaseAgentState(BaseModel): 

121 """Agent 状态的基类。 

122 

123 所有 Agent 状态必须继承此类。自动提供: 

124 - thread_id / session_id 追踪 

125 - step 计数器 

126 - 状态快照(snapshot)与恢复(restore) 

127 - JSON Schema 生成 

128 - Checkpoint 序列化 

129 

130 Usage: 

131 class MyState(BaseAgentState): 

132 messages: list[dict] = Field(default_factory=list) 

133 tools_result: dict = Field(default_factory=dict) 

134 task_progress: float = 0.0 

135 """ 

136 

137 thread_id: str = Field( 

138 default_factory=lambda: f"thread-{uuid.uuid4().hex[:8]}", 

139 description="对话线程唯一标识", 

140 ) 

141 messages: list[Any] = Field( 

142 default_factory=list, description="对话消息列表" 

143 ) 

144 metrics: dict[str, Any] = Field( 

145 default_factory=dict, description="运行时指标" 

146 ) 

147 step: int = Field(default=0, ge=0, description="当前执行步骤") 

148 created_at: str = Field( 

149 default_factory=lambda: datetime.now(timezone.utc).isoformat(), 

150 description="创建时间", 

151 ) 

152 updated_at: str = Field( 

153 default_factory=lambda: datetime.now(timezone.utc).isoformat(), 

154 description="最后更新时间", 

155 ) 

156 tags: list[str] = Field(default_factory=list, description="标签") 

157 metadata: dict[str, Any] = Field(default_factory=dict, description="自定义元数据") 

158 parent_state_id: Optional[str] = Field( 

159 default=None, description="父状态 ID(用于分支/回溯)" 

160 ) 

161 

162 # 字段级 Reducer 配置 

163 _field_reducers: dict[str, ReducerStrategy] = PrivateAttr(default_factory=dict) 

164 _field_custom_reducers: dict[str, ReducerFn] = PrivateAttr(default_factory=dict) 

165 _sensitive_fields: set[str] = PrivateAttr(default_factory=set) 

166 

167 model_config = ConfigDict( 

168 extra="allow", 

169 validate_assignment=True, 

170 json_schema_extra={ 

171 "title": "AgentState", 

172 "description": "AgentOS Structured Agent State", 

173 }, 

174 ) 

175 

176 def __init__(self, **data): 

177 super().__init__(**data) 

178 self._field_reducers = {} 

179 self._field_custom_reducers = {} 

180 self._sensitive_fields = set() 

181 

182 # ── Reducer Registration ────────────────── 

183 

184 def set_reducer(self, field: str, strategy: ReducerStrategy) -> "BaseAgentState": 

185 """为字段设置合并策略。""" 

186 self._field_reducers[field] = strategy 

187 return self 

188 

189 def set_custom_reducer(self, field: str, fn: ReducerFn) -> "BaseAgentState": 

190 """为字段设置自定义合并函数。""" 

191 self._field_custom_reducers[field] = fn 

192 self._field_reducers[field] = ReducerStrategy.CUSTOM 

193 return self 

194 

195 def mark_sensitive(self, *fields: str) -> "BaseAgentState": 

196 """标记敏感字段。""" 

197 self._sensitive_fields.update(fields) 

198 return self 

199 

200 # ── State Mutation ──────────────────────── 

201 

202 def update_field( 

203 self, 

204 field: str, 

205 value: Any, 

206 reducer: Optional[ReducerStrategy] = None, 

207 ) -> None: 

208 """更新单个字段,自动应用 Reducer。 

209 

210 Args: 

211 field: 字段名 

212 value: 新值 

213 reducer: 合并策略,不传则使用注册的 reducer 

214 """ 

215 if field not in self.model_fields and field not in self.model_computed_fields: 

216 # Dynamic field — store in metadata 

217 old = self.metadata.get(field) 

218 strategy = reducer or self._field_reducers.get( 

219 field, 

220 ReducerStrategy.REPLACE, 

221 ) 

222 self.metadata[field] = _apply_reducer(old, value, strategy) 

223 else: 

224 old = getattr(self, field, None) 

225 strategy = reducer or self._field_reducers.get( 

226 field, 

227 ReducerStrategy.REPLACE, 

228 ) 

229 new_val = _apply_reducer(old, value, strategy) 

230 setattr(self, field, new_val) 

231 

232 self.updated_at = datetime.now(timezone.utc).isoformat() 

233 

234 def merge(self, other: Union["BaseAgentState", dict]) -> "BaseAgentState": 

235 """合并另一个状态到当前状态。 

236 

237 Args: 

238 other: 另一个 AgentState 实例或字典 

239 

240 Returns: 

241 self (in-place merge) 

242 """ 

243 if isinstance(other, dict): 

244 other = self.__class__(**other) 

245 

246 for field_name in other.model_fields: 

247 if field_name in ("thread_id", "created_at"): 

248 continue # Immutable fields 

249 

250 other_val = getattr(other, field_name, None) 

251 if other_val is None: 

252 continue 

253 

254 self.update_field(field_name, other_val) 

255 

256 # Merge metadata 

257 if other.metadata: 

258 self.metadata = _deep_merge(self.metadata, other.metadata) 

259 self.updated_at = datetime.now(timezone.utc).isoformat() 

260 

261 # Merge tags 

262 if other.tags: 

263 self.tags = list(set(self.tags + other.tags)) 

264 

265 return self 

266 

267 def increment_step(self) -> int: 

268 """递增步骤计数器,返回新 step。""" 

269 self.step += 1 

270 self.updated_at = datetime.now(timezone.utc).isoformat() 

271 return self.step 

272 

273 # ── Snapshot & Restore ──────────────────── 

274 

275 def snapshot(self) -> dict[str, Any]: 

276 """生成当前状态的完整快照。 

277 

278 Returns: 

279 可序列化的状态字典(可直接存入 Checkpoint) 

280 """ 

281 data = self.model_dump(mode="python", exclude_none=False) 

282 # Remove private attrs 

283 data.pop("_field_reducers", None) 

284 data.pop("_field_custom_reducers", None) 

285 data.pop("_sensitive_fields", None) 

286 return data 

287 

288 def sanitized_snapshot(self) -> dict[str, Any]: 

289 """生成脱敏快照(敏感字段替换为 '***')""" 

290 data = self.snapshot() 

291 for field in self._sensitive_fields: 

292 if field in data: 

293 data[field] = "***" 

294 return data 

295 

296 @classmethod 

297 def restore(cls, data: dict[str, Any]) -> "BaseAgentState": 

298 """从快照字典恢复状态。 

299 

300 Args: 

301 data: snapshot() 返回的字典 

302 

303 Returns: 

304 新的 AgentState 实例 

305 """ 

306 def _clean_private(d: dict) -> dict: 

307 return {k: v for k, v in d.items() if not k.startswith("_")} 

308 

309 cleaned = _clean_private(data) 

310 instance = cls(**cleaned) 

311 # Restore reducer configs from saved metadata if present 

312 if "metadata" in cleaned and isinstance(cleaned["metadata"], dict): 

313 reducer_config = cleaned["metadata"].get("_reducer_config") 

314 if isinstance(reducer_config, dict): 

315 for field, strategy_name in reducer_config.items(): 

316 try: 

317 strategy = ReducerStrategy(strategy_name) 

318 instance._field_reducers[field] = strategy 

319 except ValueError: 

320 pass 

321 return instance 

322 

323 def diff(self, other: "BaseAgentState") -> dict[str, tuple]: 

324 """计算两个状态之间的差异。 

325 

326 Returns: 

327 {field: (old_val, new_val)} 的字典 

328 """ 

329 diffs = {} 

330 for field in self.model_fields: 

331 old = getattr(self, field, None) 

332 new = getattr(other, field, None) 

333 if old != new: 

334 diffs[field] = (old, new) 

335 return diffs 

336 

337 # ── JSON Schema ─────────────────────────── 

338 

339 @classmethod 

340 def generate_schema(cls) -> dict[str, Any]: 

341 """生成状态的 JSON Schema(符合 OpenAI Function Calling 格式)。 

342 

343 Returns: 

344 JSON Schema dict,可直接用作 tool/function 的 parameters 定义 

345 """ 

346 return cls.model_json_schema() 

347 

348 @classmethod 

349 def validate_json_input(cls, data: dict) -> "BaseAgentState": 

350 """从 JSON 字典验证并创建实例。""" 

351 return cls.model_validate(data) 

352 

353 

354# ── Specialized States ────────────────────── 

355 

356 

357class AgentState(BaseAgentState): 

358 """通用 Agent 运行状态。 

359 

360 预配置了常用字段和默认 reducer: 

361 - messages: APPEND(对话消息累积) 

362 - tools_result: MERGE(工具结果合并) 

363 - errors: APPEND(错误收集) 

364 - intermediate: REPLACE(中间结果替换) 

365 """ 

366 

367 messages: list[dict[str, Any]] = Field( 

368 default_factory=list, 

369 description="对话消息历史", 

370 ) 

371 tools_result: dict[str, Any] = Field( 

372 default_factory=dict, 

373 description="最近一次工具调用结果", 

374 ) 

375 intermediate: dict[str, Any] = Field( 

376 default_factory=dict, 

377 description="中间计算结果(每次替换)", 

378 ) 

379 errors: list[dict[str, Any]] = Field( 

380 default_factory=list, 

381 description="错误堆栈", 

382 ) 

383 human_interrupts: list[dict[str, Any]] = Field( 

384 default_factory=list, 

385 description="人工干预请求队列", 

386 ) 

387 context_summary: str = Field( 

388 default="", 

389 description="上下文摘要(自动分页用)", 

390 ) 

391 task_progress: Optional[float] = Field( 

392 default=None, 

393 ge=0.0, 

394 le=1.0, 

395 description="任务进度 0.0~1.0", 

396 ) 

397 abort_reason: Optional[str] = Field( 

398 default=None, 

399 description="中止原因(非空表示需中止)", 

400 ) 

401 

402 def __init__(self, **data): 

403 super().__init__(**data) 

404 # Default reducers for AgentState 

405 self._field_reducers["messages"] = ReducerStrategy.APPEND 

406 self._field_reducers["tools_result"] = ReducerStrategy.MERGE 

407 self._field_reducers["errors"] = ReducerStrategy.APPEND 

408 self._field_reducers["human_interrupts"] = ReducerStrategy.APPEND 

409 self._field_reducers["intermediate"] = ReducerStrategy.REPLACE 

410 

411 @property 

412 def last_message(self) -> Optional[dict[str, Any]]: 

413 """获取最后一条消息。""" 

414 return self.messages[-1] if self.messages else None 

415 

416 @property 

417 def error_count(self) -> int: 

418 """累计错误数。""" 

419 return len(self.errors) 

420 

421 @property 

422 def should_abort(self) -> bool: 

423 """是否需要中止执行?""" 

424 return self.abort_reason is not None 

425 

426 def add_message(self, role: str, content: str, **extra) -> "AgentState": 

427 """添加一条消息。""" 

428 msg = {"role": role, "content": content, **extra} 

429 self.update_field("messages", [msg]) 

430 return self 

431 

432 def add_error(self, error_type: str, message: str, **extra) -> "AgentState": 

433 """记录一个错误。""" 

434 err = { 

435 "type": error_type, 

436 "message": message, 

437 "step": self.step, 

438 "timestamp": datetime.now(timezone.utc).isoformat(), 

439 **extra, 

440 } 

441 self.update_field("errors", [err]) 

442 return self 

443 

444 def request_human_input( 

445 self, 

446 prompt: str, 

447 options: Optional[list[str]] = None, 

448 **extra, 

449 ) -> "AgentState": 

450 """发起人工干预请求。""" 

451 interrupt = { 

452 "prompt": prompt, 

453 "options": options, 

454 "step": self.step, 

455 "timestamp": datetime.now(timezone.utc).isoformat(), 

456 **extra, 

457 } 

458 self.update_field("human_interrupts", [interrupt]) 

459 return self 

460 

461 def clear_human_interrupts(self) -> "AgentState": 

462 """清除所有待处理的人工干预请求。""" 

463 self.human_interrupts = [] 

464 return self 

465 

466 def abort(self, reason: str) -> "AgentState": 

467 """标记任务需要中止。""" 

468 self.abort_reason = reason 

469 return self 

470 

471 def reset_abort(self) -> "AgentState": 

472 """清除中止标记。""" 

473 self.abort_reason = None 

474 return self 

475 

476 

477class MultiAgentState(BaseAgentState): 

478 """多 Agent 协作状态。 

479 

480 管理多个子 Agent 的状态、消息路由、角色分配。 

481 """ 

482 

483 agents: dict[str, dict[str, Any]] = Field( 

484 default_factory=dict, 

485 description="所有子 Agent 的状态 {agent_id: {state dict}}", 

486 ) 

487 message_queue: list[dict[str, Any]] = Field( 

488 default_factory=list, 

489 description="Agent 间消息队列", 

490 ) 

491 roles: dict[str, str] = Field( 

492 default_factory=dict, 

493 description="Agent 角色分配 {agent_id: role_name}", 

494 ) 

495 coordinator_state: dict[str, Any] = Field( 

496 default_factory=dict, 

497 description="协调器内部状态", 

498 ) 

499 handoff_log: list[dict[str, Any]] = Field( 

500 default_factory=list, 

501 description="任务移交记录", 

502 ) 

503 

504 def __init__(self, **data): 

505 super().__init__(**data) 

506 self._field_reducers["agents"] = ReducerStrategy.MERGE 

507 self._field_reducers["message_queue"] = ReducerStrategy.APPEND 

508 self._field_reducers["handoff_log"] = ReducerStrategy.APPEND 

509 

510 def register_agent( 

511 self, agent_id: str, role: str = "worker", **meta 

512 ) -> "MultiAgentState": 

513 """注册一个子 Agent。""" 

514 self.agents[agent_id] = { 

515 "role": role, 

516 "state": "idle", 

517 "step": 0, 

518 "errors": 0, 

519 **meta, 

520 } 

521 self.roles[agent_id] = role 

522 return self 

523 

524 def update_agent_state( 

525 self, agent_id: str, updates: dict[str, Any] 

526 ) -> "MultiAgentState": 

527 """更新子 Agent 的状态。""" 

528 if agent_id in self.agents: 

529 self.agents[agent_id].update(updates) 

530 self._touch() 

531 return self 

532 

533 def send_message( 

534 self, from_agent: str, to_agent: str, content: Any 

535 ) -> "MultiAgentState": 

536 """Agent 间发送消息。""" 

537 msg = { 

538 "from": from_agent, 

539 "to": to_agent, 

540 "content": content, 

541 "timestamp": datetime.now(timezone.utc).isoformat(), 

542 } 

543 self.update_field("message_queue", [msg]) 

544 return self 

545 

546 def log_handoff( 

547 self, from_agent: str, to_agent: str, task_id: str, reason: str = "" 

548 ) -> "MultiAgentState": 

549 """记录任务移交。""" 

550 self.update_field("handoff_log", [{ 

551 "from": from_agent, 

552 "to": to_agent, 

553 "task_id": task_id, 

554 "reason": reason, 

555 "timestamp": datetime.now(timezone.utc).isoformat(), 

556 }]) 

557 return self 

558 

559 def _touch(self) -> None: 

560 self.updated_at = datetime.now(timezone.utc).isoformat() 

561 

562 

563class ToolCallState(BaseAgentState): 

564 """工具调用追踪状态。 

565 

566 用于细粒度监控每次工具调用的入参/出参/耗时。 

567 """ 

568 

569 calls: list[dict[str, Any]] = Field( 

570 default_factory=list, 

571 description="工具调用历史", 

572 ) 

573 pending_calls: dict[str, dict[str, Any]] = Field( 

574 default_factory=dict, 

575 description="进行中的工具调用 {call_id: {tool, args, start_time}}", 

576 ) 

577 tool_stats: dict[str, dict[str, Any]] = Field( 

578 default_factory=dict, 

579 description="工具统计 {tool_name: {count, total_ms, errors, avg_ms}}", 

580 ) 

581 

582 def __init__(self, **data): 

583 super().__init__(**data) 

584 self._field_reducers["calls"] = ReducerStrategy.APPEND 

585 

586 def start_call(self, call_id: str, tool: str, args: dict) -> "ToolCallState": 

587 """记录工具调用开始。""" 

588 self.pending_calls[call_id] = { 

589 "tool": tool, 

590 "args": args, 

591 "start_time": time.time(), 

592 "step": self.step, 

593 } 

594 self._init_stats(tool) 

595 return self 

596 

597 def complete_call( 

598 self, call_id: str, result: Any, error: Optional[str] = None 

599 ) -> "ToolCallState": 

600 """记录工具调用完成。""" 

601 if call_id not in self.pending_calls: 

602 return self # 幂等 

603 call_info = self.pending_calls.pop(call_id) 

604 elapsed_ms = (time.time() - call_info["start_time"]) * 1000 

605 record = { 

606 **call_info, 

607 "call_id": call_id, 

608 "elapsed_ms": round(elapsed_ms, 2), 

609 "result": result, 

610 "error": error, 

611 "success": error is None, 

612 "timestamp": datetime.now(timezone.utc).isoformat(), 

613 } 

614 self.update_field("calls", [record]) 

615 

616 # Update stats 

617 tool = call_info["tool"] 

618 stats = self.tool_stats.get(tool, {}) 

619 stats["count"] = stats.get("count", 0) + 1 

620 stats["total_ms"] = stats.get("total_ms", 0) + elapsed_ms 

621 stats["errors"] = stats.get("errors", 0) + (1 if error else 0) 

622 stats["avg_ms"] = stats["total_ms"] / stats["count"] 

623 self.tool_stats[tool] = stats 

624 

625 return self 

626 

627 def _init_stats(self, tool: str) -> None: 

628 """初始化工具统计。""" 

629 if tool not in self.tool_stats: 

630 self.tool_stats[tool] = { 

631 "count": 0, 

632 "total_ms": 0, 

633 "errors": 0, 

634 "avg_ms": 0, 

635 } 

636 

637 @property 

638 def total_tool_calls(self) -> int: 

639 """总工具调用次数。""" 

640 return len(self.calls) 

641 

642 @property 

643 def failed_calls(self) -> int: 

644 """失败的工具调用次数。""" 

645 return sum(1 for c in self.calls if not c.get("success", True)) 

646 

647 

648# ── Schema Registry ────────────────────────── 

649 

650 

651class StateSchemaRegistry: 

652 """状态 Schema 注册中心。 

653 

654 支持按名称查找、注册、验证状态类型。 

655 """ 

656 

657 def __init__(self): 

658 self._schemas: dict[str, type[BaseAgentState]] = {} 

659 self._default_name: str = "AgentState" 

660 self._schemas[self._default_name] = AgentState 

661 self._schemas["MultiAgentState"] = MultiAgentState 

662 self._schemas["ToolCallState"] = ToolCallState 

663 

664 def register(self, name: str, schema_cls: type[BaseAgentState]) -> None: 

665 """注册自定义状态类型。""" 

666 if not issubclass(schema_cls, BaseAgentState): 

667 raise TypeError( 

668 f"State class must inherit from BaseAgentState, " 

669 f"got {schema_cls.__name__}" 

670 ) 

671 self._schemas[name] = schema_cls 

672 

673 def get(self, name: str) -> type[BaseAgentState]: 

674 """获取已注册的状态类型。""" 

675 if name not in self._schemas: 

676 raise KeyError(f"Unknown state schema: '{name}'. " 

677 f"Available: {list(self._schemas.keys())}") 

678 return self._schemas[name] 

679 

680 def list_schemas(self) -> list[str]: 

681 """列出所有已注册的状态类型。""" 

682 return list(self._schemas.keys()) 

683 

684 def create_state(self, name: str, **kwargs) -> BaseAgentState: 

685 """创建已注册类型的实例。""" 

686 cls = self.get(name) 

687 return cls(**kwargs) 

688 

689 def validate(self, name: str, data: dict) -> BaseAgentState: 

690 """验证并创建状态实例。""" 

691 cls = self.get(name) 

692 return cls.model_validate(data) 

693 

694 @property 

695 def default_state_class(self) -> type[BaseAgentState]: 

696 """默认状态类型。""" 

697 return self._schemas[self._default_name] 

698 

699 @default_state_class.setter 

700 def default_state_class(self, name: str) -> None: 

701 if name not in self._schemas: 

702 raise KeyError(f"Unknown state schema: '{name}'") 

703 self._default_name = name 

704 

705 

706# ── 全局单例 ───────────────────────────────── 

707 

708state_registry = StateSchemaRegistry() 

709 

710# ── State Reducers (test compatibility) ── 

711class StateReducer: 

712 """Base state reducer with merge strategy.""" 

713 @staticmethod 

714 def merge(base_state, new_state): 

715 data = base_state.model_dump() 

716 new_data = new_state.model_dump(exclude_unset=True) 

717 data.update(new_data) 

718 return type(base_state)(**data) 

719 

720class LastWriteWinsReducer: 

721 """Reducer: newer state wins based on version.""" 

722 @staticmethod 

723 def merge(base_state, new_state): 

724 bv = getattr(base_state, 'version', 0) 

725 nv = getattr(new_state, 'version', 0) 

726 if nv >= bv: 

727 data = base_state.model_dump() 

728 data.update(new_state.model_dump(exclude_unset=True)) 

729 return type(base_state)(**data) 

730 return base_state 

731 

732class AppendOnlyReducer: 

733 """Reducer: append-only for list fields.""" 

734 MERGE_FIELDS = ["messages", "logs"] 

735 @staticmethod 

736 def merge(base_state, new_state): 

737 data = base_state.model_dump() 

738 new_data = new_state.model_dump(exclude_unset=True) 

739 for field in AppendOnlyReducer.MERGE_FIELDS: 

740 if field in new_data and field in data: 

741 data[field] = list(data[field]) + list(new_data[field]) 

742 data.update({k:v for k,v in new_data.items() if k not in AppendOnlyReducer.MERGE_FIELDS}) 

743 return type(base_state)(**data)