Coverage for intelligence_toolkit/AI/utils.py: 95%
64 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
5"""Utility functions for the OpenAI API."""
7import hashlib
8import json
9import logging
10from typing import Any
12import tiktoken
14from intelligence_toolkit.AI.base_chat import BaseChat
15from intelligence_toolkit.AI.client import OpenAIClient
16from intelligence_toolkit.AI.defaults import DEFAULT_ENCODING, DEFAULT_REPORT_BATCH_SIZE
17from intelligence_toolkit.AI.validation_prompt import GROUNDEDNESS_PROMPT
18from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback
20log = logging.getLogger(__name__)
23def generate_text(ai_configuration, messages, **kwargs):
24 return OpenAIClient(ai_configuration).generate_chat(
25 messages, stream=False, **kwargs
26 )
29async def generate_text_async(ai_configuration, messages, stream, **kwargs):
30 return await OpenAIClient(ai_configuration).generate_chat_async(messages, stream, **kwargs)
33async def map_generate_text(
34 ai_configuration,
35 messages_list,
36 callbacks: list[ProgressBatchCallback] | None = None,
37 **kwargs,
38):
39 return await BaseChat(ai_configuration).generate_texts_async(
40 messages_list, callbacks, **kwargs
41 )
44def get_token_count(text: str, encoding=None, model=None) -> int:
45 """Function that counts the number of tokens in a string."""
46 encoder = tiktoken.get_encoding(encoding or DEFAULT_ENCODING)
47 if model:
48 try:
49 encoder = tiktoken.encoding_for_model(model)
50 except KeyError:
51 log.warning("model not found, using default encoding: %s", DEFAULT_ENCODING)
52 encoder = tiktoken.get_encoding(DEFAULT_ENCODING)
53 return len(encoder.encode(json.dumps(text)))
56def prepare_messages(
57 system_message: str, variables: dict[str, Any], user_message=None
58) -> list[dict[str, str]]:
59 """Prepare messages for the OpenAI API."""
60 messages = [{"role": "system", "content": system_message.format(**variables)}]
62 if user_message is not None:
63 messages.append({"role": "user", "content": user_message.format(**variables)})
64 return messages
67def prepare_validation(messages: str, ai_report: str) -> list[dict[str, str]]:
68 return [
69 {
70 "role": "system",
71 "content": GROUNDEDNESS_PROMPT.format(
72 instructions=messages, report=ai_report
73 ),
74 }
75 ]
78def try_parse_json_object(input: str) -> dict:
79 """Generate JSON-string output using best-attempt prompting & parsing techniques."""
80 try:
81 result = json.loads(input)
82 except json.JSONDecodeError:
83 log.exception("error loading json, json=%s", input)
84 raise
85 else:
86 if not isinstance(result, dict):
87 raise TypeError
88 return result
91def hash_text(text: str) -> str:
92 """Function that hashes a string."""
93 text = text.replace("\n", " ")
94 return hashlib.sha256(text.encode()).hexdigest()
97def generate_messages(
98 user_prompt, system_prompt, variables, safety_prompt=""
99) -> list[dict[str, str]]:
100 full_prompt = f"{system_prompt} {user_prompt} {safety_prompt}"
102 return prepare_messages(full_prompt, variables)
105def generate_batch_messages(
106 prompt,
107 batch_name,
108 batch_value,
109 variables: dict | None = None,
110 batch_size: int | None = DEFAULT_REPORT_BATCH_SIZE,
111) -> list[dict[str, str]]:
112 if variables is None:
113 variables = {}
115 batch_offset = 0
116 batch_count_raw = len(batch_value) // batch_size
117 batch_count_remaining = len(batch_value) % batch_size
118 batch_count = batch_count_raw + 1 if batch_count_remaining != 0 else batch_count_raw
119 batch_messages = []
121 full_prompt = " ".join(
122 [
123 prompt["report_prompt"],
124 prompt["user_prompt"],
125 prompt["safety_prompt"],
126 ]
127 )
128 for _i in range(batch_count):
129 batch = batch_value[
130 batch_offset : min(batch_offset + batch_size, len(batch_value))
131 ]
132 batch_offset += batch_size
133 batch_variables = dict(variables)
134 batch_variables[batch_name] = batch.to_csv()
135 batch_messages.append(prepare_messages(full_prompt, batch_variables))
137 return batch_messages