Coverage for intelligence_toolkit/AI/base_chat.py: 100%
33 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
6import asyncio
8from tqdm.asyncio import tqdm_asyncio
10from intelligence_toolkit.AI.base_batch_async import BaseBatchAsync
11from intelligence_toolkit.AI.client import OpenAIClient
12from intelligence_toolkit.AI.defaults import DEFAULT_CONCURRENT_COROUTINES
13from intelligence_toolkit.helpers.decorators import retry_with_backoff
14from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback
17class BaseChat(BaseBatchAsync, OpenAIClient):
18 def __init__(
19 self, configuration=None, concurrent_coroutines=DEFAULT_CONCURRENT_COROUTINES
20 ) -> None:
21 OpenAIClient.__init__(self, configuration)
22 self.semaphore = asyncio.Semaphore(concurrent_coroutines)
24 @retry_with_backoff()
25 async def generate_text_async(self, messages, callbacks, stream, **llm_kwargs):
26 async with self.semaphore:
27 try:
28 chat = await self.generate_chat_async(messages=messages, stream=stream, **llm_kwargs)
29 if callbacks:
30 self.progress_callback()
31 except Exception as e:
32 print(f"Error validating report: {e}")
33 msg = f"Problem in OpenAI response. {e}"
34 raise Exception(msg) from e
35 return chat
37 @retry_with_backoff()
38 async def generate_texts_async(
39 self,
40 messages_list: list[list[dict[str, str]]],
41 callbacks: list[ProgressBatchCallback] | None = None,
42 **llm_kwargs,
43 ):
44 self.total_tasks = len(messages_list)
45 tasks = [
46 asyncio.create_task(
47 self.generate_text_async(messages, callbacks, False, **llm_kwargs)
48 )
49 for messages in messages_list
50 ]
51 if callbacks:
52 progress_task = asyncio.create_task(self.track_progress(tasks, callbacks))
53 result = await tqdm_asyncio.gather(*tasks)
54 if callbacks:
55 await progress_task
56 return result