Coverage for agentos/models/backends/ollama.py: 26%

97 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1"""Ollama backend for AgentOS. 

2 

3Supports local LLM inference via Ollama. 

4Models: llama3, mistral, codellama, phi3, gemma2, deepseek-r1, etc. 

5""" 

6 

7from dataclasses import dataclass, field 

8from typing import Any, AsyncIterator, Dict, List, Optional 

9 

10 

11@dataclass 

12class OllamaConfig: 

13 """Configuration for Ollama backend.""" 

14 base_url: str = "http://localhost:11434" 

15 model: str = "llama3" 

16 temperature: float = 0.7 

17 max_tokens: int = 4096 

18 top_p: float = 0.9 

19 top_k: int = 40 

20 num_ctx: int = 8192 

21 timeout: int = 120 

22 max_retries: int = 3 

23 keep_alive: str = "5m" 

24 

25 

26class OllamaClient: 

27 """Ollama LLM client for local model inference. 

28 

29 Supports: 

30 - Chat completions (streaming and non-streaming) 

31 - Tool calling (function calling) 

32 - Model listing and management 

33 - Custom system prompts 

34 """ 

35 

36 def __init__(self, config: Optional[OllamaConfig] = None): 

37 self.config = config or OllamaConfig() 

38 

39 @property 

40 def _generate_url(self) -> str: 

41 return f"{self.config.base_url}/api/chat" 

42 

43 def _build_payload( 

44 self, 

45 messages: List[Dict[str, str]], 

46 system: Optional[str] = None, 

47 **kwargs, 

48 ) -> Dict[str, Any]: 

49 msgs = messages.copy() 

50 if system: 

51 msgs.insert(0, {"role": "system", "content": system}) 

52 

53 payload: Dict[str, Any] = { 

54 "model": self.config.model, 

55 "messages": msgs, 

56 "stream": False, 

57 "options": { 

58 "temperature": kwargs.get("temperature", self.config.temperature), 

59 "num_predict": kwargs.get("max_tokens", self.config.max_tokens), 

60 "top_p": kwargs.get("top_p", self.config.top_p), 

61 "top_k": self.config.top_k, 

62 "num_ctx": self.config.num_ctx, 

63 }, 

64 } 

65 

66 if kwargs.get("tools"): 

67 payload["tools"] = kwargs["tools"] 

68 

69 return payload 

70 

71 async def _async_request(self, payload: Dict) -> Dict: 

72 import httpx 

73 

74 timeout = httpx.Timeout(self.config.timeout) 

75 for attempt in range(self.config.max_retries): 

76 try: 

77 async with httpx.AsyncClient() as client: 

78 resp = await client.post( 

79 self._generate_url, 

80 json=payload, 

81 timeout=timeout, 

82 ) 

83 resp.raise_for_status() 

84 return resp.json() 

85 except (httpx.HTTPStatusError, httpx.ConnectError) as e: 

86 if attempt == self.config.max_retries - 1: 

87 raise 

88 import asyncio 

89 await asyncio.sleep(2 ** attempt) 

90 

91 async def chat( 

92 self, 

93 messages: List[Dict[str, str]], 

94 system: Optional[str] = None, 

95 **kwargs, 

96 ) -> Dict[str, Any]: 

97 """Send a chat completion request. 

98 

99 Returns dict with 'content', 'role', 'usage', 'model'. 

100 """ 

101 payload = self._build_payload(messages, system, **kwargs) 

102 result = await self._async_request(payload) 

103 

104 message = result.get("message", {}) 

105 tool_calls = [] 

106 if "tool_calls" in message: 

107 for tc in message["tool_calls"]: 

108 func = tc.get("function", {}) 

109 tool_calls.append({ 

110 "id": tc.get("id", ""), 

111 "name": func.get("name", ""), 

112 "arguments": func.get("arguments", "{}"), 

113 }) 

114 

115 return { 

116 "content": message.get("content", ""), 

117 "role": "assistant", 

118 "usage": { 

119 "input_tokens": result.get("prompt_eval_count", 0), 

120 "output_tokens": result.get("eval_count", 0), 

121 "total_duration_ms": result.get("total_duration", 0), 

122 }, 

123 "model": result.get("model", self.config.model), 

124 "done_reason": result.get("done_reason", ""), 

125 "tool_calls": tool_calls, 

126 } 

127 

128 async def chat_stream( 

129 self, 

130 messages: List[Dict[str, str]], 

131 system: Optional[str] = None, 

132 **kwargs, 

133 ) -> AsyncIterator[Dict[str, Any]]: 

134 """Stream chat completion tokens.""" 

135 payload = self._build_payload(messages, system, **kwargs) 

136 payload["stream"] = True 

137 

138 import httpx 

139 

140 timeout = httpx.Timeout(self.config.timeout * 2) 

141 async with httpx.AsyncClient() as client: 

142 async with client.stream( 

143 "POST", 

144 self._generate_url, 

145 json=payload, 

146 timeout=timeout, 

147 ) as response: 

148 response.raise_for_status() 

149 async for line in response.aiter_lines(): 

150 import json 

151 try: 

152 event = json.loads(line) 

153 message = event.get("message", {}) 

154 yield { 

155 "delta": message.get("content", ""), 

156 "done": event.get("done", False), 

157 "model": event.get("model", self.config.model), 

158 } 

159 if event.get("done"): 

160 break 

161 except json.JSONDecodeError: 

162 continue 

163 

164 def sync_chat( 

165 self, 

166 messages: List[Dict[str, str]], 

167 system: Optional[str] = None, 

168 **kwargs, 

169 ) -> Dict[str, Any]: 

170 """Synchronous chat completion.""" 

171 import asyncio 

172 

173 loop = asyncio.new_event_loop() 

174 try: 

175 return loop.run_until_complete(self.chat(messages, system, **kwargs)) 

176 finally: 

177 loop.close() 

178 

179 async def list_models(self) -> List[Dict[str, Any]]: 

180 """List locally available Ollama models.""" 

181 import httpx 

182 

183 timeout = httpx.Timeout(self.config.timeout) 

184 async with httpx.AsyncClient() as client: 

185 resp = await client.get( 

186 f"{self.config.base_url}/api/tags", 

187 timeout=timeout, 

188 ) 

189 resp.raise_for_status() 

190 data = resp.json() 

191 return [ 

192 { 

193 "name": m.get("name", ""), 

194 "size": m.get("size", 0), 

195 "modified": m.get("modified_at", ""), 

196 "format": m.get("details", {}).get("format", ""), 

197 } 

198 for m in data.get("models", []) 

199 ] 

200 

201 async def pull_model(self, model_name: str) -> AsyncIterator[Dict[str, Any]]: 

202 """Pull a model from Ollama registry.""" 

203 import httpx 

204 

205 timeout = httpx.Timeout(600) # 10 min for downloads 

206 async with httpx.AsyncClient() as client: 

207 async with client.stream( 

208 "POST", 

209 f"{self.config.base_url}/api/pull", 

210 json={"name": model_name, "stream": True}, 

211 timeout=timeout, 

212 ) as response: 

213 response.raise_for_status() 

214 async for line in response.aiter_lines(): 

215 import json 

216 try: 

217 event = json.loads(line) 

218 yield event 

219 except json.JSONDecodeError: 

220 continue