Coverage for src / autoencodix / utils / _llm_explainer.py: 15%

135 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import os 

2import re 

3import json 

4import ollama 

5from dotenv import find_dotenv, load_dotenv 

6from mistralai.client import Mistral 

7from typing import List, Dict, Any 

8 

9import requests 

10import warnings 

11 

12warnings.simplefilter("always", UserWarning) 

13 

14 

15class LLMExplainer: 

16 """LLM client with support for multiple providers.""" 

17 

18 def __init__( 

19 self, 

20 client_name: str, 

21 model_name: str, 

22 genes_to_latent: Dict[str, List], 

23 prompt: str, 

24 ): 

25 """Initialize LLM client. 

26 

27 Args: 

28 client_name: Name of the LLM client. 

29 model_name: Name of the model to use. 

30 """ 

31 load_dotenv(find_dotenv()) 

32 self._client_name = client_name 

33 self._model = model_name 

34 self.prompt = prompt 

35 self.genes_to_latent = genes_to_latent 

36 if self._client_name == "ollama": 

37 # Set Ollama host for Docker compatibility 

38 ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434") 

39 # Configure ollama client 

40 ollama_client = ollama.Client(host=ollama_host) 

41 self._ollama_client = ollama_client 

42 try: 

43 response = self._ollama_client.list() 

44 available_models = [m.model for m in response.models] 

45 if self._model not in available_models: 

46 raise ValueError( 

47 f"Model '{self._model}' not available. " 

48 f"Available: {available_models}" 

49 ) 

50 except Exception as e: 

51 import warnings 

52 

53 warnings.warn(f"Could not validate Ollama model '{self._model}': {e}") 

54 # Initialize client-specific objects 

55 elif self._client_name == "mistral": 

56 api_key = os.environ.get("MISTRAL_API_KEY") 

57 if not api_key: 

58 raise ValueError("Environment variable MISTRAL_API_KEY not set") 

59 self._mistral_client = Mistral(api_key=api_key) 

60 elif self._client_name == "openrouter": 

61 self._openrouter_api_key = os.environ.get("OPENROUTER_PREMIUM_API_KEY") 

62 if not self._openrouter_api_key: 

63 raise ValueError( 

64 "Environment variable OPENROUTER_PREMIUM_API_KEY not set" 

65 ) 

66 self._openrouter_url = "https://openrouter.ai/api/v1/chat/completions" 

67 elif self._client_name == "scads-llm": 

68 self._scads_llm_api_key = os.environ.get("SCADS_LLM_API_KEY") 

69 if not self._scads_llm_api_key: 

70 raise ValueError("Environment variable SCADS_LLM_API_KEY not set") 

71 self._scads_llm_url = "https://llm.scads.ai/v1" 

72 else: 

73 raise NotImplementedError(f"Client {self._client_name} not implemented") 

74 

75 def _build_prompt(self, *, gene_list: List[str], prompt: str) -> str: 

76 """Builds the prompt for the LLM. 

77 

78 Args: 

79 gene_list: List of genes. 

80 prompt: The prompt template. 

81 

82 Returns: 

83 The formatted prompt. 

84 """ 

85 gene_block = "\n".join(f"- {g}" for g in gene_list) 

86 return prompt.format(gene_block=gene_block) 

87 

88 def extract_json_from_output(self, text: str) -> Dict[str, Any]: 

89 """Extract and parse JSON from LLM output robustly. 

90 

91 Args: 

92 text: LLM output string. 

93 

94 Returns: 

95 Parsed dict with keys "TLDR" and "DETAILS". 

96 """ 

97 # Strip whitespace and common wrappers 

98 text = text.strip() 

99 # Remove code block markers if present 

100 text = re.sub(r"^```json\s*|\s*```$", "", text, flags=re.DOTALL) 

101 # Extract the JSON substring if embedded in text 

102 match = re.search(r"\{.*\}", text, re.DOTALL) 

103 if match: 

104 text = match.group(0) 

105 # Fix common issues: single quotes to double, trailing commas 

106 text = text.replace("'", '"') 

107 text = re.sub(r",\s*([}\]])", r"\1", text) 

108 # Remove unescaped newlines in strings (approximate fix) 

109 text = re.sub(r"(?<!\\)\n", " ", text) 

110 try: 

111 return json.loads(text) 

112 except json.JSONDecodeError as e: 

113 import warnings 

114 

115 warnings.warn(f"Failed to parse JSON. Returning raw text. Error: {e}") 

116 return { 

117 "TLDR": text, 

118 "DETAILS": { 

119 "dominant_themes": text, 

120 "hypotheses": [], 

121 "pathways_summary": text, 

122 }, 

123 } 

124 

125 def explain(self) -> Dict[str, Dict[str, Any]]: 

126 """Generate explanations for each latent dimension. 

127 

128 Returns: 

129 A dict mapping latent dimension -> parsed JSON explanation (or raw output if parsing fails). 

130 """ 

131 res: Dict[str, Dict[str, Any]] = {} 

132 markdown_sections = [] 

133 for key, genes in self.genes_to_latent.items(): 

134 print(f"Explaining latent dimension {key} with LLM...") 

135 prompt = self._build_prompt(gene_list=genes, prompt=self.prompt) 

136 raw_output = self._get_llm_answer(question=prompt) 

137 if len(raw_output) == 0: 

138 import warnings 

139 

140 warnings.warn( 

141 f"Received empty response from LLM for latent dimension {key}. Skipping." 

142 ) 

143 continue 

144 # ---- TRY TO PARSE JSON ---- 

145 try: 

146 parsed = self.extract_json_from_output(raw_output) 

147 except Exception as e: 

148 import warnings 

149 

150 warnings.warn( 

151 f"Failed to parse JSON for latent dimension {key}. " 

152 f"Using raw output instead. Error: {e}" 

153 ) 

154 parsed = { 

155 "TLDR": raw_output, 

156 "DETAILS": { 

157 "dominant_themes": raw_output, 

158 "hypotheses": [], 

159 "pathways_summary": raw_output, 

160 }, 

161 } 

162 res[key] = parsed 

163 # ---- BUILD READABLE MARKDOWN ---- 

164 markdown_sections.append(f"# Latent Dimension {key}\n") 

165 markdown_sections.append("## Most influential genes") 

166 markdown_sections.append(", ".join(genes) + "\n") 

167 markdown_sections.append("## TLDR") 

168 markdown_sections.append(parsed.get("TLDR", "") + "\n") 

169 markdown_sections.append("## Details\n") 

170 

171 details = parsed.get("DETAILS", {}) 

172 if isinstance(details, dict): 

173 # Dominant themes 

174 markdown_sections.append("### Dominant Biological Themes\n") 

175 markdown_sections.append(details.get("dominant_themes", "") + "\n") 

176 # Hypotheses 

177 markdown_sections.append("### Mechanistic Hypotheses\n") 

178 hyps = details.get("hypotheses", []) 

179 if isinstance(hyps, list): 

180 for hyp in hyps: 

181 markdown_sections.append(f"- {hyp}\n") 

182 else: 

183 markdown_sections.append(str(hyps) + "\n") 

184 # Pathways summary 

185 markdown_sections.append("### Summary of Pathways/Processes\n") 

186 markdown_sections.append(details.get("pathways_summary", "") + "\n") 

187 else: 

188 # Fallback if not dict 

189 markdown_sections.append(str(details) + "\n") 

190 

191 markdown_sections.append("\n---\n") 

192 # ---- WRITE ONE SINGLE FILE ---- 

193 output_path = os.path.join(os.getcwd(), "latent_explanations.md") 

194 with open(output_path, "w") as f: 

195 f.write("\n".join(markdown_sections)) 

196 print(f"Saved explanations to: {output_path}") 

197 return res 

198 

199 def _get_llm_answer(self, *, question: str) -> str: 

200 """Gets the LLM answer based on the client. 

201 

202 Args: 

203 question: The input question. 

204 

205 Returns: 

206 Generated response text. 

207 """ 

208 if self._client_name == "mistral": 

209 return self._get_mistral_answer(question=question) 

210 elif self._client_name == "ollama": 

211 return self._get_ollama_answer(question=question) 

212 elif self._client_name == "openrouter": 

213 return self._get_openrouter_answer( 

214 question=question, model_name=self._model 

215 ) 

216 elif self._client_name == "scads-llm": 

217 return self._get_scads_llm_answer(question=question, model_name=self._model) 

218 else: 

219 raise NotImplementedError(f"Client {self._client_name} not implemented") 

220 

221 def _get_mistral_answer(self, *, question: str) -> str: 

222 """Get answer from Mistral API. 

223 

224 Args: 

225 question: The input question. 

226 

227 Returns: 

228 Generated response text. 

229 """ 

230 chat_response = self._mistral_client.chat.complete( 

231 model=self._model, 

232 messages=[ 

233 { 

234 "role": "user", 

235 "content": question, # type: ignore 

236 }, 

237 ], 

238 response_format={"type": "json_object"}, # Enforce JSON output 

239 ) 

240 return chat_response.choices[0].message.content # type: ignore 

241 

242 def _get_ollama_answer(self, *, question: str) -> str: 

243 """Get answer from Ollama. 

244 

245 Args: 

246 question: The input question. 

247 

248 Returns: 

249 Generated response text. 

250 """ 

251 response = self._ollama_client.generate( 

252 model=self._model, 

253 prompt=question, 

254 ) 

255 return response["response"] 

256 

257 def _get_openrouter_answer(self, *, question: str, model_name: str) -> str: 

258 """Get answer from OpenRouter. 

259 

260 Args: 

261 question: The input question. 

262 model_name: The name of the model to use. 

263 

264 Returns: 

265 Generated response text. 

266 """ 

267 

268 response = requests.post( 

269 url=self._openrouter_url, 

270 headers={ 

271 "Authorization": f"Bearer {self._openrouter_api_key}", 

272 }, 

273 data=json.dumps( 

274 { 

275 "model": model_name, 

276 "messages": [ 

277 { 

278 "role": "user", 

279 "content": question, 

280 } 

281 ], 

282 } 

283 ), 

284 ) 

285 if response.status_code != 200: 

286 # Warn 

287 warnings.warn( 

288 f"OpenRouter API request failed with status code {response.status_code}: {response.text}" 

289 ) 

290 # Return empty response 

291 return "" 

292 if "choices" not in response.json() or len(response.json()["choices"]) == 0: 

293 warnings.warn( 

294 f"OpenRouter API response missing 'choices': {response.json()}" 

295 ) 

296 return "" 

297 return response.json()["choices"][0]["message"]["content"] 

298 

299 def _get_scads_llm_answer(self, *, question: str, model_name: str) -> str: 

300 """Get answer from SCADS-LLM. 

301 

302 Args: 

303 question: The input question. 

304 model_name: The name of the model to use. 

305 

306 Returns: 

307 Generated response text. 

308 """ 

309 

310 client = OpenAI( 

311 base_url="https://llm.scads.ai/v1", api_key=self._scads_llm_api_key 

312 ) 

313 

314 response = client.chat.completions.create( 

315 messages=[{"role": "user", "content": question}], model=model_name 

316 ) 

317 

318 return response.choices[0].message.content