Coverage for agentos/cache/embedder.py: 34%

110 statements  

« prev     ^ index     » next       coverage.py v7.14.3, created at 2026-07-02 09:59 +0800

1""" 

2Embedding实现层 — 多种embedding provider的真实调用。 

3v0.50: 新增模块。为语义缓存/向量数据库提供embedding实现。 

4""" 

5 

6from __future__ import annotations 

7 

8import asyncio 

9import os 

10from abc import ABC, abstractmethod 

11from dataclasses import dataclass 

12from typing import Any 

13 

14import httpx 

15 

16 

17@dataclass 

18class EmbeddingResult: 

19 """Result of an embedding generation request.""" 

20 vector: list[float] 

21 tokens: int = 0 

22 model: str = "" 

23 

24 def __len__(self) -> int: 

25 return len(self.vector) 

26 

27 def __iter__(self): 

28 return iter(self.vector) 

29 

30 def __getitem__(self, idx): 

31 return self.vector[idx] 

32 

33 

34class BaseEmbedder(ABC): 

35 """Embedding提供者抽象基类。""" 

36 

37 @abstractmethod 

38 async def embed(self, text: str) -> EmbeddingResult: 

39 ... 

40 

41 @abstractmethod 

42 async def embed_batch(self, texts: list[str]) -> list[EmbeddingResult]: 

43 ... 

44 

45 @abstractmethod 

46 def dimension(self) -> int: 

47 ... 

48 

49 

50class OpenAIEmbedder(BaseEmbedder): 

51 """OpenAI text-embedding-3-small / text-embedding-3-large.""" 

52 

53 MODELS = { 

54 "small": ("text-embedding-3-small", 1536), 

55 "large": ("text-embedding-3-large", 3072), 

56 "ada": ("text-embedding-ada-002", 1536), 

57 } 

58 

59 def __init__(self, model: str = "small", api_key: str = "", 

60 base_url: str = "https://api.openai.com/v1"): 

61 info = self.MODELS.get(model) 

62 if not info: 

63 raise ValueError(f"Unknown model key: {model}. Use: {list(self.MODELS.keys())}") 

64 self.model_id, self._dim = info 

65 self.api_key = api_key or os.environ.get("OPENAI_API_KEY", "") 

66 self.base_url = base_url 

67 self._http = httpx.AsyncClient(timeout=60, headers={ 

68 "Authorization": f"Bearer {self.api_key}", 

69 "Content-Type": "application/json", 

70 }) 

71 

72 def dimension(self) -> int: 

73 return self._dim 

74 

75 async def embed(self, text: str) -> EmbeddingResult: 

76 results = await self.embed_batch([text]) 

77 return results[0] 

78 

79 async def embed_batch(self, texts: list[str]) -> list[EmbeddingResult]: 

80 body = {"model": self.model_id, "input": texts} 

81 resp = await self._http.post(f"{self.base_url}/embeddings", json=body) 

82 resp.raise_for_status() 

83 data = resp.json() 

84 results = [] 

85 for item in data["data"]: 

86 results.append(EmbeddingResult( 

87 vector=item["embedding"], 

88 model=self.model_id, 

89 )) 

90 return results 

91 

92 async def close(self): 

93 await self._http.aclose() 

94 

95 

96class LocalEmbedder(BaseEmbedder): 

97 """本地sentence-transformers模型。无API调用,零成本。""" 

98 

99 def __init__(self, model_name: str = "all-MiniLM-L6-v2"): 

100 self.model_name = model_name 

101 self._model = None 

102 self._dim = 384 

103 

104 def _ensure_model(self): 

105 if self._model is None: 

106 from sentence_transformers import SentenceTransformer 

107 self._model = SentenceTransformer(self.model_name) 

108 self._dim = self._model.get_sentence_embedding_dimension() 

109 

110 def dimension(self) -> int: 

111 if self._model is None: 

112 if self.model_name == "all-MiniLM-L6-v2": 

113 self._dim = 384 

114 elif "large" in self.model_name: 

115 self._dim = 1024 

116 else: 

117 self._dim = 768 

118 return self._dim 

119 

120 async def embed(self, text: str) -> EmbeddingResult: 

121 self._ensure_model() 

122 vec = self._model.encode(text, normalize_embeddings=True) 

123 return EmbeddingResult(vector=vec.tolist(), model=self.model_name) 

124 

125 async def embed_batch(self, texts: list[str]) -> list[EmbeddingResult]: 

126 self._ensure_model() 

127 vecs = self._model.encode(texts, normalize_embeddings=True) 

128 return [ 

129 EmbeddingResult(vector=v.tolist(), model=self.model_name) 

130 for v in vecs 

131 ] 

132 

133 

134class CohereEmbedder(BaseEmbedder): 

135 """Cohere Embed API.""" 

136 

137 def __init__(self, model: str = "embed-english-v3.0", api_key: str = ""): 

138 self.model_id = model 

139 self.api_key = api_key or os.environ.get("COHERE_API_KEY", "") 

140 self._http = httpx.AsyncClient(timeout=60, headers={ 

141 "Authorization": f"Bearer {self.api_key}", 

142 "Content-Type": "application/json", 

143 }) 

144 self._dim = {"embed-english-v3.0": 1024, "embed-multilingual-v3.0": 1024}.get(model, 1024) 

145 

146 def dimension(self) -> int: 

147 return self._dim 

148 

149 async def embed(self, text: str) -> EmbeddingResult: 

150 body = {"model": self.model_id, "texts": [text], "input_type": "search_document"} 

151 resp = await self._http.post("https://api.cohere.ai/v1/embed", json=body) 

152 resp.raise_for_status() 

153 data = resp.json() 

154 return EmbeddingResult(vector=data["embeddings"][0], model=self.model_id) 

155 

156 async def embed_batch(self, texts: list[str]) -> list[EmbeddingResult]: 

157 body = {"model": self.model_id, "texts": texts, "input_type": "search_document"} 

158 resp = await self._http.post("https://api.cohere.ai/v1/embed", json=body) 

159 resp.raise_for_status() 

160 data = resp.json() 

161 return [ 

162 EmbeddingResult(vector=vec, model=self.model_id) 

163 for vec in data["embeddings"] 

164 ] 

165 

166 async def close(self): 

167 await self._http.aclose() 

168 

169 

170async def get_embedder(provider: str = "openai", **kwargs) -> BaseEmbedder: 

171 """工厂函数:获取embedder实例。""" 

172 match provider: 

173 case "openai": 

174 return OpenAIEmbedder(**kwargs) 

175 case "local": 

176 return LocalEmbedder(**kwargs) 

177 case "cohere": 

178 return CohereEmbedder(**kwargs) 

179 case _: 

180 raise ValueError(f"Unknown embedder provider: {provider}. Use: openai/local/cohere") 

181 

182 

183async def cosine_similarity(a: list[float], b: list[float]) -> float: 

184 """余弦相似度。""" 

185 dot = sum(x * y for x, y in zip(a, b)) 

186 norm_a = sum(x * x for x in a) ** 0.5 

187 norm_b = sum(x * x for x in b) ** 0.5 

188 if norm_a == 0 or norm_b == 0: 

189 return 0.0 

190 return dot / (norm_a * norm_b)