Coverage for intelligence_toolkit/AI/base_chat.py: 100%

33 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4 

5 

6import asyncio 

7 

8from tqdm.asyncio import tqdm_asyncio 

9 

10from intelligence_toolkit.AI.base_batch_async import BaseBatchAsync 

11from intelligence_toolkit.AI.client import OpenAIClient 

12from intelligence_toolkit.AI.defaults import DEFAULT_CONCURRENT_COROUTINES 

13from intelligence_toolkit.helpers.decorators import retry_with_backoff 

14from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback 

15 

16 

17class BaseChat(BaseBatchAsync, OpenAIClient): 

18 def __init__( 

19 self, configuration=None, concurrent_coroutines=DEFAULT_CONCURRENT_COROUTINES 

20 ) -> None: 

21 OpenAIClient.__init__(self, configuration) 

22 self.semaphore = asyncio.Semaphore(concurrent_coroutines) 

23 

24 @retry_with_backoff() 

25 async def generate_text_async(self, messages, callbacks, stream, **llm_kwargs): 

26 async with self.semaphore: 

27 try: 

28 chat = await self.generate_chat_async(messages=messages, stream=stream, **llm_kwargs) 

29 if callbacks: 

30 self.progress_callback() 

31 except Exception as e: 

32 print(f"Error validating report: {e}") 

33 msg = f"Problem in OpenAI response. {e}" 

34 raise Exception(msg) from e 

35 return chat 

36 

37 @retry_with_backoff() 

38 async def generate_texts_async( 

39 self, 

40 messages_list: list[list[dict[str, str]]], 

41 callbacks: list[ProgressBatchCallback] | None = None, 

42 **llm_kwargs, 

43 ): 

44 self.total_tasks = len(messages_list) 

45 tasks = [ 

46 asyncio.create_task( 

47 self.generate_text_async(messages, callbacks, False, **llm_kwargs) 

48 ) 

49 for messages in messages_list 

50 ] 

51 if callbacks: 

52 progress_task = asyncio.create_task(self.track_progress(tasks, callbacks)) 

53 result = await tqdm_asyncio.gather(*tasks) 

54 if callbacks: 

55 await progress_task 

56 return result