Coverage for src \ truenex_memory \ store \ qdrant_store.py: 80%

127 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-19 10:21 +0200

1"""Optional Qdrant adapter and local in-memory vector store.""" 

2 

3from __future__ import annotations 

4 

5from dataclasses import dataclass, field 

6import math 

7from typing import Any 

8 

9 

10Payload = dict[str, Any] 

11 

12 

13class VectorStoreUnavailable(RuntimeError): 

14 """Raised when an optional vector backend cannot be used.""" 

15 

16 

17@dataclass(frozen=True) 

18class VectorPoint: 

19 """A vector and payload ready to be stored.""" 

20 

21 id: str 

22 vector: list[float] 

23 payload: Payload = field(default_factory=dict) 

24 

25 

26@dataclass(frozen=True) 

27class VectorSearchHit: 

28 """A ranked vector retrieval result.""" 

29 

30 id: str 

31 score: float 

32 payload: Payload 

33 

34 

35class InMemoryVectorStore: 

36 """Small cosine-similarity store for unit tests and local fallback.""" 

37 

38 def __init__(self, dimensions: int) -> None: 

39 if dimensions < 1: 

40 raise ValueError("dimensions must be greater than zero") 

41 self.dimensions = dimensions 

42 self._points: dict[str, VectorPoint] = {} 

43 

44 def upsert(self, points: list[object]) -> None: 

45 for point in points: 

46 vector = _point_vector(point) 

47 point_id = _point_id(point) 

48 payload = _point_payload(point) 

49 self._validate_vector(vector) 

50 self._points[point_id] = VectorPoint( 

51 id=point_id, 

52 vector=list(vector), 

53 payload=dict(payload), 

54 ) 

55 

56 def search( 

57 self, 

58 vector: list[float], 

59 *, 

60 limit: int | None = None, 

61 top_k: int | None = None, 

62 ) -> list[VectorSearchHit]: 

63 limit = _resolve_limit(limit=limit, top_k=top_k) 

64 self._validate_limit(limit) 

65 self._validate_vector(vector) 

66 hits = [ 

67 VectorSearchHit(id=point.id, score=_cosine(vector, point.vector), payload=dict(point.payload)) 

68 for point in self._points.values() 

69 ] 

70 hits.sort(key=lambda hit: hit.score, reverse=True) 

71 return hits[:limit] 

72 

73 def delete(self, ids: list[str]) -> None: 

74 for point_id in ids: 

75 self._points.pop(point_id, None) 

76 

77 def count(self) -> int: 

78 return len(self._points) 

79 

80 def _validate_vector(self, vector: list[float]) -> None: 

81 if len(vector) != self.dimensions: 

82 raise ValueError(f"vector must have {self.dimensions} dimensions") 

83 

84 @staticmethod 

85 def _validate_limit(limit: int) -> None: 

86 if limit < 1: 

87 raise ValueError("limit must be greater than zero") 

88 

89 

90class QdrantVectorStore: 

91 """Thin adapter around ``qdrant-client`` with controlled failure modes.""" 

92 

93 def __init__( 

94 self, 

95 *, 

96 collection_name: str, 

97 dimensions: int, 

98 url: str | None = None, 

99 client: Any | None = None, 

100 distance: str = "Cosine", 

101 ) -> None: 

102 if dimensions < 1: 

103 raise ValueError("dimensions must be greater than zero") 

104 if not collection_name.strip(): 

105 raise ValueError("collection_name cannot be empty") 

106 self.collection_name = collection_name 

107 self.dimensions = dimensions 

108 self.distance = distance 

109 self._client = client if client is not None else self._build_client(url) 

110 self._models = _load_qdrant_models() 

111 

112 def initialize(self) -> None: 

113 try: 

114 if self._collection_exists(): 

115 return 

116 distance = getattr(self._models.Distance, self.distance.upper()) 

117 self._client.create_collection( 

118 collection_name=self.collection_name, 

119 vectors_config=self._models.VectorParams(size=self.dimensions, distance=distance), 

120 ) 

121 except VectorStoreUnavailable: 

122 raise 

123 except Exception as exc: # pragma: no cover - depends on live Qdrant 

124 raise VectorStoreUnavailable(f"Qdrant is not reachable: {exc}") from exc 

125 

126 def upsert(self, points: list[object]) -> None: 

127 self.initialize() 

128 try: 

129 qdrant_points = [ 

130 self._models.PointStruct( 

131 id=_point_id(point), 

132 vector=_point_vector(point), 

133 payload=dict(_point_payload(point)), 

134 ) 

135 for point in points 

136 ] 

137 self._client.upsert(collection_name=self.collection_name, points=qdrant_points) 

138 except Exception as exc: # pragma: no cover - depends on live Qdrant 

139 raise VectorStoreUnavailable(f"Qdrant upsert failed: {exc}") from exc 

140 

141 def search( 

142 self, 

143 vector: list[float], 

144 *, 

145 limit: int | None = None, 

146 top_k: int | None = None, 

147 ) -> list[VectorSearchHit]: 

148 limit = _resolve_limit(limit=limit, top_k=top_k) 

149 InMemoryVectorStore._validate_limit(limit) 

150 self.initialize() 

151 try: 

152 rows = self._client.search( 

153 collection_name=self.collection_name, 

154 query_vector=vector, 

155 limit=limit, 

156 ) 

157 except Exception as exc: # pragma: no cover - depends on live Qdrant 

158 raise VectorStoreUnavailable(f"Qdrant search failed: {exc}") from exc 

159 return [ 

160 VectorSearchHit(id=str(row.id), score=float(row.score), payload=dict(row.payload or {})) 

161 for row in rows 

162 ] 

163 

164 def delete(self, ids: list[str]) -> None: 

165 self.initialize() 

166 try: 

167 self._client.delete( 

168 collection_name=self.collection_name, 

169 points_selector=self._models.PointIdsList(points=ids), 

170 ) 

171 except Exception as exc: # pragma: no cover - depends on live Qdrant 

172 raise VectorStoreUnavailable(f"Qdrant delete failed: {exc}") from exc 

173 

174 def _collection_exists(self) -> bool: 

175 try: 

176 self._client.get_collection(self.collection_name) 

177 return True 

178 except Exception: 

179 return False 

180 

181 @staticmethod 

182 def _build_client(url: str | None) -> Any: 

183 try: 

184 from qdrant_client import QdrantClient 

185 except ImportError as exc: 

186 raise VectorStoreUnavailable("qdrant-client is not installed") from exc 

187 try: 

188 return QdrantClient(url=url) if url else QdrantClient(":memory:") 

189 except Exception as exc: # pragma: no cover - depends on client version 

190 raise VectorStoreUnavailable(f"Qdrant client could not be created: {exc}") from exc 

191 

192 

193def _load_qdrant_models() -> Any: 

194 try: 

195 from qdrant_client import models 

196 except ImportError as exc: 

197 raise VectorStoreUnavailable("qdrant-client is not installed") from exc 

198 return models 

199 

200 

201def _resolve_limit(*, limit: int | None, top_k: int | None) -> int: 

202 return top_k if top_k is not None else (limit if limit is not None else 5) 

203 

204 

205def _point_id(point: object) -> str: 

206 point_id = getattr(point, "id", None) or getattr(point, "point_id", None) 

207 if point_id is None: 

208 raise ValueError("vector point is missing an id") 

209 return str(point_id) 

210 

211 

212def _point_vector(point: object) -> list[float]: 

213 vector = getattr(point, "vector", None) 

214 if vector is None: 

215 raise ValueError("vector point is missing a vector") 

216 return [float(value) for value in vector] 

217 

218 

219def _point_payload(point: object) -> Payload: 

220 payload = getattr(point, "payload", None) 

221 if payload is None: 

222 return {} 

223 return dict(payload) 

224 

225 

226def _cosine(left: list[float], right: list[float]) -> float: 

227 left_norm = math.sqrt(sum(value * value for value in left)) 

228 right_norm = math.sqrt(sum(value * value for value in right)) 

229 if left_norm == 0 or right_norm == 0: 

230 return 0.0 

231 dot = sum(left_value * right_value for left_value, right_value in zip(left, right, strict=True)) 

232 return dot / (left_norm * right_norm)