Coverage for src/dash_ai_chat/dash_ai_chat.py: 68%

235 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-08-03 20:34 +0300

1import datetime 

2import json 

3import os 

4import unicodedata 

5from pathlib import Path 

6from typing import Any, Dict, Iterator, List, Optional 

7 

8import dash_bootstrap_components as dbc 

9from dash import ALL, Dash, Input, Output, State, callback_context, dcc, html, no_update 

10from openai import OpenAI 

11 

12Dash() 

13 

14 

15class DashAIChat(Dash): 

16 def __init__(self, base_dir, **kwargs): 

17 if "external_stylesheets" not in kwargs: 

18 kwargs["external_stylesheets"] = [dbc.themes.BOOTSTRAP, dbc.icons.BOOTSTRAP] 

19 

20 assets_path = (Path(__file__).parent / "assets").absolute() 

21 if "assets_folder" not in kwargs: 

22 kwargs["assets_folder"] = str(assets_path) 

23 

24 super().__init__( 

25 __name__, 

26 **kwargs, 

27 ) 

28 self.required_ids = { 

29 "burger_menu", 

30 "sidebar_offcanvas", 

31 "conversation_list", 

32 "url", 

33 "chat_area_div", 

34 "user_input_textarea", 

35 "new_chat_button", 

36 } 

37 self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) 

38 self.BASE_DIR = Path(base_dir) 

39 self.AI_REGISTRY = { 

40 ("openai", "chat.completions"): { 

41 "call": lambda messages, 

42 model, 

43 **kwargs: self.client.chat.completions.create( 

44 model=model, messages=messages, **kwargs 

45 ), 

46 "extract": lambda resp: resp["choices"][0]["message"]["content"], 

47 "format_messages": lambda history: [ 

48 {"role": m["role"], "content": m["content"]} for m in history 

49 ], 

50 }, 

51 ("openai", "completions"): { 

52 "call": lambda prompt, model, **kwargs: self.client.completions.create( 

53 model=model, prompt=prompt, **kwargs 

54 ), 

55 "extract": lambda resp: resp["choices"][0]["text"], 

56 "format_messages": lambda history: "\n".join( 

57 f"{m['role']}: {m['content']}" for m in history 

58 ), 

59 }, 

60 } 

61 self.layout = self.default_layout() 

62 self._validate_layout() 

63 self._register_callbacks() 

64 self._register_clientside_callbacks() 

65 

66 # --- Layout Factories --- 

67 def sidebar(self): 

68 return dbc.Offcanvas( 

69 [ 

70 html.Div( 

71 [ 

72 html.I(className="bi bi-pencil-square icon-new-chat"), 

73 " New chat", 

74 ], 

75 id="new_chat_button", 

76 className="new-chat-button", 

77 ), 

78 html.Div( 

79 id="conversation_list", 

80 children=[], 

81 className="conversation-list", 

82 ), 

83 ], 

84 id="sidebar_offcanvas", 

85 is_open=False, 

86 backdrop=False, 

87 placement="start", 

88 className="sidebar-offcanvas", 

89 ) 

90 

91 def chat_area(self): 

92 return html.Div( 

93 id="chat_area_div", 

94 children=[], 

95 className="chat-area-div", 

96 ) 

97 

98 def input_area(self): 

99 return html.Div( 

100 [ 

101 dbc.Textarea( 

102 id="user_input_textarea", 

103 placeholder="Ask...", 

104 rows=4, 

105 autoFocus=True, 

106 className="form-control user-input-textarea", 

107 ), 

108 ] 

109 ) 

110 

111 def default_layout(self): 

112 return html.Div( 

113 [ 

114 html.Button( 

115 "☰", 

116 id="burger_menu", 

117 className="burger-menu", 

118 ), 

119 self.sidebar(), 

120 dcc.Location(id="url", refresh=False), 

121 html.Div( 

122 [ 

123 html.Br(), 

124 html.Div( 

125 [ 

126 html.Div( 

127 [ 

128 dcc.Loading( 

129 self.chat_area(), 

130 type="circle", 

131 overlay_style={ 

132 "visibility": "visible", 

133 "filter": "blur(0.7px)", 

134 }, 

135 ) 

136 ], 

137 className="col", 

138 ) 

139 ], 

140 className="row", 

141 ), 

142 html.Div( 

143 [html.Div([self.input_area()], className="col")], 

144 className="row", 

145 ), 

146 ], 

147 className="container main-container", 

148 ), 

149 ] 

150 ) 

151 

152 def _validate_layout(self): 

153 def collect_ids(component): 

154 ids = set() 

155 if hasattr(component, "id") and component.id: 

156 ids.add(component.id) 

157 if hasattr(component, "children"): 

158 children = component.children 

159 if isinstance(children, list): 

160 for child in children: 

161 ids |= collect_ids(child) 

162 elif children is not None: 

163 ids |= collect_ids(children) 

164 return ids 

165 

166 ids = collect_ids(self.layout) 

167 missing = self.required_ids - ids 

168 if missing: 

169 raise ValueError( 

170 f"The following required component IDs are missing from the layout: {missing}" 

171 ) 

172 

173 def set_layout(self, layout): 

174 self.layout = layout 

175 self._validate_layout() 

176 

177 # --- Engine Methods (to be overridden as needed) --- 

178 def load_messages(self, user_id: str, conversation_id: str) -> List[Dict]: 

179 path = self._get_convo_dir(user_id, conversation_id) / "messages.json" 

180 return self._read_json(path) if path.exists() else [] 

181 

182 def save_messages( 

183 self, user_id: str, conversation_id: str, messages: List[Dict] 

184 ) -> None: 

185 path = self._ensure_convo_dir(user_id, conversation_id) / "messages.json" 

186 self._write_json(path, messages) 

187 

188 def add_message(self, user_id: str, conversation_id: str, message: Dict) -> None: 

189 messages = self.load_messages(user_id, conversation_id) 

190 messages.append(message) 

191 self.save_messages(user_id, conversation_id, messages) 

192 

193 def append_raw_response( 

194 self, user_id: str, conversation_id: str, response: Dict 

195 ) -> None: 

196 path = ( 

197 self._ensure_convo_dir(user_id, conversation_id) / "raw_api_responses.jsonl" 

198 ) 

199 self._append_jsonl(path, response) 

200 

201 def load_metadata(self, user_id: str, conversation_id: str) -> Dict: 

202 path = self._get_convo_dir(user_id, conversation_id) / "metadata.json" 

203 return self._read_json(path) if path.exists() else {} 

204 

205 def save_metadata(self, user_id: str, conversation_id: str, metadata: Dict) -> None: 

206 path = self._ensure_convo_dir(user_id, conversation_id) / "metadata.json" 

207 self._write_json(path, metadata) 

208 

209 def list_users(self) -> List[str]: 

210 return sorted([p.name for p in self.BASE_DIR.iterdir() if p.is_dir()]) 

211 

212 def list_conversations(self, user_id: str) -> List[str]: 

213 user_dir = self._get_user_dir(user_id) 

214 if not user_dir.exists(): 

215 return [] 

216 return sorted([p.name for p in user_dir.iterdir() if p.is_dir()]) 

217 

218 def get_conversation_titles(self, user_id: str) -> List[Dict[str, str]]: 

219 conversations = self.list_conversations(user_id) 

220 result = [] 

221 for convo_id in conversations: 

222 messages = self.load_messages(user_id, convo_id) 

223 if messages: 

224 first_message = messages[0].get("content", "") 

225 title = ( 

226 first_message[:30] + "..." 

227 if len(first_message) > 30 

228 else first_message 

229 ) 

230 title = title.capitalize() if title else "" 

231 result.append({"id": convo_id, "title": title}) 

232 return result 

233 

234 def get_last_convo_id(self, user_id: str) -> Optional[str]: 

235 conversations = self.list_conversations(user_id) 

236 return conversations[-1] if conversations else self.get_next_convo_id(user_id) 

237 

238 def get_next_convo_id(self, user_id: str) -> str: 

239 user_dir = self._get_user_dir(user_id) 

240 if not user_dir.exists(): 

241 return "001" 

242 existing_ids = [ 

243 int(p.name) for p in user_dir.iterdir() if p.is_dir() and p.name.isdigit() 

244 ] 

245 if not existing_ids: 

246 return "001" 

247 next_id = max(existing_ids) + 1 

248 return f"{next_id:03d}" 

249 

250 def fetch_ai_response( 

251 self, 

252 messages: List[Dict], 

253 model: str, 

254 provider: str = "openai", 

255 endpoint: str = "chat.completions", 

256 **kwargs, 

257 ) -> Dict: 

258 key = (provider, endpoint) 

259 if key not in self.AI_REGISTRY: 

260 raise ValueError(f"Unknown provider/endpoint: {provider}/{endpoint}") 

261 if not model: 

262 raise ValueError("Model must be specified explicitly.") 

263 call_fn = self.AI_REGISTRY[key]["call"] 

264 format_fn = self.AI_REGISTRY[key]["format_messages"] 

265 resp = call_fn(format_fn(messages), model, **kwargs) 

266 return resp.model_dump() if hasattr(resp, "model_dump") else resp 

267 

268 def extract_assistant_content( 

269 self, 

270 raw_response: Dict, 

271 provider: str = "openai", 

272 endpoint: str = "chat.completions", 

273 ) -> str: 

274 key = (provider, endpoint) 

275 if key not in self.AI_REGISTRY: 

276 raise ValueError(f"Unknown provider/endpoint: {provider}/{endpoint}") 

277 extract_fn = self.AI_REGISTRY[key]["extract"] 

278 return extract_fn(raw_response) 

279 

280 def update_convo( 

281 self, 

282 user_id: str, 

283 user_message: str, 

284 convo_id: Optional[str] = None, 

285 provider: str = "openai", 

286 endpoint: str = "chat.completions", 

287 ) -> str: 

288 convo_id = convo_id or self.get_next_convo_id(user_id) 

289 user_msg = {"role": "user", "content": user_message} 

290 self.add_message(user_id, convo_id, user_msg) 

291 history = self.load_messages(user_id, convo_id) 

292 raw_response = self.fetch_ai_response( 

293 history, 

294 model="gpt-4o", 

295 provider=provider, 

296 endpoint=endpoint, 

297 ) 

298 self.append_raw_response(user_id, convo_id, raw_response) 

299 assistant_content = self.extract_assistant_content( 

300 raw_response, provider, endpoint 

301 ) 

302 assistant_msg = {"role": "assistant", "content": assistant_content} 

303 self.add_message(user_id, convo_id, assistant_msg) 

304 return convo_id 

305 

306 def _register_callbacks(self): 

307 @self.callback( 

308 Output("chat_area_div", "children"), 

309 Output("user_input_textarea", "value"), 

310 Input("user_input_textarea", "n_submit"), 

311 Input("url", "pathname"), 

312 State("user_input_textarea", "value"), 

313 ) 

314 def handle_user_input(n_submit, pathname, value): 

315 import uuid 

316 

317 segments = (pathname or "/").strip("/").split("/") 

318 user_id = segments[0] if segments and segments[0] else str(uuid.uuid4())[:5] 

319 convo_id = segments[1] if len(segments) > 1 and segments[1] else None 

320 engine_user_id = f"chat_data/{user_id}" 

321 if not convo_id: 

322 convo_id = self.get_next_convo_id(engine_user_id) 

323 if n_submit and value: 

324 self.update_convo( 

325 user_id=engine_user_id, user_message=value, convo_id=convo_id 

326 ) 

327 messages = self.load_messages(engine_user_id, convo_id) 

328 return self.format_messages(messages), "" 

329 

330 @self.callback( 

331 Output("sidebar_offcanvas", "is_open"), 

332 Output("url", "pathname", allow_duplicate=True), 

333 Input("burger_menu", "n_clicks"), 

334 Input({"type": "conversation-item", "index": ALL}, "n_clicks"), 

335 State("sidebar_offcanvas", "is_open"), 

336 State("url", "pathname"), 

337 prevent_initial_call=True, 

338 ) 

339 def toggle_offcanvas_and_navigate( 

340 burger_clicks, convo_clicks, is_open, current_pathname 

341 ): 

342 ctx = callback_context 

343 if not ctx.triggered: 

344 return no_update, no_update 

345 trigger_id = ctx.triggered[0]["prop_id"] 

346 if "conversation-item" in trigger_id and any(convo_clicks): 

347 clicked_index = next( 

348 i for i, clicks in enumerate(convo_clicks) if clicks and clicks > 0 

349 ) 

350 convo_id = f"{clicked_index + 1:03d}" 

351 import re 

352 

353 new_path = re.sub(r"/[^/]*$", f"/{convo_id}", current_pathname or "/") 

354 return False, new_path 

355 if "burger_menu" in trigger_id and burger_clicks: 

356 return not is_open, no_update 

357 return no_update, no_update 

358 

359 @self.callback( 

360 Output("conversation_list", "children"), 

361 Input("url", "pathname"), 

362 ) 

363 def update_conversation_list(pathname): 

364 import uuid 

365 

366 if not pathname: 

367 return [] 

368 segments = pathname.strip("/").split("/") 

369 user_id = segments[0] if segments and segments[0] else str(uuid.uuid4())[:5] 

370 engine_user_id = f"chat_data/{user_id}" 

371 conversations = self.get_conversation_titles(engine_user_id) 

372 if not conversations: 

373 return [] 

374 conversation_items = [] 

375 for convo in conversations: 

376 conversation_items.append( 

377 html.Div( 

378 convo["title"], 

379 id={"type": "conversation-item", "index": convo["id"]}, 

380 style={ 

381 "cursor": "pointer", 

382 "padding": "0.5rem", 

383 "margin-bottom": "0.25rem", 

384 "border-radius": "5px", 

385 }, 

386 ) 

387 ) 

388 return conversation_items 

389 

390 @self.callback( 

391 Output("url", "pathname", allow_duplicate=True), 

392 Output("sidebar_offcanvas", "is_open", allow_duplicate=True), 

393 Input("new_chat_button", "n_clicks"), 

394 State("url", "pathname"), 

395 prevent_initial_call=True, 

396 ) 

397 def handle_new_chat(n_clicks, current_pathname): 

398 if n_clicks: 

399 import uuid 

400 

401 segments = (current_pathname or "/").strip("/").split("/") 

402 user_id = ( 

403 segments[0] if segments and segments[0] else str(uuid.uuid4())[:5] 

404 ) 

405 

406 engine_user_id = f"chat_data/{user_id}" 

407 next_convo_id = self.get_next_convo_id(engine_user_id) 

408 new_path = f"/{user_id}/{next_convo_id}" 

409 

410 return new_path, False 

411 return no_update, no_update 

412 

413 def _register_clientside_callbacks(self): 

414 self.clientside_callback( 

415 """ 

416 function(chat_content) { 

417 if (chat_content && chat_content.length > 0) { 

418 setTimeout(() => { 

419 const chatArea = document.getElementById('chat_area_div'); 

420 if (chatArea) { 

421 chatArea.scrollTop = chatArea.scrollHeight; 

422 } 

423 }, 100); 

424 } 

425 return window.dash_clientside.no_update; 

426 } 

427 """, 

428 Output("chat_area_div", "data-scroll-trigger", allow_duplicate=True), 

429 Input("chat_area_div", "children"), 

430 prevent_initial_call=True, 

431 ) 

432 self.clientside_callback( 

433 """ 

434 function(textarea_value) { 

435 const textarea = document.getElementById('user_input_textarea'); 

436 if (textarea && textarea_value) { 

437 const rtlPattern = '[\u0590-\u05ff\u0600-\u06ff\u0750-\u077f' + 

438 '\u08a0-\u08ff\ufb1d-\ufb4f\ufb50-\ufdff\ufe70-\ufeff]'; 

439 const rtlRegex = new RegExp(rtlPattern); 

440 const isRTL = rtlRegex.test(textarea_value); 

441 textarea.style.direction = isRTL ? 'rtl' : 'ltr'; 

442 textarea.style.textAlign = isRTL ? 'right' : 'left'; 

443 } 

444 return window.dash_clientside.no_update; 

445 } 

446 """, 

447 Output("user_input_textarea", "title", allow_duplicate=True), 

448 Input("user_input_textarea", "value"), 

449 prevent_initial_call=True, 

450 ) 

451 

452 # --- RTL Detection --- 

453 def _is_rtl(self, text): 

454 if not text or not text.strip(): 

455 return False 

456 for char in text: 

457 bidi = unicodedata.bidirectional(char) 

458 if bidi in ("R", "AL"): 

459 return True 

460 elif bidi == "L": 

461 return False 

462 return False 

463 

464 # --- Time --- 

465 def _now(self): 

466 return datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S") 

467 

468 # --- Path Operations --- 

469 def _get_user_dir(self, user_id: str) -> Path: 

470 return self.BASE_DIR / user_id 

471 

472 def _get_convo_dir(self, user_id: str, conversation_id: str) -> Path: 

473 return self._get_user_dir(user_id) / conversation_id 

474 

475 def _ensure_convo_dir(self, user_id: str, conversation_id: str) -> Path: 

476 path = self._get_convo_dir(user_id, conversation_id) 

477 path.mkdir(parents=True, exist_ok=True) 

478 return path 

479 

480 # --- File I/O --- 

481 def _read_json(self, path: Path) -> Any: 

482 with open(path, "r", encoding="utf-8") as f: 

483 return json.load(f) 

484 

485 def _read_jsonl(self, path: Path) -> Iterator[Dict]: 

486 with open(path, "r", encoding="utf-8") as f: 

487 for line in f: 

488 line = line.strip() 

489 if line: 

490 yield json.loads(line) 

491 

492 def _write_json(self, path: Path, data: Any) -> None: 

493 with open(path, "w", encoding="utf-8") as f: 

494 json.dump(data, f, ensure_ascii=False, indent=2) 

495 

496 def _append_jsonl(self, path: Path, entry: Dict) -> None: 

497 with open(path, "a", encoding="utf-8") as f: 

498 f.write(json.dumps(entry, ensure_ascii=False) + "\n") 

499 

500 def format_messages(self, messages): 

501 if not messages: 

502 return [] 

503 formatted = [] 

504 current_msg_direction = "ltr" 

505 for i, msg in enumerate(messages): 

506 if msg["role"] == "user": 

507 current_msg_direction = "rtl" if self._is_rtl(msg["content"]) else "ltr" 

508 formatted.append( 

509 html.Div( 

510 [ 

511 dcc.Markdown( 

512 msg["content"], 

513 id=f"user-msg-{i}", 

514 style={ 

515 "text-align": "right", 

516 "width": "80%", 

517 "margin-left": "auto", 

518 "padding": "0.3em 0.3em", 

519 "background-color": "var(--bs-light)", 

520 "border-radius": "15px", 

521 }, 

522 ), 

523 html.Div( 

524 [ 

525 dcc.Clipboard( 

526 content=msg["content"], 

527 id=f"clipboard-user-{i}", 

528 style={"cursor": "pointer"}, 

529 ), 

530 dbc.Tooltip( 

531 "Copy", 

532 target=f"clipboard-user-{i}", 

533 placement="top", 

534 ), 

535 ], 

536 style={ 

537 "text-align": "right", 

538 "width": "4%", 

539 "margin-left": "auto", 

540 }, 

541 ), 

542 ], 

543 dir=current_msg_direction, 

544 ) 

545 ) 

546 elif msg["role"] == "assistant": 

547 formatted.append( 

548 html.Div( 

549 [ 

550 dcc.Markdown( 

551 msg["content"], 

552 id=f"assistant-msg-{i}", 

553 className="table table-striped table-hover", 

554 ), 

555 html.Div( 

556 [ 

557 dcc.Clipboard( 

558 content=msg["content"], 

559 id=f"clipboard-assistant-{i}", 

560 style={"cursor": "pointer"}, 

561 ), 

562 dbc.Tooltip( 

563 "Copy", 

564 target=f"clipboard-assistant-{i}", 

565 placement="top", 

566 ), 

567 ], 

568 style={"width": "4%"}, 

569 ), 

570 ], 

571 dir=current_msg_direction, 

572 ) 

573 ) 

574 return formatted