Coverage for src / autoencodix / utils / _llm_explainer.py: 15%
135 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import os
2import re
3import json
4import ollama
5from dotenv import find_dotenv, load_dotenv
6from mistralai.client import Mistral
7from typing import List, Dict, Any
9import requests
10import warnings
12warnings.simplefilter("always", UserWarning)
15class LLMExplainer:
16 """LLM client with support for multiple providers."""
18 def __init__(
19 self,
20 client_name: str,
21 model_name: str,
22 genes_to_latent: Dict[str, List],
23 prompt: str,
24 ):
25 """Initialize LLM client.
27 Args:
28 client_name: Name of the LLM client.
29 model_name: Name of the model to use.
30 """
31 load_dotenv(find_dotenv())
32 self._client_name = client_name
33 self._model = model_name
34 self.prompt = prompt
35 self.genes_to_latent = genes_to_latent
36 if self._client_name == "ollama":
37 # Set Ollama host for Docker compatibility
38 ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434")
39 # Configure ollama client
40 ollama_client = ollama.Client(host=ollama_host)
41 self._ollama_client = ollama_client
42 try:
43 response = self._ollama_client.list()
44 available_models = [m.model for m in response.models]
45 if self._model not in available_models:
46 raise ValueError(
47 f"Model '{self._model}' not available. "
48 f"Available: {available_models}"
49 )
50 except Exception as e:
51 import warnings
53 warnings.warn(f"Could not validate Ollama model '{self._model}': {e}")
54 # Initialize client-specific objects
55 elif self._client_name == "mistral":
56 api_key = os.environ.get("MISTRAL_API_KEY")
57 if not api_key:
58 raise ValueError("Environment variable MISTRAL_API_KEY not set")
59 self._mistral_client = Mistral(api_key=api_key)
60 elif self._client_name == "openrouter":
61 self._openrouter_api_key = os.environ.get("OPENROUTER_PREMIUM_API_KEY")
62 if not self._openrouter_api_key:
63 raise ValueError(
64 "Environment variable OPENROUTER_PREMIUM_API_KEY not set"
65 )
66 self._openrouter_url = "https://openrouter.ai/api/v1/chat/completions"
67 elif self._client_name == "scads-llm":
68 self._scads_llm_api_key = os.environ.get("SCADS_LLM_API_KEY")
69 if not self._scads_llm_api_key:
70 raise ValueError("Environment variable SCADS_LLM_API_KEY not set")
71 self._scads_llm_url = "https://llm.scads.ai/v1"
72 else:
73 raise NotImplementedError(f"Client {self._client_name} not implemented")
75 def _build_prompt(self, *, gene_list: List[str], prompt: str) -> str:
76 """Builds the prompt for the LLM.
78 Args:
79 gene_list: List of genes.
80 prompt: The prompt template.
82 Returns:
83 The formatted prompt.
84 """
85 gene_block = "\n".join(f"- {g}" for g in gene_list)
86 return prompt.format(gene_block=gene_block)
88 def extract_json_from_output(self, text: str) -> Dict[str, Any]:
89 """Extract and parse JSON from LLM output robustly.
91 Args:
92 text: LLM output string.
94 Returns:
95 Parsed dict with keys "TLDR" and "DETAILS".
96 """
97 # Strip whitespace and common wrappers
98 text = text.strip()
99 # Remove code block markers if present
100 text = re.sub(r"^```json\s*|\s*```$", "", text, flags=re.DOTALL)
101 # Extract the JSON substring if embedded in text
102 match = re.search(r"\{.*\}", text, re.DOTALL)
103 if match:
104 text = match.group(0)
105 # Fix common issues: single quotes to double, trailing commas
106 text = text.replace("'", '"')
107 text = re.sub(r",\s*([}\]])", r"\1", text)
108 # Remove unescaped newlines in strings (approximate fix)
109 text = re.sub(r"(?<!\\)\n", " ", text)
110 try:
111 return json.loads(text)
112 except json.JSONDecodeError as e:
113 import warnings
115 warnings.warn(f"Failed to parse JSON. Returning raw text. Error: {e}")
116 return {
117 "TLDR": text,
118 "DETAILS": {
119 "dominant_themes": text,
120 "hypotheses": [],
121 "pathways_summary": text,
122 },
123 }
125 def explain(self) -> Dict[str, Dict[str, Any]]:
126 """Generate explanations for each latent dimension.
128 Returns:
129 A dict mapping latent dimension -> parsed JSON explanation (or raw output if parsing fails).
130 """
131 res: Dict[str, Dict[str, Any]] = {}
132 markdown_sections = []
133 for key, genes in self.genes_to_latent.items():
134 print(f"Explaining latent dimension {key} with LLM...")
135 prompt = self._build_prompt(gene_list=genes, prompt=self.prompt)
136 raw_output = self._get_llm_answer(question=prompt)
137 if len(raw_output) == 0:
138 import warnings
140 warnings.warn(
141 f"Received empty response from LLM for latent dimension {key}. Skipping."
142 )
143 continue
144 # ---- TRY TO PARSE JSON ----
145 try:
146 parsed = self.extract_json_from_output(raw_output)
147 except Exception as e:
148 import warnings
150 warnings.warn(
151 f"Failed to parse JSON for latent dimension {key}. "
152 f"Using raw output instead. Error: {e}"
153 )
154 parsed = {
155 "TLDR": raw_output,
156 "DETAILS": {
157 "dominant_themes": raw_output,
158 "hypotheses": [],
159 "pathways_summary": raw_output,
160 },
161 }
162 res[key] = parsed
163 # ---- BUILD READABLE MARKDOWN ----
164 markdown_sections.append(f"# Latent Dimension {key}\n")
165 markdown_sections.append("## Most influential genes")
166 markdown_sections.append(", ".join(genes) + "\n")
167 markdown_sections.append("## TLDR")
168 markdown_sections.append(parsed.get("TLDR", "") + "\n")
169 markdown_sections.append("## Details\n")
171 details = parsed.get("DETAILS", {})
172 if isinstance(details, dict):
173 # Dominant themes
174 markdown_sections.append("### Dominant Biological Themes\n")
175 markdown_sections.append(details.get("dominant_themes", "") + "\n")
176 # Hypotheses
177 markdown_sections.append("### Mechanistic Hypotheses\n")
178 hyps = details.get("hypotheses", [])
179 if isinstance(hyps, list):
180 for hyp in hyps:
181 markdown_sections.append(f"- {hyp}\n")
182 else:
183 markdown_sections.append(str(hyps) + "\n")
184 # Pathways summary
185 markdown_sections.append("### Summary of Pathways/Processes\n")
186 markdown_sections.append(details.get("pathways_summary", "") + "\n")
187 else:
188 # Fallback if not dict
189 markdown_sections.append(str(details) + "\n")
191 markdown_sections.append("\n---\n")
192 # ---- WRITE ONE SINGLE FILE ----
193 output_path = os.path.join(os.getcwd(), "latent_explanations.md")
194 with open(output_path, "w") as f:
195 f.write("\n".join(markdown_sections))
196 print(f"Saved explanations to: {output_path}")
197 return res
199 def _get_llm_answer(self, *, question: str) -> str:
200 """Gets the LLM answer based on the client.
202 Args:
203 question: The input question.
205 Returns:
206 Generated response text.
207 """
208 if self._client_name == "mistral":
209 return self._get_mistral_answer(question=question)
210 elif self._client_name == "ollama":
211 return self._get_ollama_answer(question=question)
212 elif self._client_name == "openrouter":
213 return self._get_openrouter_answer(
214 question=question, model_name=self._model
215 )
216 elif self._client_name == "scads-llm":
217 return self._get_scads_llm_answer(question=question, model_name=self._model)
218 else:
219 raise NotImplementedError(f"Client {self._client_name} not implemented")
221 def _get_mistral_answer(self, *, question: str) -> str:
222 """Get answer from Mistral API.
224 Args:
225 question: The input question.
227 Returns:
228 Generated response text.
229 """
230 chat_response = self._mistral_client.chat.complete(
231 model=self._model,
232 messages=[
233 {
234 "role": "user",
235 "content": question, # type: ignore
236 },
237 ],
238 response_format={"type": "json_object"}, # Enforce JSON output
239 )
240 return chat_response.choices[0].message.content # type: ignore
242 def _get_ollama_answer(self, *, question: str) -> str:
243 """Get answer from Ollama.
245 Args:
246 question: The input question.
248 Returns:
249 Generated response text.
250 """
251 response = self._ollama_client.generate(
252 model=self._model,
253 prompt=question,
254 )
255 return response["response"]
257 def _get_openrouter_answer(self, *, question: str, model_name: str) -> str:
258 """Get answer from OpenRouter.
260 Args:
261 question: The input question.
262 model_name: The name of the model to use.
264 Returns:
265 Generated response text.
266 """
268 response = requests.post(
269 url=self._openrouter_url,
270 headers={
271 "Authorization": f"Bearer {self._openrouter_api_key}",
272 },
273 data=json.dumps(
274 {
275 "model": model_name,
276 "messages": [
277 {
278 "role": "user",
279 "content": question,
280 }
281 ],
282 }
283 ),
284 )
285 if response.status_code != 200:
286 # Warn
287 warnings.warn(
288 f"OpenRouter API request failed with status code {response.status_code}: {response.text}"
289 )
290 # Return empty response
291 return ""
292 if "choices" not in response.json() or len(response.json()["choices"]) == 0:
293 warnings.warn(
294 f"OpenRouter API response missing 'choices': {response.json()}"
295 )
296 return ""
297 return response.json()["choices"][0]["message"]["content"]
299 def _get_scads_llm_answer(self, *, question: str, model_name: str) -> str:
300 """Get answer from SCADS-LLM.
302 Args:
303 question: The input question.
304 model_name: The name of the model to use.
306 Returns:
307 Generated response text.
308 """
310 client = OpenAI(
311 base_url="https://llm.scads.ai/v1", api_key=self._scads_llm_api_key
312 )
314 response = client.chat.completions.create(
315 messages=[{"role": "user", "content": question}], model=model_name
316 )
318 return response.choices[0].message.content