Coverage for python / weflayr / sdk / mistralai / client.py: 100%

32 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-23 15:28 +0100

1"""Weflayr instrumented wrapper for the Mistral AI SDK. 

2 

3Connector: **Mistral AI** 

4Provider: https://mistral.ai 

5 

6Available connectors 

7-------------------- 

8- :class:`Mistral` — top-level client (drop-in for ``mistralai.Mistral``) 

9 

10Available methods 

11----------------- 

12Via ``client.chat``: 

13 

14- :meth:`Chat.complete` — synchronous chat completion with telemetry 

15- :meth:`Chat.complete_async` — async chat completion with telemetry 

16 

17Example:: 

18 

19 from weflayr.sdk.mistralai.client import Mistral 

20 

21 client = Mistral(api_key="sk-...") 

22 response = client.chat.complete( 

23 model="mistral-small-latest", 

24 messages=[{"role": "user", "content": "Hello"}], 

25 ) 

26""" 

27 

28from __future__ import annotations 

29 

30from typing import Any 

31 

32from mistralai.client import Mistral as _Mistral 

33from mistralai.client.chat import Chat as _Chat 

34 

35from weflayr.sdk.helpers import CLIENT_ID, CLIENT_SECRET, INTAKE_URL, track_async, track_sync 

36 

37 

38def _usage(response) -> dict: 

39 """Extract token usage from a Mistral chat response. 

40 

41 Args: 

42 response: A ``ChatCompletionResponse`` returned by the Mistral SDK. 

43 

44 Returns: 

45 A dict with ``prompt_tokens`` and ``completion_tokens`` (both ``int | None``). 

46 """ 

47 usage = getattr(response, "usage", None) 

48 return { 

49 "prompt_tokens": getattr(usage, "prompt_tokens", None), 

50 "completion_tokens": getattr(usage, "completion_tokens", None), 

51 } 

52 

53 

54class Chat(_Chat): 

55 """Instrumented Mistral chat client. 

56 

57 Args: 

58 *args: Forwarded to the upstream ``mistralai.client.chat.Chat``. 

59 intake_url: Weflayr intake API base URL. 

60 client_id: Client identifier sent in the endpoint path. 

61 bearer_token: Bearer token for the Authorization header. 

62 **kwargs: Forwarded to the upstream ``mistralai.client.chat.Chat``. 

63 """ 

64 

65 def __init__( 

66 self, 

67 *args: Any, 

68 intake_url: str = INTAKE_URL, 

69 client_id: str = CLIENT_ID, 

70 bearer_token: str = CLIENT_SECRET, 

71 **kwargs: Any, 

72 ) -> None: 

73 super().__init__(*args, **kwargs) 

74 self._intake_url = intake_url 

75 self._client_id = client_id 

76 self._bearer_token = bearer_token 

77 

78 def complete(self, **kwargs: Any): 

79 """Send a synchronous chat completion request with telemetry. 

80 

81 Args: 

82 model (str): Mistral model identifier (e.g. ``"mistral-small-latest"``). 

83 messages (list[dict]): Conversation history in OpenAI message format. 

84 tags (dict, optional): Arbitrary key/value metadata. Stripped before the upstream call. 

85 **kwargs: Any other kwargs accepted by ``mistralai`` ``chat.complete()``. 

86 

87 Returns: 

88 ``ChatCompletionResponse``: The upstream Mistral response, unmodified. 

89 """ 

90 tags = kwargs.pop("tags", {}) 

91 return track_sync( 

92 url=self._intake_url, 

93 call="chat.complete", 

94 before={"model": kwargs.get("model"), "message_count": len(kwargs.get("messages", [])), "tags": tags}, 

95 fn=lambda: super(Chat, self).complete(**kwargs), 

96 after_extra=_usage, 

97 client_id=self._client_id, 

98 bearer_token=self._bearer_token, 

99 ) 

100 

101 async def complete_async(self, **kwargs: Any): 

102 """Send an async chat completion request with telemetry. 

103 

104 Args: 

105 model (str): Mistral model identifier (e.g. ``"mistral-small-latest"``). 

106 messages (list[dict]): Conversation history in OpenAI message format. 

107 tags (dict, optional): Arbitrary key/value metadata. Stripped before the upstream call. 

108 **kwargs: Any other kwargs accepted by ``mistralai`` ``chat.complete_async()``. 

109 

110 Returns: 

111 ``ChatCompletionResponse``: The upstream Mistral response, unmodified. 

112 """ 

113 tags = kwargs.pop("tags", {}) 

114 return await track_async( 

115 url=self._intake_url, 

116 call="chat.complete_async", 

117 before={"model": kwargs.get("model"), "message_count": len(kwargs.get("messages", [])), "tags": tags}, 

118 fn=lambda: super(Chat, self).complete_async(**kwargs), 

119 after_extra=_usage, 

120 client_id=self._client_id, 

121 bearer_token=self._bearer_token, 

122 ) 

123 

124 

125class Mistral(_Mistral): 

126 """Drop-in replacement for ``mistralai.client.Mistral`` with Weflayr telemetry. 

127 

128 Args: 

129 api_key (str): Your Mistral API key. 

130 intake_url (str, optional): Weflayr intake API base URL. 

131 client_id (str, optional): Client identifier sent in the endpoint path. 

132 bearer_token (str, optional): Bearer token for the Authorization header. 

133 **kwargs: Forwarded unchanged to ``mistralai.client.Mistral``. 

134 

135 Attributes: 

136 chat (:class:`Chat`): Instrumented chat client. 

137 """ 

138 

139 def __init__( 

140 self, 

141 *args: Any, 

142 intake_url: str = INTAKE_URL, 

143 client_id: str = CLIENT_ID, 

144 bearer_token: str = CLIENT_SECRET, 

145 **kwargs: Any, 

146 ) -> None: 

147 super().__init__(*args, **kwargs) 

148 self._intake_url = intake_url 

149 self._client_id = client_id 

150 self._bearer_token = bearer_token 

151 

152 def __getattr__(self, name: str) -> Any: 

153 instance = super().__getattr__(name) 

154 if name == "chat": 

155 instance = Chat( 

156 self.sdk_configuration, 

157 parent_ref=self, 

158 intake_url=self._intake_url, 

159 client_id=self._client_id, 

160 bearer_token=self._bearer_token, 

161 ) 

162 object.__setattr__(self, name, instance) 

163 return instance