Coverage for agentos/models/backends/ollama.py: 26%
97 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""Ollama backend for AgentOS.
3Supports local LLM inference via Ollama.
4Models: llama3, mistral, codellama, phi3, gemma2, deepseek-r1, etc.
5"""
7from dataclasses import dataclass, field
8from typing import Any, AsyncIterator, Dict, List, Optional
11@dataclass
12class OllamaConfig:
13 """Configuration for Ollama backend."""
14 base_url: str = "http://localhost:11434"
15 model: str = "llama3"
16 temperature: float = 0.7
17 max_tokens: int = 4096
18 top_p: float = 0.9
19 top_k: int = 40
20 num_ctx: int = 8192
21 timeout: int = 120
22 max_retries: int = 3
23 keep_alive: str = "5m"
26class OllamaClient:
27 """Ollama LLM client for local model inference.
29 Supports:
30 - Chat completions (streaming and non-streaming)
31 - Tool calling (function calling)
32 - Model listing and management
33 - Custom system prompts
34 """
36 def __init__(self, config: Optional[OllamaConfig] = None):
37 self.config = config or OllamaConfig()
39 @property
40 def _generate_url(self) -> str:
41 return f"{self.config.base_url}/api/chat"
43 def _build_payload(
44 self,
45 messages: List[Dict[str, str]],
46 system: Optional[str] = None,
47 **kwargs,
48 ) -> Dict[str, Any]:
49 msgs = messages.copy()
50 if system:
51 msgs.insert(0, {"role": "system", "content": system})
53 payload: Dict[str, Any] = {
54 "model": self.config.model,
55 "messages": msgs,
56 "stream": False,
57 "options": {
58 "temperature": kwargs.get("temperature", self.config.temperature),
59 "num_predict": kwargs.get("max_tokens", self.config.max_tokens),
60 "top_p": kwargs.get("top_p", self.config.top_p),
61 "top_k": self.config.top_k,
62 "num_ctx": self.config.num_ctx,
63 },
64 }
66 if kwargs.get("tools"):
67 payload["tools"] = kwargs["tools"]
69 return payload
71 async def _async_request(self, payload: Dict) -> Dict:
72 import httpx
74 timeout = httpx.Timeout(self.config.timeout)
75 for attempt in range(self.config.max_retries):
76 try:
77 async with httpx.AsyncClient() as client:
78 resp = await client.post(
79 self._generate_url,
80 json=payload,
81 timeout=timeout,
82 )
83 resp.raise_for_status()
84 return resp.json()
85 except (httpx.HTTPStatusError, httpx.ConnectError) as e:
86 if attempt == self.config.max_retries - 1:
87 raise
88 import asyncio
89 await asyncio.sleep(2 ** attempt)
91 async def chat(
92 self,
93 messages: List[Dict[str, str]],
94 system: Optional[str] = None,
95 **kwargs,
96 ) -> Dict[str, Any]:
97 """Send a chat completion request.
99 Returns dict with 'content', 'role', 'usage', 'model'.
100 """
101 payload = self._build_payload(messages, system, **kwargs)
102 result = await self._async_request(payload)
104 message = result.get("message", {})
105 tool_calls = []
106 if "tool_calls" in message:
107 for tc in message["tool_calls"]:
108 func = tc.get("function", {})
109 tool_calls.append({
110 "id": tc.get("id", ""),
111 "name": func.get("name", ""),
112 "arguments": func.get("arguments", "{}"),
113 })
115 return {
116 "content": message.get("content", ""),
117 "role": "assistant",
118 "usage": {
119 "input_tokens": result.get("prompt_eval_count", 0),
120 "output_tokens": result.get("eval_count", 0),
121 "total_duration_ms": result.get("total_duration", 0),
122 },
123 "model": result.get("model", self.config.model),
124 "done_reason": result.get("done_reason", ""),
125 "tool_calls": tool_calls,
126 }
128 async def chat_stream(
129 self,
130 messages: List[Dict[str, str]],
131 system: Optional[str] = None,
132 **kwargs,
133 ) -> AsyncIterator[Dict[str, Any]]:
134 """Stream chat completion tokens."""
135 payload = self._build_payload(messages, system, **kwargs)
136 payload["stream"] = True
138 import httpx
140 timeout = httpx.Timeout(self.config.timeout * 2)
141 async with httpx.AsyncClient() as client:
142 async with client.stream(
143 "POST",
144 self._generate_url,
145 json=payload,
146 timeout=timeout,
147 ) as response:
148 response.raise_for_status()
149 async for line in response.aiter_lines():
150 import json
151 try:
152 event = json.loads(line)
153 message = event.get("message", {})
154 yield {
155 "delta": message.get("content", ""),
156 "done": event.get("done", False),
157 "model": event.get("model", self.config.model),
158 }
159 if event.get("done"):
160 break
161 except json.JSONDecodeError:
162 continue
164 def sync_chat(
165 self,
166 messages: List[Dict[str, str]],
167 system: Optional[str] = None,
168 **kwargs,
169 ) -> Dict[str, Any]:
170 """Synchronous chat completion."""
171 import asyncio
173 loop = asyncio.new_event_loop()
174 try:
175 return loop.run_until_complete(self.chat(messages, system, **kwargs))
176 finally:
177 loop.close()
179 async def list_models(self) -> List[Dict[str, Any]]:
180 """List locally available Ollama models."""
181 import httpx
183 timeout = httpx.Timeout(self.config.timeout)
184 async with httpx.AsyncClient() as client:
185 resp = await client.get(
186 f"{self.config.base_url}/api/tags",
187 timeout=timeout,
188 )
189 resp.raise_for_status()
190 data = resp.json()
191 return [
192 {
193 "name": m.get("name", ""),
194 "size": m.get("size", 0),
195 "modified": m.get("modified_at", ""),
196 "format": m.get("details", {}).get("format", ""),
197 }
198 for m in data.get("models", [])
199 ]
201 async def pull_model(self, model_name: str) -> AsyncIterator[Dict[str, Any]]:
202 """Pull a model from Ollama registry."""
203 import httpx
205 timeout = httpx.Timeout(600) # 10 min for downloads
206 async with httpx.AsyncClient() as client:
207 async with client.stream(
208 "POST",
209 f"{self.config.base_url}/api/pull",
210 json={"name": model_name, "stream": True},
211 timeout=timeout,
212 ) as response:
213 response.raise_for_status()
214 async for line in response.aiter_lines():
215 import json
216 try:
217 event = json.loads(line)
218 yield event
219 except json.JSONDecodeError:
220 continue