Coverage for src/dash_ai_chat/dash_ai_chat.py: 68%
235 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-03 20:34 +0300
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-03 20:34 +0300
1import datetime
2import json
3import os
4import unicodedata
5from pathlib import Path
6from typing import Any, Dict, Iterator, List, Optional
8import dash_bootstrap_components as dbc
9from dash import ALL, Dash, Input, Output, State, callback_context, dcc, html, no_update
10from openai import OpenAI
12Dash()
15class DashAIChat(Dash):
16 def __init__(self, base_dir, **kwargs):
17 if "external_stylesheets" not in kwargs:
18 kwargs["external_stylesheets"] = [dbc.themes.BOOTSTRAP, dbc.icons.BOOTSTRAP]
20 assets_path = (Path(__file__).parent / "assets").absolute()
21 if "assets_folder" not in kwargs:
22 kwargs["assets_folder"] = str(assets_path)
24 super().__init__(
25 __name__,
26 **kwargs,
27 )
28 self.required_ids = {
29 "burger_menu",
30 "sidebar_offcanvas",
31 "conversation_list",
32 "url",
33 "chat_area_div",
34 "user_input_textarea",
35 "new_chat_button",
36 }
37 self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
38 self.BASE_DIR = Path(base_dir)
39 self.AI_REGISTRY = {
40 ("openai", "chat.completions"): {
41 "call": lambda messages,
42 model,
43 **kwargs: self.client.chat.completions.create(
44 model=model, messages=messages, **kwargs
45 ),
46 "extract": lambda resp: resp["choices"][0]["message"]["content"],
47 "format_messages": lambda history: [
48 {"role": m["role"], "content": m["content"]} for m in history
49 ],
50 },
51 ("openai", "completions"): {
52 "call": lambda prompt, model, **kwargs: self.client.completions.create(
53 model=model, prompt=prompt, **kwargs
54 ),
55 "extract": lambda resp: resp["choices"][0]["text"],
56 "format_messages": lambda history: "\n".join(
57 f"{m['role']}: {m['content']}" for m in history
58 ),
59 },
60 }
61 self.layout = self.default_layout()
62 self._validate_layout()
63 self._register_callbacks()
64 self._register_clientside_callbacks()
66 # --- Layout Factories ---
67 def sidebar(self):
68 return dbc.Offcanvas(
69 [
70 html.Div(
71 [
72 html.I(className="bi bi-pencil-square icon-new-chat"),
73 " New chat",
74 ],
75 id="new_chat_button",
76 className="new-chat-button",
77 ),
78 html.Div(
79 id="conversation_list",
80 children=[],
81 className="conversation-list",
82 ),
83 ],
84 id="sidebar_offcanvas",
85 is_open=False,
86 backdrop=False,
87 placement="start",
88 className="sidebar-offcanvas",
89 )
91 def chat_area(self):
92 return html.Div(
93 id="chat_area_div",
94 children=[],
95 className="chat-area-div",
96 )
98 def input_area(self):
99 return html.Div(
100 [
101 dbc.Textarea(
102 id="user_input_textarea",
103 placeholder="Ask...",
104 rows=4,
105 autoFocus=True,
106 className="form-control user-input-textarea",
107 ),
108 ]
109 )
111 def default_layout(self):
112 return html.Div(
113 [
114 html.Button(
115 "☰",
116 id="burger_menu",
117 className="burger-menu",
118 ),
119 self.sidebar(),
120 dcc.Location(id="url", refresh=False),
121 html.Div(
122 [
123 html.Br(),
124 html.Div(
125 [
126 html.Div(
127 [
128 dcc.Loading(
129 self.chat_area(),
130 type="circle",
131 overlay_style={
132 "visibility": "visible",
133 "filter": "blur(0.7px)",
134 },
135 )
136 ],
137 className="col",
138 )
139 ],
140 className="row",
141 ),
142 html.Div(
143 [html.Div([self.input_area()], className="col")],
144 className="row",
145 ),
146 ],
147 className="container main-container",
148 ),
149 ]
150 )
152 def _validate_layout(self):
153 def collect_ids(component):
154 ids = set()
155 if hasattr(component, "id") and component.id:
156 ids.add(component.id)
157 if hasattr(component, "children"):
158 children = component.children
159 if isinstance(children, list):
160 for child in children:
161 ids |= collect_ids(child)
162 elif children is not None:
163 ids |= collect_ids(children)
164 return ids
166 ids = collect_ids(self.layout)
167 missing = self.required_ids - ids
168 if missing:
169 raise ValueError(
170 f"The following required component IDs are missing from the layout: {missing}"
171 )
173 def set_layout(self, layout):
174 self.layout = layout
175 self._validate_layout()
177 # --- Engine Methods (to be overridden as needed) ---
178 def load_messages(self, user_id: str, conversation_id: str) -> List[Dict]:
179 path = self._get_convo_dir(user_id, conversation_id) / "messages.json"
180 return self._read_json(path) if path.exists() else []
182 def save_messages(
183 self, user_id: str, conversation_id: str, messages: List[Dict]
184 ) -> None:
185 path = self._ensure_convo_dir(user_id, conversation_id) / "messages.json"
186 self._write_json(path, messages)
188 def add_message(self, user_id: str, conversation_id: str, message: Dict) -> None:
189 messages = self.load_messages(user_id, conversation_id)
190 messages.append(message)
191 self.save_messages(user_id, conversation_id, messages)
193 def append_raw_response(
194 self, user_id: str, conversation_id: str, response: Dict
195 ) -> None:
196 path = (
197 self._ensure_convo_dir(user_id, conversation_id) / "raw_api_responses.jsonl"
198 )
199 self._append_jsonl(path, response)
201 def load_metadata(self, user_id: str, conversation_id: str) -> Dict:
202 path = self._get_convo_dir(user_id, conversation_id) / "metadata.json"
203 return self._read_json(path) if path.exists() else {}
205 def save_metadata(self, user_id: str, conversation_id: str, metadata: Dict) -> None:
206 path = self._ensure_convo_dir(user_id, conversation_id) / "metadata.json"
207 self._write_json(path, metadata)
209 def list_users(self) -> List[str]:
210 return sorted([p.name for p in self.BASE_DIR.iterdir() if p.is_dir()])
212 def list_conversations(self, user_id: str) -> List[str]:
213 user_dir = self._get_user_dir(user_id)
214 if not user_dir.exists():
215 return []
216 return sorted([p.name for p in user_dir.iterdir() if p.is_dir()])
218 def get_conversation_titles(self, user_id: str) -> List[Dict[str, str]]:
219 conversations = self.list_conversations(user_id)
220 result = []
221 for convo_id in conversations:
222 messages = self.load_messages(user_id, convo_id)
223 if messages:
224 first_message = messages[0].get("content", "")
225 title = (
226 first_message[:30] + "..."
227 if len(first_message) > 30
228 else first_message
229 )
230 title = title.capitalize() if title else ""
231 result.append({"id": convo_id, "title": title})
232 return result
234 def get_last_convo_id(self, user_id: str) -> Optional[str]:
235 conversations = self.list_conversations(user_id)
236 return conversations[-1] if conversations else self.get_next_convo_id(user_id)
238 def get_next_convo_id(self, user_id: str) -> str:
239 user_dir = self._get_user_dir(user_id)
240 if not user_dir.exists():
241 return "001"
242 existing_ids = [
243 int(p.name) for p in user_dir.iterdir() if p.is_dir() and p.name.isdigit()
244 ]
245 if not existing_ids:
246 return "001"
247 next_id = max(existing_ids) + 1
248 return f"{next_id:03d}"
250 def fetch_ai_response(
251 self,
252 messages: List[Dict],
253 model: str,
254 provider: str = "openai",
255 endpoint: str = "chat.completions",
256 **kwargs,
257 ) -> Dict:
258 key = (provider, endpoint)
259 if key not in self.AI_REGISTRY:
260 raise ValueError(f"Unknown provider/endpoint: {provider}/{endpoint}")
261 if not model:
262 raise ValueError("Model must be specified explicitly.")
263 call_fn = self.AI_REGISTRY[key]["call"]
264 format_fn = self.AI_REGISTRY[key]["format_messages"]
265 resp = call_fn(format_fn(messages), model, **kwargs)
266 return resp.model_dump() if hasattr(resp, "model_dump") else resp
268 def extract_assistant_content(
269 self,
270 raw_response: Dict,
271 provider: str = "openai",
272 endpoint: str = "chat.completions",
273 ) -> str:
274 key = (provider, endpoint)
275 if key not in self.AI_REGISTRY:
276 raise ValueError(f"Unknown provider/endpoint: {provider}/{endpoint}")
277 extract_fn = self.AI_REGISTRY[key]["extract"]
278 return extract_fn(raw_response)
280 def update_convo(
281 self,
282 user_id: str,
283 user_message: str,
284 convo_id: Optional[str] = None,
285 provider: str = "openai",
286 endpoint: str = "chat.completions",
287 ) -> str:
288 convo_id = convo_id or self.get_next_convo_id(user_id)
289 user_msg = {"role": "user", "content": user_message}
290 self.add_message(user_id, convo_id, user_msg)
291 history = self.load_messages(user_id, convo_id)
292 raw_response = self.fetch_ai_response(
293 history,
294 model="gpt-4o",
295 provider=provider,
296 endpoint=endpoint,
297 )
298 self.append_raw_response(user_id, convo_id, raw_response)
299 assistant_content = self.extract_assistant_content(
300 raw_response, provider, endpoint
301 )
302 assistant_msg = {"role": "assistant", "content": assistant_content}
303 self.add_message(user_id, convo_id, assistant_msg)
304 return convo_id
306 def _register_callbacks(self):
307 @self.callback(
308 Output("chat_area_div", "children"),
309 Output("user_input_textarea", "value"),
310 Input("user_input_textarea", "n_submit"),
311 Input("url", "pathname"),
312 State("user_input_textarea", "value"),
313 )
314 def handle_user_input(n_submit, pathname, value):
315 import uuid
317 segments = (pathname or "/").strip("/").split("/")
318 user_id = segments[0] if segments and segments[0] else str(uuid.uuid4())[:5]
319 convo_id = segments[1] if len(segments) > 1 and segments[1] else None
320 engine_user_id = f"chat_data/{user_id}"
321 if not convo_id:
322 convo_id = self.get_next_convo_id(engine_user_id)
323 if n_submit and value:
324 self.update_convo(
325 user_id=engine_user_id, user_message=value, convo_id=convo_id
326 )
327 messages = self.load_messages(engine_user_id, convo_id)
328 return self.format_messages(messages), ""
330 @self.callback(
331 Output("sidebar_offcanvas", "is_open"),
332 Output("url", "pathname", allow_duplicate=True),
333 Input("burger_menu", "n_clicks"),
334 Input({"type": "conversation-item", "index": ALL}, "n_clicks"),
335 State("sidebar_offcanvas", "is_open"),
336 State("url", "pathname"),
337 prevent_initial_call=True,
338 )
339 def toggle_offcanvas_and_navigate(
340 burger_clicks, convo_clicks, is_open, current_pathname
341 ):
342 ctx = callback_context
343 if not ctx.triggered:
344 return no_update, no_update
345 trigger_id = ctx.triggered[0]["prop_id"]
346 if "conversation-item" in trigger_id and any(convo_clicks):
347 clicked_index = next(
348 i for i, clicks in enumerate(convo_clicks) if clicks and clicks > 0
349 )
350 convo_id = f"{clicked_index + 1:03d}"
351 import re
353 new_path = re.sub(r"/[^/]*$", f"/{convo_id}", current_pathname or "/")
354 return False, new_path
355 if "burger_menu" in trigger_id and burger_clicks:
356 return not is_open, no_update
357 return no_update, no_update
359 @self.callback(
360 Output("conversation_list", "children"),
361 Input("url", "pathname"),
362 )
363 def update_conversation_list(pathname):
364 import uuid
366 if not pathname:
367 return []
368 segments = pathname.strip("/").split("/")
369 user_id = segments[0] if segments and segments[0] else str(uuid.uuid4())[:5]
370 engine_user_id = f"chat_data/{user_id}"
371 conversations = self.get_conversation_titles(engine_user_id)
372 if not conversations:
373 return []
374 conversation_items = []
375 for convo in conversations:
376 conversation_items.append(
377 html.Div(
378 convo["title"],
379 id={"type": "conversation-item", "index": convo["id"]},
380 style={
381 "cursor": "pointer",
382 "padding": "0.5rem",
383 "margin-bottom": "0.25rem",
384 "border-radius": "5px",
385 },
386 )
387 )
388 return conversation_items
390 @self.callback(
391 Output("url", "pathname", allow_duplicate=True),
392 Output("sidebar_offcanvas", "is_open", allow_duplicate=True),
393 Input("new_chat_button", "n_clicks"),
394 State("url", "pathname"),
395 prevent_initial_call=True,
396 )
397 def handle_new_chat(n_clicks, current_pathname):
398 if n_clicks:
399 import uuid
401 segments = (current_pathname or "/").strip("/").split("/")
402 user_id = (
403 segments[0] if segments and segments[0] else str(uuid.uuid4())[:5]
404 )
406 engine_user_id = f"chat_data/{user_id}"
407 next_convo_id = self.get_next_convo_id(engine_user_id)
408 new_path = f"/{user_id}/{next_convo_id}"
410 return new_path, False
411 return no_update, no_update
413 def _register_clientside_callbacks(self):
414 self.clientside_callback(
415 """
416 function(chat_content) {
417 if (chat_content && chat_content.length > 0) {
418 setTimeout(() => {
419 const chatArea = document.getElementById('chat_area_div');
420 if (chatArea) {
421 chatArea.scrollTop = chatArea.scrollHeight;
422 }
423 }, 100);
424 }
425 return window.dash_clientside.no_update;
426 }
427 """,
428 Output("chat_area_div", "data-scroll-trigger", allow_duplicate=True),
429 Input("chat_area_div", "children"),
430 prevent_initial_call=True,
431 )
432 self.clientside_callback(
433 """
434 function(textarea_value) {
435 const textarea = document.getElementById('user_input_textarea');
436 if (textarea && textarea_value) {
437 const rtlPattern = '[\u0590-\u05ff\u0600-\u06ff\u0750-\u077f' +
438 '\u08a0-\u08ff\ufb1d-\ufb4f\ufb50-\ufdff\ufe70-\ufeff]';
439 const rtlRegex = new RegExp(rtlPattern);
440 const isRTL = rtlRegex.test(textarea_value);
441 textarea.style.direction = isRTL ? 'rtl' : 'ltr';
442 textarea.style.textAlign = isRTL ? 'right' : 'left';
443 }
444 return window.dash_clientside.no_update;
445 }
446 """,
447 Output("user_input_textarea", "title", allow_duplicate=True),
448 Input("user_input_textarea", "value"),
449 prevent_initial_call=True,
450 )
452 # --- RTL Detection ---
453 def _is_rtl(self, text):
454 if not text or not text.strip():
455 return False
456 for char in text:
457 bidi = unicodedata.bidirectional(char)
458 if bidi in ("R", "AL"):
459 return True
460 elif bidi == "L":
461 return False
462 return False
464 # --- Time ---
465 def _now(self):
466 return datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S")
468 # --- Path Operations ---
469 def _get_user_dir(self, user_id: str) -> Path:
470 return self.BASE_DIR / user_id
472 def _get_convo_dir(self, user_id: str, conversation_id: str) -> Path:
473 return self._get_user_dir(user_id) / conversation_id
475 def _ensure_convo_dir(self, user_id: str, conversation_id: str) -> Path:
476 path = self._get_convo_dir(user_id, conversation_id)
477 path.mkdir(parents=True, exist_ok=True)
478 return path
480 # --- File I/O ---
481 def _read_json(self, path: Path) -> Any:
482 with open(path, "r", encoding="utf-8") as f:
483 return json.load(f)
485 def _read_jsonl(self, path: Path) -> Iterator[Dict]:
486 with open(path, "r", encoding="utf-8") as f:
487 for line in f:
488 line = line.strip()
489 if line:
490 yield json.loads(line)
492 def _write_json(self, path: Path, data: Any) -> None:
493 with open(path, "w", encoding="utf-8") as f:
494 json.dump(data, f, ensure_ascii=False, indent=2)
496 def _append_jsonl(self, path: Path, entry: Dict) -> None:
497 with open(path, "a", encoding="utf-8") as f:
498 f.write(json.dumps(entry, ensure_ascii=False) + "\n")
500 def format_messages(self, messages):
501 if not messages:
502 return []
503 formatted = []
504 current_msg_direction = "ltr"
505 for i, msg in enumerate(messages):
506 if msg["role"] == "user":
507 current_msg_direction = "rtl" if self._is_rtl(msg["content"]) else "ltr"
508 formatted.append(
509 html.Div(
510 [
511 dcc.Markdown(
512 msg["content"],
513 id=f"user-msg-{i}",
514 style={
515 "text-align": "right",
516 "width": "80%",
517 "margin-left": "auto",
518 "padding": "0.3em 0.3em",
519 "background-color": "var(--bs-light)",
520 "border-radius": "15px",
521 },
522 ),
523 html.Div(
524 [
525 dcc.Clipboard(
526 content=msg["content"],
527 id=f"clipboard-user-{i}",
528 style={"cursor": "pointer"},
529 ),
530 dbc.Tooltip(
531 "Copy",
532 target=f"clipboard-user-{i}",
533 placement="top",
534 ),
535 ],
536 style={
537 "text-align": "right",
538 "width": "4%",
539 "margin-left": "auto",
540 },
541 ),
542 ],
543 dir=current_msg_direction,
544 )
545 )
546 elif msg["role"] == "assistant":
547 formatted.append(
548 html.Div(
549 [
550 dcc.Markdown(
551 msg["content"],
552 id=f"assistant-msg-{i}",
553 className="table table-striped table-hover",
554 ),
555 html.Div(
556 [
557 dcc.Clipboard(
558 content=msg["content"],
559 id=f"clipboard-assistant-{i}",
560 style={"cursor": "pointer"},
561 ),
562 dbc.Tooltip(
563 "Copy",
564 target=f"clipboard-assistant-{i}",
565 placement="top",
566 ),
567 ],
568 style={"width": "4%"},
569 ),
570 ],
571 dir=current_msg_direction,
572 )
573 )
574 return formatted