Coverage for src / kemi / summarizer.py: 100%
65 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1"""LLM-powered abstractive summarization for memory consolidation.
3Provides a pluggable :class:`LLMSummarizer` that can use OpenAI, Anthropic,
4Ollama (via OpenAI-compatible API), or a custom callable to generate concise
5abstractive summaries from a list of related memory texts.
7Typical usage::
9 summarizer = LLMSummarizer(provider="openai", model="gpt-4o-mini")
10 summary = summarizer.summarize([
11 "User loves hiking in the Alps",
12 "User visited Switzerland last summer",
13 "User plans to hike Mont Blanc next year",
14 ])
15 # -> "User enjoys alpine hiking, has visited Switzerland, and plans to hike Mont Blanc."
16"""
18import logging
19from collections.abc import Callable
20from typing import Any
22logger = logging.getLogger(__name__)
24_DEFAULT_PROMPT_TEMPLATE = (
25 "Summarize these related memories into a concise, integrated statement: {memories}"
26)
28# Type alias for custom callables: (prompt: str, **kwargs) -> str
29LLMCallback = Callable[..., str]
32class LLMSummarizer:
33 """Generate abstractive summaries of memory groups using an LLM.
35 Supports four modes via the *provider* argument:
37 ``"openai"``
38 Uses ``openai.OpenAI``. The ``model`` defaults to ``gpt-4o-mini``.
39 Reads ``OPENAI_API_KEY`` from the environment.
41 ``"anthropic"``
42 Uses ``anthropic.Anthropic``. The ``model`` defaults to
43 ``claude-3-haiku-20240307``. Reads ``ANTHROPIC_API_KEY`` from the
44 environment.
46 ``"ollama"``
47 Uses the OpenAI-compatible endpoint at ``ollama_base_url``
48 (default ``http://localhost:11434/v1``). The ``model`` defaults to
49 ``llama3.2``.
51 ``"custom"``
52 Uses the provided ``custom_callback`` callable directly.
53 ``custom_callback(prompt, **kwargs)`` must return the summary text.
55 Args:
56 provider: One of ``"openai"``, ``"anthropic"``, ``"ollama"``,
57 or ``"custom"``.
58 model: Model name override. If omitted, a sensible default is used
59 per provider.
60 api_key: API key override. Falls back to the standard env var when
61 not provided.
62 ollama_base_url: Base URL for Ollama's OpenAI-compatible endpoint.
63 custom_callback: Callable used when ``provider="custom"``.
64 prompt_template: Template string with a ``{memories}`` placeholder
65 (one memory per line). Defaults to :data:`_DEFAULT_PROMPT_TEMPLATE`.
66 """
68 def __init__(
69 self,
70 provider: str = "openai",
71 model: str | None = None,
72 api_key: str | None = None,
73 ollama_base_url: str | None = None,
74 custom_callback: LLMCallback | None = None,
75 prompt_template: str | None = None,
76 ) -> None:
77 self._provider = provider.lower()
78 self._model = model
79 self._api_key = api_key
80 self._ollama_base_url = ollama_base_url or "http://localhost:11434/v1"
81 self._custom_callback = custom_callback
82 self._prompt_template = prompt_template or _DEFAULT_PROMPT_TEMPLATE
84 self._client: Any = None
85 self._effective_model: str | None = None
86 self._init_client()
88 def _init_client(self) -> None:
89 """Lazily initialise the LLM client based on *provider*."""
90 if self._provider == "custom":
91 if self._custom_callback is None:
92 raise ValueError("custom_callback is required when provider='custom'")
93 return
95 if self._provider == "openai" or self._provider == "ollama":
96 try:
97 from openai import OpenAI
98 except ImportError as exc:
99 raise ImportError(
100 "openai package is required for provider='openai' or "
101 "provider='ollama'. Install with: pip install openai"
102 ) from exc
104 kwargs: dict[str, Any] = {}
105 if self._provider == "ollama":
106 kwargs["base_url"] = self._ollama_base_url
107 kwargs["api_key"] = self._api_key or "ollama"
108 else:
109 kwargs["api_key"] = self._api_key
110 self._client = OpenAI(**kwargs)
111 self._effective_model = self._model or (
112 "gpt-4o-mini" if self._provider == "openai" else "llama3.2"
113 )
115 elif self._provider == "anthropic":
116 try:
117 from anthropic import Anthropic
118 except ImportError as exc:
119 raise ImportError(
120 "anthropic package is required for provider='anthropic'. "
121 "Install with: pip install anthropic"
122 ) from exc
124 self._client = Anthropic(api_key=self._api_key)
125 self._effective_model = self._model or "claude-3-haiku-20240307"
127 else:
128 raise ValueError(
129 f"Unknown provider: {self._provider}. "
130 "Expected one of: openai, anthropic, ollama, custom"
131 )
133 def summarize(
134 self,
135 memories: list[str],
136 **kwargs: Any,
137 ) -> str:
138 """Generate an abstractive summary from a list of memory texts.
140 Args:
141 memories: List of memory content strings to summarize.
142 **kwargs: Additional keyword arguments passed through to the
143 underlying LLM call (e.g. ``temperature=0.3``,
144 ``max_tokens=200``).
146 Returns:
147 The generated summary text, or an empty string on failure.
148 """
149 if not memories:
150 return ""
152 # Format the prompt
153 memories_text = "\n".join(f"- {m}" for m in memories)
154 prompt = self._prompt_template.format(memories=memories_text)
156 try:
157 if self._provider == "custom":
158 if self._custom_callback is None:
159 return ""
160 return self._custom_callback(prompt, **kwargs)
162 if self._provider in ("openai", "ollama"):
163 if self._client is None:
164 return ""
165 response = self._client.chat.completions.create(
166 model=self._effective_model or "gpt-4o-mini",
167 messages=[
168 {"role": "user", "content": prompt},
169 ],
170 **kwargs,
171 )
172 return response.choices[0].message.content or ""
174 if self._provider == "anthropic":
175 if self._client is None:
176 return ""
177 response = self._client.messages.create(
178 model=self._effective_model or "claude-3-haiku-20240307",
179 max_tokens=kwargs.pop("max_tokens", 300),
180 messages=[
181 {"role": "user", "content": prompt},
182 ],
183 **kwargs,
184 )
185 return response.content[0].text if response.content else ""
187 except Exception:
188 logger.warning("LLM summarization failed, falling back to extractive summary", exc_info=True)
190 return ""