"""Adapter for OpenAI chat completions via the official ``openai`` SDK.
Streams via ``chat.completions.create(stream=True)``. OpenAI streams
text in ``delta.content`` and tool calls in ``delta.tool_calls`` arrays
where each entry carries an ``index``; the same index across chunks
refers to the same tool call (so we accumulate by index). The final
chunk with ``stream_options={"include_usage": True}`` carries token
counts.
The SDK is imported lazily inside ``__init__`` so users without the
``openai`` extra installed can still ``from jeevesagent.model import
OpenAIModel`` — the import only fires when constructing without
passing a ``client``.
"""
from __future__ import annotations
import json
import os
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any
from ..core.ids import new_id
from ..core.types import Message, ModelChunk, Role, ToolCall, ToolDef, Usage
@dataclass
class _OAIPartial:
id: str = ""
name: str = ""
args_json: str = ""
[docs]
class OpenAIModel:
"""Talks to OpenAI via :class:`openai.AsyncOpenAI`."""
def __init__(
self,
model: str = "gpt-4o",
*,
client: Any = None,
api_key: str | None = None,
base_url: str | None = None,
) -> None:
self.name = model
if client is not None:
self._client = client
else:
try:
from openai import AsyncOpenAI
except ImportError as exc: # pragma: no cover — depends on user env
raise ImportError(
"OpenAI SDK not installed. "
"Install with: pip install 'jeevesagent[openai]'"
) from exc
self._client = AsyncOpenAI(
api_key=api_key or os.environ.get("OPENAI_API_KEY"),
base_url=base_url,
)
[docs]
async def complete(
self,
messages: list[Message],
*,
tools: list[ToolDef] | None = None,
temperature: float = 1.0,
max_tokens: int | None = None,
) -> tuple[str, list[ToolCall], Usage, str]:
"""Single-shot completion (no per-chunk yields).
Tries the OpenAI non-streaming endpoint
(``stream=False``) first. If that fails — e.g. when a test
fake client only supports streaming, or a transport doesn't
honor ``stream=False`` — falls back to consuming
:meth:`stream` internally and accumulating the result. The
fallback still saves the per-chunk yield + Event
construction overhead on the architecture side because
ReAct calls ``complete`` with a single ``await``.
"""
oai_messages = _to_openai_messages(messages)
oai_tools = [_to_openai_tool(t) for t in (tools or [])]
kwargs: dict[str, Any] = {
"model": self.name,
"messages": oai_messages,
"stream": False,
}
if temperature != 1.0:
kwargs["temperature"] = temperature
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
if oai_tools:
kwargs["tools"] = oai_tools
try:
response = await self._client.chat.completions.create(**kwargs)
except Exception: # noqa: BLE001 — surface the fallback transparently
return await _consume_stream(
self.stream(
messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
)
)
choices = getattr(response, "choices", None) or []
if not choices:
# Some clients (test fakes; transports that don't honour
# ``stream=False``) return an iterator instead of a
# single response object. Fall back to consuming the
# adapter's own ``stream`` — it knows how to parse the
# chunks into ModelChunks.
if hasattr(response, "__aiter__"):
return await _consume_stream(
self.stream(
messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
)
)
return ("", [], Usage(), "stop")
choice = choices[0]
message = getattr(choice, "message", None)
text = (getattr(message, "content", None) or "") if message else ""
tool_calls: list[ToolCall] = []
raw_tool_calls = (
getattr(message, "tool_calls", None) if message else None
)
for tc in raw_tool_calls or []:
fn = getattr(tc, "function", None)
name = getattr(fn, "name", "") if fn else ""
args_json = getattr(fn, "arguments", "") if fn else ""
try:
args = json.loads(args_json) if args_json else {}
except json.JSONDecodeError:
args = {}
tool_calls.append(
ToolCall(
id=getattr(tc, "id", None) or new_id("tc"),
tool=name,
args=args,
)
)
u = getattr(response, "usage", None)
usage = Usage(
input_tokens=getattr(u, "prompt_tokens", 0) or 0,
output_tokens=getattr(u, "completion_tokens", 0) or 0,
)
finish_reason = getattr(choice, "finish_reason", None) or "stop"
return text, tool_calls, usage, str(finish_reason)
[docs]
async def stream(
self,
messages: list[Message],
*,
tools: list[ToolDef] | None = None,
temperature: float = 1.0,
max_tokens: int | None = None,
) -> AsyncIterator[ModelChunk]:
oai_messages = _to_openai_messages(messages)
oai_tools = [_to_openai_tool(t) for t in (tools or [])]
kwargs: dict[str, Any] = {
"model": self.name,
"messages": oai_messages,
"stream": True,
"stream_options": {"include_usage": True},
}
if temperature != 1.0:
kwargs["temperature"] = temperature
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
if oai_tools:
kwargs["tools"] = oai_tools
partials: dict[int, _OAIPartial] = {}
usage = Usage()
finish_reason: str | None = None
stream = await self._client.chat.completions.create(**kwargs)
async for chunk in stream:
chunk_usage = getattr(chunk, "usage", None)
if chunk_usage is not None:
usage = Usage(
input_tokens=getattr(chunk_usage, "prompt_tokens", 0) or 0,
output_tokens=getattr(chunk_usage, "completion_tokens", 0) or 0,
)
choices = getattr(chunk, "choices", None)
if not choices:
continue
choice = choices[0]
delta = getattr(choice, "delta", None)
if delta is None:
continue
content = getattr(delta, "content", None)
if content:
yield ModelChunk(kind="text", text=content)
tool_call_deltas = getattr(delta, "tool_calls", None)
if tool_call_deltas:
for tc in tool_call_deltas:
idx = getattr(tc, "index", 0) or 0
p = partials.setdefault(idx, _OAIPartial())
tc_id = getattr(tc, "id", None)
if tc_id:
p.id = tc_id
fn = getattr(tc, "function", None)
if fn is not None:
fn_name = getattr(fn, "name", None)
if fn_name:
p.name = fn_name
fn_args = getattr(fn, "arguments", None)
if fn_args:
p.args_json += fn_args
fr = getattr(choice, "finish_reason", None)
if fr:
finish_reason = fr
# OpenAI emits each tool call across many deltas keyed by index;
# only after the stream ends do we have a complete picture.
for idx in sorted(partials.keys()):
p = partials[idx]
try:
args = json.loads(p.args_json) if p.args_json else {}
except json.JSONDecodeError:
args = {}
yield ModelChunk(
kind="tool_call",
tool_call=ToolCall(
id=p.id or new_id("tc"),
tool=p.name,
args=args,
),
)
yield ModelChunk(
kind="finish",
finish_reason=finish_reason or "stop",
usage=usage,
)
# ---------------------------------------------------------------------------
# Stream-consumption fallback for complete()
# ---------------------------------------------------------------------------
async def _consume_stream(
chunks: AsyncIterator[ModelChunk],
) -> tuple[str, list[ToolCall], Usage, str]:
"""Drain a ``ModelChunk`` stream into the same return tuple as
:meth:`OpenAIModel.complete`. Used when the non-streaming
transport isn't available."""
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
usage = Usage()
finish_reason = "stop"
async for chunk in chunks:
if chunk.kind == "text" and chunk.text is not None:
text_parts.append(chunk.text)
elif (
chunk.kind == "tool_call" and chunk.tool_call is not None
):
tool_calls.append(chunk.tool_call)
elif chunk.kind == "finish":
if chunk.usage is not None:
usage = chunk.usage
if chunk.finish_reason:
finish_reason = chunk.finish_reason
return "".join(text_parts), tool_calls, usage, finish_reason
async def _replay(it: AsyncIterator[Any]) -> AsyncIterator[ModelChunk]:
"""Last-resort: yield from an unknown async iterator that we
couldn't classify as a Response. Only used when a fake test
client returns an iterator from a ``stream=False`` call."""
async for item in it:
if isinstance(item, ModelChunk):
yield item
# ---------------------------------------------------------------------------
# Message conversion
# ---------------------------------------------------------------------------
def _to_openai_messages(messages: list[Message]) -> list[dict[str, Any]]:
"""Convert our messages to OpenAI's role/tool_call_id wire format."""
out: list[dict[str, Any]] = []
for m in messages:
if m.role == Role.SYSTEM:
out.append({"role": "system", "content": m.content})
elif m.role == Role.USER:
out.append({"role": "user", "content": m.content})
elif m.role == Role.ASSISTANT:
msg: dict[str, Any] = {"role": "assistant"}
if m.content:
msg["content"] = m.content
if m.tool_calls:
msg["tool_calls"] = [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.tool,
"arguments": json.dumps(tc.args),
},
}
for tc in m.tool_calls
]
out.append(msg)
elif m.role == Role.TOOL:
out.append(
{
"role": "tool",
"content": m.content,
"tool_call_id": m.tool_call_id or "",
}
)
return out
def _to_openai_tool(t: ToolDef) -> dict[str, Any]:
return {
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.input_schema or {"type": "object", "properties": {}},
},
}