Coverage for agentos/llm/base.py: 75%

107 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2LLM Provider 抽象层。 

3为 Nexus AgentOS 提供统一的 LLM 调用接口,实现 Provider 无关性。 

4v1.3.36: +Function Calling / Tool Use 抽象。 

5""" 

6 

7from __future__ import annotations 

8 

9from abc import ABC, abstractmethod 

10from dataclasses import dataclass, field 

11from enum import Enum 

12from typing import Any, Callable, Iterator, Union 

13 

14__all__ = [ 

15 "MessageRole", 

16 "Message", 

17 "CompletionUsage", 

18 "CompletionChoice", 

19 "CompletionResult", 

20 "StreamChunk", 

21 "TokenUsage", 

22 "Tool", 

23 "ToolFunction", 

24 "ToolParameter", 

25 "ToolCall", 

26 "LLMProvider", 

27] 

28 

29 

30class MessageRole(str, Enum): 

31 SYSTEM = "system" 

32 USER = "user" 

33 ASSISTANT = "assistant" 

34 TOOL = "tool" 

35 

36 

37@dataclass 

38class TokenUsage: 

39 prompt_tokens: int = 0 

40 completion_tokens: int = 0 

41 total_tokens: int = 0 

42 

43 

44@dataclass 

45class CompletionUsage(TokenUsage): 

46 cost_usd: float = 0.0 

47 

48 

49@dataclass 

50class Message: 

51 role: MessageRole 

52 content: str 

53 name: str | None = None 

54 tool_call_id: str | None = None 

55 tool_calls: list[ToolCall] | None = None 

56 

57 def as_dict(self) -> dict[str, Any]: 

58 d: dict[str, Any] = {"role": self.role.value, "content": self.content} 

59 if self.name: 

60 d["name"] = self.name 

61 if self.tool_call_id: 

62 d["tool_call_id"] = self.tool_call_id 

63 return d 

64 

65 

66# --- Function Calling / Tool Use --- 

67 

68@dataclass 

69class ToolParameter: 

70 """JSON Schema 属性定义。""" 

71 type: str = "string" 

72 description: str = "" 

73 enum: list[str] | None = None 

74 required: bool = False 

75 

76 def as_schema(self) -> dict[str, Any]: 

77 s: dict[str, Any] = {"type": self.type} 

78 if self.description: 

79 s["description"] = self.description 

80 if self.enum: 

81 s["enum"] = self.enum 

82 return s 

83 

84 

85@dataclass 

86class ToolFunction: 

87 """函数定义。""" 

88 name: str 

89 description: str = "" 

90 parameters: dict[str, ToolParameter] = field(default_factory=dict) 

91 required: list[str] = field(default_factory=list) 

92 

93 def as_schema(self) -> dict[str, Any]: 

94 props = {k: v.as_schema() for k, v in self.parameters.items()} 

95 return { 

96 "type": "function", 

97 "function": { 

98 "name": self.name, 

99 "description": self.description, 

100 "parameters": { 

101 "type": "object", 

102 "properties": props, 

103 "required": self.required or [k for k, v in self.parameters.items() if v.required], 

104 }, 

105 }, 

106 } 

107 

108 

109@dataclass 

110class Tool: 

111 """顶层 Tool 包装。""" 

112 function: ToolFunction 

113 

114 def as_schema(self) -> dict[str, Any]: 

115 return self.function.as_schema() 

116 

117 @classmethod 

118 def from_function( 

119 cls, 

120 name: str, 

121 description: str = "", 

122 parameters: dict[str, ToolParameter] | None = None, 

123 required: list[str] | None = None, 

124 ) -> Tool: 

125 return cls( 

126 function=ToolFunction( 

127 name=name, description=description, 

128 parameters=parameters or {}, 

129 required=required or [], 

130 ) 

131 ) 

132 

133 

134@dataclass 

135class ToolCall: 

136 """模型请求的工具调用。""" 

137 id: str 

138 name: str 

139 arguments: str # JSON string 

140 

141 @property 

142 def parsed_arguments(self) -> dict[str, Any]: 

143 import json 

144 return json.loads(self.arguments) 

145 

146 

147@dataclass 

148class CompletionChoice: 

149 index: int 

150 message: Message 

151 finish_reason: str = "stop" 

152 

153 

154@dataclass 

155class CompletionResult: 

156 id: str = "" 

157 model: str = "" 

158 choices: list[CompletionChoice] = field(default_factory=list) 

159 usage: CompletionUsage = field(default_factory=CompletionUsage) 

160 created: int = 0 

161 

162 

163@dataclass 

164class StreamChunk: 

165 content: str = "" 

166 finish_reason: str | None = None 

167 index: int = 0 

168 tool_calls: list[ToolCall] | None = None 

169 

170 

171class LLMProvider(ABC): 

172 """统一 LLM Provider 抽象。实现 OpenAI / Anthropic / 本地模型 的标准化调用。""" 

173 

174 def __init__(self, model: str = "", api_key: str = "", base_url: str = ""): 

175 self.model = model 

176 self.api_key = api_key 

177 self.base_url = base_url 

178 

179 @abstractmethod 

180 def chat( 

181 self, 

182 messages: list[Message], 

183 *, 

184 temperature: float = 0.7, 

185 max_tokens: int = 4096, 

186 top_p: float = 1.0, 

187 stop: list[str] | None = None, 

188 tools: list[Tool] | None = None, 

189 tool_choice: str = "auto", 

190 **kwargs: Any, 

191 ) -> CompletionResult: 

192 """同步聊天补全。""" 

193 ... 

194 

195 @abstractmethod 

196 async def achat( 

197 self, 

198 messages: list[Message], 

199 *, 

200 temperature: float = 0.7, 

201 max_tokens: int = 4096, 

202 top_p: float = 1.0, 

203 stop: list[str] | None = None, 

204 tools: list[Tool] | None = None, 

205 tool_choice: str = "auto", 

206 **kwargs: Any, 

207 ) -> CompletionResult: 

208 """异步聊天补全。""" 

209 ... 

210 

211 def stream( 

212 self, 

213 messages: list[Message], 

214 *, 

215 temperature: float = 0.7, 

216 max_tokens: int = 4096, 

217 tools: list[Tool] | None = None, 

218 **kwargs: Any, 

219 ) -> Iterator[StreamChunk]: 

220 """流式聊天补全。默认调用非流式包装。""" 

221 result = self.chat(messages, temperature=temperature, max_tokens=max_tokens, tools=tools, **kwargs) 

222 for c in result.choices: 

223 yield StreamChunk(content=c.message.content, finish_reason=c.finish_reason, index=c.index) 

224 

225 async def astream( 

226 self, 

227 messages: list[Message], 

228 *, 

229 temperature: float = 0.7, 

230 max_tokens: int = 4096, 

231 tools: list[Tool] | None = None, 

232 **kwargs: Any, 

233 ): 

234 """异步流式补全。默认调用 achat 包装。""" 

235 result = await self.achat(messages, temperature=temperature, max_tokens=max_tokens, tools=tools, **kwargs) 

236 for c in result.choices: 

237 yield StreamChunk(content=c.message.content, finish_reason=c.finish_reason, index=c.index) 

238 

239 @property 

240 @abstractmethod 

241 def provider_name(self) -> str: 

242 ...