Coverage for src / kemi / summarizer.py: 100%

65 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-05 15:47 +0000

1"""LLM-powered abstractive summarization for memory consolidation. 

2 

3Provides a pluggable :class:`LLMSummarizer` that can use OpenAI, Anthropic, 

4Ollama (via OpenAI-compatible API), or a custom callable to generate concise 

5abstractive summaries from a list of related memory texts. 

6 

7Typical usage:: 

8 

9 summarizer = LLMSummarizer(provider="openai", model="gpt-4o-mini") 

10 summary = summarizer.summarize([ 

11 "User loves hiking in the Alps", 

12 "User visited Switzerland last summer", 

13 "User plans to hike Mont Blanc next year", 

14 ]) 

15 # -> "User enjoys alpine hiking, has visited Switzerland, and plans to hike Mont Blanc." 

16""" 

17 

18import logging 

19from collections.abc import Callable 

20from typing import Any 

21 

22logger = logging.getLogger(__name__) 

23 

24_DEFAULT_PROMPT_TEMPLATE = ( 

25 "Summarize these related memories into a concise, integrated statement: {memories}" 

26) 

27 

28# Type alias for custom callables: (prompt: str, **kwargs) -> str 

29LLMCallback = Callable[..., str] 

30 

31 

32class LLMSummarizer: 

33 """Generate abstractive summaries of memory groups using an LLM. 

34 

35 Supports four modes via the *provider* argument: 

36 

37 ``"openai"`` 

38 Uses ``openai.OpenAI``. The ``model`` defaults to ``gpt-4o-mini``. 

39 Reads ``OPENAI_API_KEY`` from the environment. 

40 

41 ``"anthropic"`` 

42 Uses ``anthropic.Anthropic``. The ``model`` defaults to 

43 ``claude-3-haiku-20240307``. Reads ``ANTHROPIC_API_KEY`` from the 

44 environment. 

45 

46 ``"ollama"`` 

47 Uses the OpenAI-compatible endpoint at ``ollama_base_url`` 

48 (default ``http://localhost:11434/v1``). The ``model`` defaults to 

49 ``llama3.2``. 

50 

51 ``"custom"`` 

52 Uses the provided ``custom_callback`` callable directly. 

53 ``custom_callback(prompt, **kwargs)`` must return the summary text. 

54 

55 Args: 

56 provider: One of ``"openai"``, ``"anthropic"``, ``"ollama"``, 

57 or ``"custom"``. 

58 model: Model name override. If omitted, a sensible default is used 

59 per provider. 

60 api_key: API key override. Falls back to the standard env var when 

61 not provided. 

62 ollama_base_url: Base URL for Ollama's OpenAI-compatible endpoint. 

63 custom_callback: Callable used when ``provider="custom"``. 

64 prompt_template: Template string with a ``{memories}`` placeholder 

65 (one memory per line). Defaults to :data:`_DEFAULT_PROMPT_TEMPLATE`. 

66 """ 

67 

68 def __init__( 

69 self, 

70 provider: str = "openai", 

71 model: str | None = None, 

72 api_key: str | None = None, 

73 ollama_base_url: str | None = None, 

74 custom_callback: LLMCallback | None = None, 

75 prompt_template: str | None = None, 

76 ) -> None: 

77 self._provider = provider.lower() 

78 self._model = model 

79 self._api_key = api_key 

80 self._ollama_base_url = ollama_base_url or "http://localhost:11434/v1" 

81 self._custom_callback = custom_callback 

82 self._prompt_template = prompt_template or _DEFAULT_PROMPT_TEMPLATE 

83 

84 self._client: Any = None 

85 self._effective_model: str | None = None 

86 self._init_client() 

87 

88 def _init_client(self) -> None: 

89 """Lazily initialise the LLM client based on *provider*.""" 

90 if self._provider == "custom": 

91 if self._custom_callback is None: 

92 raise ValueError("custom_callback is required when provider='custom'") 

93 return 

94 

95 if self._provider == "openai" or self._provider == "ollama": 

96 try: 

97 from openai import OpenAI 

98 except ImportError as exc: 

99 raise ImportError( 

100 "openai package is required for provider='openai' or " 

101 "provider='ollama'. Install with: pip install openai" 

102 ) from exc 

103 

104 kwargs: dict[str, Any] = {} 

105 if self._provider == "ollama": 

106 kwargs["base_url"] = self._ollama_base_url 

107 kwargs["api_key"] = self._api_key or "ollama" 

108 else: 

109 kwargs["api_key"] = self._api_key 

110 self._client = OpenAI(**kwargs) 

111 self._effective_model = self._model or ( 

112 "gpt-4o-mini" if self._provider == "openai" else "llama3.2" 

113 ) 

114 

115 elif self._provider == "anthropic": 

116 try: 

117 from anthropic import Anthropic 

118 except ImportError as exc: 

119 raise ImportError( 

120 "anthropic package is required for provider='anthropic'. " 

121 "Install with: pip install anthropic" 

122 ) from exc 

123 

124 self._client = Anthropic(api_key=self._api_key) 

125 self._effective_model = self._model or "claude-3-haiku-20240307" 

126 

127 else: 

128 raise ValueError( 

129 f"Unknown provider: {self._provider}. " 

130 "Expected one of: openai, anthropic, ollama, custom" 

131 ) 

132 

133 def summarize( 

134 self, 

135 memories: list[str], 

136 **kwargs: Any, 

137 ) -> str: 

138 """Generate an abstractive summary from a list of memory texts. 

139 

140 Args: 

141 memories: List of memory content strings to summarize. 

142 **kwargs: Additional keyword arguments passed through to the 

143 underlying LLM call (e.g. ``temperature=0.3``, 

144 ``max_tokens=200``). 

145 

146 Returns: 

147 The generated summary text, or an empty string on failure. 

148 """ 

149 if not memories: 

150 return "" 

151 

152 # Format the prompt 

153 memories_text = "\n".join(f"- {m}" for m in memories) 

154 prompt = self._prompt_template.format(memories=memories_text) 

155 

156 try: 

157 if self._provider == "custom": 

158 if self._custom_callback is None: 

159 return "" 

160 return self._custom_callback(prompt, **kwargs) 

161 

162 if self._provider in ("openai", "ollama"): 

163 if self._client is None: 

164 return "" 

165 response = self._client.chat.completions.create( 

166 model=self._effective_model or "gpt-4o-mini", 

167 messages=[ 

168 {"role": "user", "content": prompt}, 

169 ], 

170 **kwargs, 

171 ) 

172 return response.choices[0].message.content or "" 

173 

174 if self._provider == "anthropic": 

175 if self._client is None: 

176 return "" 

177 response = self._client.messages.create( 

178 model=self._effective_model or "claude-3-haiku-20240307", 

179 max_tokens=kwargs.pop("max_tokens", 300), 

180 messages=[ 

181 {"role": "user", "content": prompt}, 

182 ], 

183 **kwargs, 

184 ) 

185 return response.content[0].text if response.content else "" 

186 

187 except Exception: 

188 logger.warning("LLM summarization failed, falling back to extractive summary", exc_info=True) 

189 

190 return ""