Coverage for intelligence_toolkit/AI/utils.py: 95%

64 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4 

5"""Utility functions for the OpenAI API.""" 

6 

7import hashlib 

8import json 

9import logging 

10from typing import Any 

11 

12import tiktoken 

13 

14from intelligence_toolkit.AI.base_chat import BaseChat 

15from intelligence_toolkit.AI.client import OpenAIClient 

16from intelligence_toolkit.AI.defaults import DEFAULT_ENCODING, DEFAULT_REPORT_BATCH_SIZE 

17from intelligence_toolkit.AI.validation_prompt import GROUNDEDNESS_PROMPT 

18from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback 

19 

20log = logging.getLogger(__name__) 

21 

22 

23def generate_text(ai_configuration, messages, **kwargs): 

24 return OpenAIClient(ai_configuration).generate_chat( 

25 messages, stream=False, **kwargs 

26 ) 

27 

28 

29async def generate_text_async(ai_configuration, messages, stream, **kwargs): 

30 return await OpenAIClient(ai_configuration).generate_chat_async(messages, stream, **kwargs) 

31 

32 

33async def map_generate_text( 

34 ai_configuration, 

35 messages_list, 

36 callbacks: list[ProgressBatchCallback] | None = None, 

37 **kwargs, 

38): 

39 return await BaseChat(ai_configuration).generate_texts_async( 

40 messages_list, callbacks, **kwargs 

41 ) 

42 

43 

44def get_token_count(text: str, encoding=None, model=None) -> int: 

45 """Function that counts the number of tokens in a string.""" 

46 encoder = tiktoken.get_encoding(encoding or DEFAULT_ENCODING) 

47 if model: 

48 try: 

49 encoder = tiktoken.encoding_for_model(model) 

50 except KeyError: 

51 log.warning("model not found, using default encoding: %s", DEFAULT_ENCODING) 

52 encoder = tiktoken.get_encoding(DEFAULT_ENCODING) 

53 return len(encoder.encode(json.dumps(text))) 

54 

55 

56def prepare_messages( 

57 system_message: str, variables: dict[str, Any], user_message=None 

58) -> list[dict[str, str]]: 

59 """Prepare messages for the OpenAI API.""" 

60 messages = [{"role": "system", "content": system_message.format(**variables)}] 

61 

62 if user_message is not None: 

63 messages.append({"role": "user", "content": user_message.format(**variables)}) 

64 return messages 

65 

66 

67def prepare_validation(messages: str, ai_report: str) -> list[dict[str, str]]: 

68 return [ 

69 { 

70 "role": "system", 

71 "content": GROUNDEDNESS_PROMPT.format( 

72 instructions=messages, report=ai_report 

73 ), 

74 } 

75 ] 

76 

77 

78def try_parse_json_object(input: str) -> dict: 

79 """Generate JSON-string output using best-attempt prompting & parsing techniques.""" 

80 try: 

81 result = json.loads(input) 

82 except json.JSONDecodeError: 

83 log.exception("error loading json, json=%s", input) 

84 raise 

85 else: 

86 if not isinstance(result, dict): 

87 raise TypeError 

88 return result 

89 

90 

91def hash_text(text: str) -> str: 

92 """Function that hashes a string.""" 

93 text = text.replace("\n", " ") 

94 return hashlib.sha256(text.encode()).hexdigest() 

95 

96 

97def generate_messages( 

98 user_prompt, system_prompt, variables, safety_prompt="" 

99) -> list[dict[str, str]]: 

100 full_prompt = f"{system_prompt} {user_prompt} {safety_prompt}" 

101 

102 return prepare_messages(full_prompt, variables) 

103 

104 

105def generate_batch_messages( 

106 prompt, 

107 batch_name, 

108 batch_value, 

109 variables: dict | None = None, 

110 batch_size: int | None = DEFAULT_REPORT_BATCH_SIZE, 

111) -> list[dict[str, str]]: 

112 if variables is None: 

113 variables = {} 

114 

115 batch_offset = 0 

116 batch_count_raw = len(batch_value) // batch_size 

117 batch_count_remaining = len(batch_value) % batch_size 

118 batch_count = batch_count_raw + 1 if batch_count_remaining != 0 else batch_count_raw 

119 batch_messages = [] 

120 

121 full_prompt = " ".join( 

122 [ 

123 prompt["report_prompt"], 

124 prompt["user_prompt"], 

125 prompt["safety_prompt"], 

126 ] 

127 ) 

128 for _i in range(batch_count): 

129 batch = batch_value[ 

130 batch_offset : min(batch_offset + batch_size, len(batch_value)) 

131 ] 

132 batch_offset += batch_size 

133 batch_variables = dict(variables) 

134 batch_variables[batch_name] = batch.to_csv() 

135 batch_messages.append(prepare_messages(full_prompt, batch_variables)) 

136 

137 return batch_messages