Coverage for src \ truenex_memory \ store \ qdrant_store.py: 80%
127 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 10:21 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 10:21 +0200
1"""Optional Qdrant adapter and local in-memory vector store."""
3from __future__ import annotations
5from dataclasses import dataclass, field
6import math
7from typing import Any
10Payload = dict[str, Any]
13class VectorStoreUnavailable(RuntimeError):
14 """Raised when an optional vector backend cannot be used."""
17@dataclass(frozen=True)
18class VectorPoint:
19 """A vector and payload ready to be stored."""
21 id: str
22 vector: list[float]
23 payload: Payload = field(default_factory=dict)
26@dataclass(frozen=True)
27class VectorSearchHit:
28 """A ranked vector retrieval result."""
30 id: str
31 score: float
32 payload: Payload
35class InMemoryVectorStore:
36 """Small cosine-similarity store for unit tests and local fallback."""
38 def __init__(self, dimensions: int) -> None:
39 if dimensions < 1:
40 raise ValueError("dimensions must be greater than zero")
41 self.dimensions = dimensions
42 self._points: dict[str, VectorPoint] = {}
44 def upsert(self, points: list[object]) -> None:
45 for point in points:
46 vector = _point_vector(point)
47 point_id = _point_id(point)
48 payload = _point_payload(point)
49 self._validate_vector(vector)
50 self._points[point_id] = VectorPoint(
51 id=point_id,
52 vector=list(vector),
53 payload=dict(payload),
54 )
56 def search(
57 self,
58 vector: list[float],
59 *,
60 limit: int | None = None,
61 top_k: int | None = None,
62 ) -> list[VectorSearchHit]:
63 limit = _resolve_limit(limit=limit, top_k=top_k)
64 self._validate_limit(limit)
65 self._validate_vector(vector)
66 hits = [
67 VectorSearchHit(id=point.id, score=_cosine(vector, point.vector), payload=dict(point.payload))
68 for point in self._points.values()
69 ]
70 hits.sort(key=lambda hit: hit.score, reverse=True)
71 return hits[:limit]
73 def delete(self, ids: list[str]) -> None:
74 for point_id in ids:
75 self._points.pop(point_id, None)
77 def count(self) -> int:
78 return len(self._points)
80 def _validate_vector(self, vector: list[float]) -> None:
81 if len(vector) != self.dimensions:
82 raise ValueError(f"vector must have {self.dimensions} dimensions")
84 @staticmethod
85 def _validate_limit(limit: int) -> None:
86 if limit < 1:
87 raise ValueError("limit must be greater than zero")
90class QdrantVectorStore:
91 """Thin adapter around ``qdrant-client`` with controlled failure modes."""
93 def __init__(
94 self,
95 *,
96 collection_name: str,
97 dimensions: int,
98 url: str | None = None,
99 client: Any | None = None,
100 distance: str = "Cosine",
101 ) -> None:
102 if dimensions < 1:
103 raise ValueError("dimensions must be greater than zero")
104 if not collection_name.strip():
105 raise ValueError("collection_name cannot be empty")
106 self.collection_name = collection_name
107 self.dimensions = dimensions
108 self.distance = distance
109 self._client = client if client is not None else self._build_client(url)
110 self._models = _load_qdrant_models()
112 def initialize(self) -> None:
113 try:
114 if self._collection_exists():
115 return
116 distance = getattr(self._models.Distance, self.distance.upper())
117 self._client.create_collection(
118 collection_name=self.collection_name,
119 vectors_config=self._models.VectorParams(size=self.dimensions, distance=distance),
120 )
121 except VectorStoreUnavailable:
122 raise
123 except Exception as exc: # pragma: no cover - depends on live Qdrant
124 raise VectorStoreUnavailable(f"Qdrant is not reachable: {exc}") from exc
126 def upsert(self, points: list[object]) -> None:
127 self.initialize()
128 try:
129 qdrant_points = [
130 self._models.PointStruct(
131 id=_point_id(point),
132 vector=_point_vector(point),
133 payload=dict(_point_payload(point)),
134 )
135 for point in points
136 ]
137 self._client.upsert(collection_name=self.collection_name, points=qdrant_points)
138 except Exception as exc: # pragma: no cover - depends on live Qdrant
139 raise VectorStoreUnavailable(f"Qdrant upsert failed: {exc}") from exc
141 def search(
142 self,
143 vector: list[float],
144 *,
145 limit: int | None = None,
146 top_k: int | None = None,
147 ) -> list[VectorSearchHit]:
148 limit = _resolve_limit(limit=limit, top_k=top_k)
149 InMemoryVectorStore._validate_limit(limit)
150 self.initialize()
151 try:
152 rows = self._client.search(
153 collection_name=self.collection_name,
154 query_vector=vector,
155 limit=limit,
156 )
157 except Exception as exc: # pragma: no cover - depends on live Qdrant
158 raise VectorStoreUnavailable(f"Qdrant search failed: {exc}") from exc
159 return [
160 VectorSearchHit(id=str(row.id), score=float(row.score), payload=dict(row.payload or {}))
161 for row in rows
162 ]
164 def delete(self, ids: list[str]) -> None:
165 self.initialize()
166 try:
167 self._client.delete(
168 collection_name=self.collection_name,
169 points_selector=self._models.PointIdsList(points=ids),
170 )
171 except Exception as exc: # pragma: no cover - depends on live Qdrant
172 raise VectorStoreUnavailable(f"Qdrant delete failed: {exc}") from exc
174 def _collection_exists(self) -> bool:
175 try:
176 self._client.get_collection(self.collection_name)
177 return True
178 except Exception:
179 return False
181 @staticmethod
182 def _build_client(url: str | None) -> Any:
183 try:
184 from qdrant_client import QdrantClient
185 except ImportError as exc:
186 raise VectorStoreUnavailable("qdrant-client is not installed") from exc
187 try:
188 return QdrantClient(url=url) if url else QdrantClient(":memory:")
189 except Exception as exc: # pragma: no cover - depends on client version
190 raise VectorStoreUnavailable(f"Qdrant client could not be created: {exc}") from exc
193def _load_qdrant_models() -> Any:
194 try:
195 from qdrant_client import models
196 except ImportError as exc:
197 raise VectorStoreUnavailable("qdrant-client is not installed") from exc
198 return models
201def _resolve_limit(*, limit: int | None, top_k: int | None) -> int:
202 return top_k if top_k is not None else (limit if limit is not None else 5)
205def _point_id(point: object) -> str:
206 point_id = getattr(point, "id", None) or getattr(point, "point_id", None)
207 if point_id is None:
208 raise ValueError("vector point is missing an id")
209 return str(point_id)
212def _point_vector(point: object) -> list[float]:
213 vector = getattr(point, "vector", None)
214 if vector is None:
215 raise ValueError("vector point is missing a vector")
216 return [float(value) for value in vector]
219def _point_payload(point: object) -> Payload:
220 payload = getattr(point, "payload", None)
221 if payload is None:
222 return {}
223 return dict(payload)
226def _cosine(left: list[float], right: list[float]) -> float:
227 left_norm = math.sqrt(sum(value * value for value in left))
228 right_norm = math.sqrt(sum(value * value for value in right))
229 if left_norm == 0 or right_norm == 0:
230 return 0.0
231 dot = sum(left_value * right_value for left_value, right_value in zip(left, right, strict=True))
232 return dot / (left_norm * right_norm)