Coverage for intelligence_toolkit/AI/openai_configuration.py: 100%
73 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
4import os
6from .defaults import (
7 DEFAULT_AZ_AUTH_TYPE,
8 DEFAULT_EMBEDDING_MODEL,
9 DEFAULT_EMBEDDING_MODEL_AZURE,
10 DEFAULT_LLM_MAX_TOKENS,
11 DEFAULT_LLM_MODEL,
12 DEFAULT_OPENAI_VERSION,
13 DEFAULT_TEMPERATURE,
14)
17def _non_blank(value: str | None) -> str | None:
18 if value is None:
19 return None
20 stripped = value.strip()
21 return None if stripped == "" else value
24class OpenAIConfiguration:
25 """OpenAI Configuration class definition."""
27 # Core Configuration
28 _api_key: str
29 _model: str
31 _api_base: str | None
32 _api_version: str | None
34 _temperature: float | None
35 _max_tokens: int | None
36 _api_type: str
37 _az_auth_type: str
38 _embedding_model: str
40 def __init__(
41 self,
42 config: dict | None = None,
43 ):
44 """Init method definition."""
45 if config is None:
46 config = {}
47 oai_type = self._get_openai_type()
48 self._api_key = config.get("api_key", self._get_api_key())
49 self._model = config.get(
50 "model",
51 self._get_chat_model(),
52 )
53 self._api_base = config.get("api_base", self._get_azure_api_base())
54 self._api_version = config.get("api_version", self._get_azure_openai_version())
55 self._temperature = config.get("temperature", DEFAULT_TEMPERATURE)
56 self._max_tokens = config.get("max_tokens", DEFAULT_LLM_MAX_TOKENS)
57 self._az_auth_type = config.get("az_auth_type", self._get_az_auth_type())
58 self._api_type = config.get("api_type", oai_type)
59 self._embedding_model = config.get(
60 "embedding_model", self._get_embedding_model()
61 )
63 def _get_openai_type(self):
64 return os.environ.get("OPENAI_TYPE", "OpenAI")
66 def _get_az_auth_type(self):
67 return os.environ.get("AZURE_AUTH_TYPE", DEFAULT_AZ_AUTH_TYPE)
69 def _get_azure_openai_version(self):
70 return os.environ.get("AZURE_OPENAI_VERSION", DEFAULT_OPENAI_VERSION)
72 def _get_chat_model(self):
73 return os.environ.get("OPENAI_API_MODEL", DEFAULT_LLM_MODEL)
75 def _get_embedding_model(self):
76 default_embedding_per_type = (
77 DEFAULT_EMBEDDING_MODEL_AZURE
78 if self._api_type == "Azure OpenAI"
79 else DEFAULT_EMBEDDING_MODEL
80 )
81 return os.environ.get("OPENAI_EMBEDDING_MODEL", default_embedding_per_type)
83 def _get_azure_api_base(self):
84 return os.environ.get("AZURE_OPENAI_ENDPOINT", "")
86 def _get_api_key(self):
87 return os.environ.get("OPENAI_API_KEY", "")
89 @property
90 def api_key(self) -> str:
91 """API key property definition."""
92 return self._api_key
94 @property
95 def model(self) -> str:
96 """Model property definition."""
97 return self._model
99 @property
100 def api_base(self) -> str | None:
101 """API base property definition."""
102 result = _non_blank(self._api_base)
103 # Remove trailing slash
104 return result[:-1] if result and result.endswith("/") else result
106 @property
107 def api_version(self) -> str | None:
108 """API version property definition."""
109 return _non_blank(self._api_version)
111 @property
112 def temperature(self) -> float | None:
113 """Temperature property definition."""
114 return self._temperature
116 @property
117 def max_tokens(self) -> int | None:
118 """Max tokens property definition."""
119 return self._max_tokens
121 @property
122 def embedding_model(self) -> str | None:
123 return self._embedding_model
125 @property
126 def api_type(self) -> str | None:
127 """Type of the AI connection."""
128 return self._api_type
130 @property
131 def az_auth_type(self) -> str:
132 """Type of the Azure OpenAI connection."""
133 return self._az_auth_type