Coverage for intelligence_toolkit/AI/openai_configuration.py: 100%

73 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4import os 

5 

6from .defaults import ( 

7 DEFAULT_AZ_AUTH_TYPE, 

8 DEFAULT_EMBEDDING_MODEL, 

9 DEFAULT_EMBEDDING_MODEL_AZURE, 

10 DEFAULT_LLM_MAX_TOKENS, 

11 DEFAULT_LLM_MODEL, 

12 DEFAULT_OPENAI_VERSION, 

13 DEFAULT_TEMPERATURE, 

14) 

15 

16 

17def _non_blank(value: str | None) -> str | None: 

18 if value is None: 

19 return None 

20 stripped = value.strip() 

21 return None if stripped == "" else value 

22 

23 

24class OpenAIConfiguration: 

25 """OpenAI Configuration class definition.""" 

26 

27 # Core Configuration 

28 _api_key: str 

29 _model: str 

30 

31 _api_base: str | None 

32 _api_version: str | None 

33 

34 _temperature: float | None 

35 _max_tokens: int | None 

36 _api_type: str 

37 _az_auth_type: str 

38 _embedding_model: str 

39 

40 def __init__( 

41 self, 

42 config: dict | None = None, 

43 ): 

44 """Init method definition.""" 

45 if config is None: 

46 config = {} 

47 oai_type = self._get_openai_type() 

48 self._api_key = config.get("api_key", self._get_api_key()) 

49 self._model = config.get( 

50 "model", 

51 self._get_chat_model(), 

52 ) 

53 self._api_base = config.get("api_base", self._get_azure_api_base()) 

54 self._api_version = config.get("api_version", self._get_azure_openai_version()) 

55 self._temperature = config.get("temperature", DEFAULT_TEMPERATURE) 

56 self._max_tokens = config.get("max_tokens", DEFAULT_LLM_MAX_TOKENS) 

57 self._az_auth_type = config.get("az_auth_type", self._get_az_auth_type()) 

58 self._api_type = config.get("api_type", oai_type) 

59 self._embedding_model = config.get( 

60 "embedding_model", self._get_embedding_model() 

61 ) 

62 

63 def _get_openai_type(self): 

64 return os.environ.get("OPENAI_TYPE", "OpenAI") 

65 

66 def _get_az_auth_type(self): 

67 return os.environ.get("AZURE_AUTH_TYPE", DEFAULT_AZ_AUTH_TYPE) 

68 

69 def _get_azure_openai_version(self): 

70 return os.environ.get("AZURE_OPENAI_VERSION", DEFAULT_OPENAI_VERSION) 

71 

72 def _get_chat_model(self): 

73 return os.environ.get("OPENAI_API_MODEL", DEFAULT_LLM_MODEL) 

74 

75 def _get_embedding_model(self): 

76 default_embedding_per_type = ( 

77 DEFAULT_EMBEDDING_MODEL_AZURE 

78 if self._api_type == "Azure OpenAI" 

79 else DEFAULT_EMBEDDING_MODEL 

80 ) 

81 return os.environ.get("OPENAI_EMBEDDING_MODEL", default_embedding_per_type) 

82 

83 def _get_azure_api_base(self): 

84 return os.environ.get("AZURE_OPENAI_ENDPOINT", "") 

85 

86 def _get_api_key(self): 

87 return os.environ.get("OPENAI_API_KEY", "") 

88 

89 @property 

90 def api_key(self) -> str: 

91 """API key property definition.""" 

92 return self._api_key 

93 

94 @property 

95 def model(self) -> str: 

96 """Model property definition.""" 

97 return self._model 

98 

99 @property 

100 def api_base(self) -> str | None: 

101 """API base property definition.""" 

102 result = _non_blank(self._api_base) 

103 # Remove trailing slash 

104 return result[:-1] if result and result.endswith("/") else result 

105 

106 @property 

107 def api_version(self) -> str | None: 

108 """API version property definition.""" 

109 return _non_blank(self._api_version) 

110 

111 @property 

112 def temperature(self) -> float | None: 

113 """Temperature property definition.""" 

114 return self._temperature 

115 

116 @property 

117 def max_tokens(self) -> int | None: 

118 """Max tokens property definition.""" 

119 return self._max_tokens 

120 

121 @property 

122 def embedding_model(self) -> str | None: 

123 return self._embedding_model 

124 

125 @property 

126 def api_type(self) -> str | None: 

127 """Type of the AI connection.""" 

128 return self._api_type 

129 

130 @property 

131 def az_auth_type(self) -> str: 

132 """Type of the Azure OpenAI connection.""" 

133 return self._az_auth_type