Coverage for agentos/cache/embedder.py: 34%
110 statements
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
« prev ^ index » next coverage.py v7.14.3, created at 2026-07-02 09:59 +0800
1"""
2Embedding实现层 — 多种embedding provider的真实调用。
3v0.50: 新增模块。为语义缓存/向量数据库提供embedding实现。
4"""
6from __future__ import annotations
8import asyncio
9import os
10from abc import ABC, abstractmethod
11from dataclasses import dataclass
12from typing import Any
14import httpx
17@dataclass
18class EmbeddingResult:
19 """Result of an embedding generation request."""
20 vector: list[float]
21 tokens: int = 0
22 model: str = ""
24 def __len__(self) -> int:
25 return len(self.vector)
27 def __iter__(self):
28 return iter(self.vector)
30 def __getitem__(self, idx):
31 return self.vector[idx]
34class BaseEmbedder(ABC):
35 """Embedding提供者抽象基类。"""
37 @abstractmethod
38 async def embed(self, text: str) -> EmbeddingResult:
39 ...
41 @abstractmethod
42 async def embed_batch(self, texts: list[str]) -> list[EmbeddingResult]:
43 ...
45 @abstractmethod
46 def dimension(self) -> int:
47 ...
50class OpenAIEmbedder(BaseEmbedder):
51 """OpenAI text-embedding-3-small / text-embedding-3-large."""
53 MODELS = {
54 "small": ("text-embedding-3-small", 1536),
55 "large": ("text-embedding-3-large", 3072),
56 "ada": ("text-embedding-ada-002", 1536),
57 }
59 def __init__(self, model: str = "small", api_key: str = "",
60 base_url: str = "https://api.openai.com/v1"):
61 info = self.MODELS.get(model)
62 if not info:
63 raise ValueError(f"Unknown model key: {model}. Use: {list(self.MODELS.keys())}")
64 self.model_id, self._dim = info
65 self.api_key = api_key or os.environ.get("OPENAI_API_KEY", "")
66 self.base_url = base_url
67 self._http = httpx.AsyncClient(timeout=60, headers={
68 "Authorization": f"Bearer {self.api_key}",
69 "Content-Type": "application/json",
70 })
72 def dimension(self) -> int:
73 return self._dim
75 async def embed(self, text: str) -> EmbeddingResult:
76 results = await self.embed_batch([text])
77 return results[0]
79 async def embed_batch(self, texts: list[str]) -> list[EmbeddingResult]:
80 body = {"model": self.model_id, "input": texts}
81 resp = await self._http.post(f"{self.base_url}/embeddings", json=body)
82 resp.raise_for_status()
83 data = resp.json()
84 results = []
85 for item in data["data"]:
86 results.append(EmbeddingResult(
87 vector=item["embedding"],
88 model=self.model_id,
89 ))
90 return results
92 async def close(self):
93 await self._http.aclose()
96class LocalEmbedder(BaseEmbedder):
97 """本地sentence-transformers模型。无API调用,零成本。"""
99 def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
100 self.model_name = model_name
101 self._model = None
102 self._dim = 384
104 def _ensure_model(self):
105 if self._model is None:
106 from sentence_transformers import SentenceTransformer
107 self._model = SentenceTransformer(self.model_name)
108 self._dim = self._model.get_sentence_embedding_dimension()
110 def dimension(self) -> int:
111 if self._model is None:
112 if self.model_name == "all-MiniLM-L6-v2":
113 self._dim = 384
114 elif "large" in self.model_name:
115 self._dim = 1024
116 else:
117 self._dim = 768
118 return self._dim
120 async def embed(self, text: str) -> EmbeddingResult:
121 self._ensure_model()
122 vec = self._model.encode(text, normalize_embeddings=True)
123 return EmbeddingResult(vector=vec.tolist(), model=self.model_name)
125 async def embed_batch(self, texts: list[str]) -> list[EmbeddingResult]:
126 self._ensure_model()
127 vecs = self._model.encode(texts, normalize_embeddings=True)
128 return [
129 EmbeddingResult(vector=v.tolist(), model=self.model_name)
130 for v in vecs
131 ]
134class CohereEmbedder(BaseEmbedder):
135 """Cohere Embed API."""
137 def __init__(self, model: str = "embed-english-v3.0", api_key: str = ""):
138 self.model_id = model
139 self.api_key = api_key or os.environ.get("COHERE_API_KEY", "")
140 self._http = httpx.AsyncClient(timeout=60, headers={
141 "Authorization": f"Bearer {self.api_key}",
142 "Content-Type": "application/json",
143 })
144 self._dim = {"embed-english-v3.0": 1024, "embed-multilingual-v3.0": 1024}.get(model, 1024)
146 def dimension(self) -> int:
147 return self._dim
149 async def embed(self, text: str) -> EmbeddingResult:
150 body = {"model": self.model_id, "texts": [text], "input_type": "search_document"}
151 resp = await self._http.post("https://api.cohere.ai/v1/embed", json=body)
152 resp.raise_for_status()
153 data = resp.json()
154 return EmbeddingResult(vector=data["embeddings"][0], model=self.model_id)
156 async def embed_batch(self, texts: list[str]) -> list[EmbeddingResult]:
157 body = {"model": self.model_id, "texts": texts, "input_type": "search_document"}
158 resp = await self._http.post("https://api.cohere.ai/v1/embed", json=body)
159 resp.raise_for_status()
160 data = resp.json()
161 return [
162 EmbeddingResult(vector=vec, model=self.model_id)
163 for vec in data["embeddings"]
164 ]
166 async def close(self):
167 await self._http.aclose()
170async def get_embedder(provider: str = "openai", **kwargs) -> BaseEmbedder:
171 """工厂函数:获取embedder实例。"""
172 match provider:
173 case "openai":
174 return OpenAIEmbedder(**kwargs)
175 case "local":
176 return LocalEmbedder(**kwargs)
177 case "cohere":
178 return CohereEmbedder(**kwargs)
179 case _:
180 raise ValueError(f"Unknown embedder provider: {provider}. Use: openai/local/cohere")
183async def cosine_similarity(a: list[float], b: list[float]) -> float:
184 """余弦相似度。"""
185 dot = sum(x * y for x, y in zip(a, b))
186 norm_a = sum(x * x for x in a) ** 0.5
187 norm_b = sum(x * x for x in b) ** 0.5
188 if norm_a == 0 or norm_b == 0:
189 return 0.0
190 return dot / (norm_a * norm_b)