Coverage for src / dataknobs_data / vector / mixins.py: 20%

112 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-26 16:34 -0700

1"""Mixins and protocols for vector-capable databases.""" 

2 

3from __future__ import annotations 

4 

5from abc import ABC, abstractmethod 

6from typing import TYPE_CHECKING, Any, Protocol 

7 

8from ..fields import FieldType 

9from .hybrid import ( 

10 FusionStrategy, 

11 HybridSearchConfig, 

12 HybridSearchResult, 

13 reciprocal_rank_fusion, 

14 weighted_score_fusion, 

15) 

16from .types import DistanceMetric, VectorSearchResult 

17 

18if TYPE_CHECKING: 

19 import numpy as np 

20 from collections.abc import Callable 

21 from ..query import Query 

22 from ..records import Record 

23 

24 

25class VectorCapable(Protocol): 

26 """Protocol for backends that can handle vector operations.""" 

27 

28 async def has_vector_support(self) -> bool: 

29 """Check if backend has vector support available. 

30 

31 Returns: 

32 True if vector operations are supported 

33 """ 

34 ... 

35 

36 async def enable_vector_support(self) -> bool: 

37 """Enable vector support (install extensions, configure indices, etc.). 

38 

39 Returns: 

40 True if vector support was successfully enabled 

41 """ 

42 ... 

43 

44 async def detect_vector_fields(self, record: Record) -> list[str]: 

45 """Detect vector fields in a record. 

46 

47 Args: 

48 record: Record to examine 

49 

50 Returns: 

51 List of field names that contain vectors 

52 """ 

53 return [ 

54 field_name 

55 for field_name, field_obj in record.fields.items() 

56 if field_obj.type in (FieldType.VECTOR, FieldType.SPARSE_VECTOR) 

57 ] 

58 

59 def get_vector_config(self) -> dict[str, Any]: 

60 """Get vector-specific configuration for this backend. 

61 

62 Returns: 

63 Dictionary of vector configuration options 

64 """ 

65 return {} 

66 

67 

68class VectorOperationsMixin(ABC): 

69 """Mixin providing vector operations for databases. 

70 

71 This mixin should be added to database backend classes that support 

72 vector operations. It provides abstract methods that must be implemented 

73 by the concrete backend class. 

74 """ 

75 

76 @abstractmethod 

77 async def vector_search( 

78 self, 

79 query_vector: np.ndarray | list[float], 

80 vector_field: str = "embedding", 

81 k: int = 10, 

82 metric: DistanceMetric = DistanceMetric.COSINE, 

83 filter: Query | None = None, 

84 include_source: bool = True, 

85 score_threshold: float | None = None, 

86 ) -> list[VectorSearchResult]: 

87 """Search for similar vectors. 

88 

89 Args: 

90 query_vector: The vector to search for 

91 vector_field: Name of the vector field to search 

92 k: Number of results to return 

93 metric: Distance metric to use 

94 filter: Optional query filter to apply before vector search 

95 include_source: Whether to include source text in results 

96 score_threshold: Optional minimum similarity score 

97 

98 Returns: 

99 List of search results ordered by similarity 

100 """ 

101 pass 

102 

103 @abstractmethod 

104 async def bulk_embed_and_store( 

105 self, 

106 records: list[Record], 

107 text_field: str | list[str], 

108 vector_field: str = "embedding", 

109 embedding_fn: Callable[[list[str]], np.ndarray] | None = None, 

110 batch_size: int = 100, 

111 model_name: str | None = None, 

112 model_version: str | None = None, 

113 ) -> list[str]: 

114 """Embed text fields and store vectors with records. 

115 

116 Args: 

117 records: Records to process 

118 text_field: Field name(s) containing text to embed 

119 vector_field: Field name to store vectors in 

120 embedding_fn: Function to generate embeddings 

121 batch_size: Number of records to process at once 

122 model_name: Name of the embedding model 

123 model_version: Version of the embedding model 

124 

125 Returns: 

126 List of record IDs that were processed 

127 """ 

128 pass 

129 

130 async def update_vector( 

131 self, 

132 record_id: str, 

133 vector_field: str, 

134 vector: np.ndarray | list[float], 

135 metadata: dict[str, Any] | None = None, 

136 ) -> bool: 

137 """Update a vector field for a specific record. 

138 

139 Args: 

140 record_id: ID of the record to update 

141 vector_field: Name of the vector field 

142 vector: New vector value 

143 metadata: Optional metadata to attach 

144 

145 Returns: 

146 True if update was successful 

147 """ 

148 # Default implementation using standard update 

149 from ..fields import VectorField 

150 

151 record = await self.read(record_id) # type: ignore 

152 if not record: 

153 return False 

154 

155 # Update the vector field 

156 record.fields[vector_field] = VectorField( 

157 name=vector_field, 

158 value=vector, 

159 metadata=metadata, 

160 ) 

161 

162 return await self.update(record_id, record) is not None # type: ignore 

163 

164 async def delete_from_index( 

165 self, 

166 record_id: str, 

167 vector_field: str = "embedding", 

168 ) -> bool: 

169 """Remove a record from the vector index. 

170 

171 Args: 

172 record_id: ID of the record to remove 

173 vector_field: Name of the vector field 

174 

175 Returns: 

176 True if deletion was successful 

177 """ 

178 # Default implementation using standard delete 

179 return await self.delete(record_id) # type: ignore 

180 

181 async def create_vector_index( 

182 self, 

183 vector_field: str = "embedding", 

184 dimensions: int | None = None, 

185 metric: DistanceMetric = DistanceMetric.COSINE, 

186 index_type: str = "auto", 

187 **kwargs: Any, 

188 ) -> bool: 

189 """Create an index for vector similarity search. 

190 

191 Args: 

192 vector_field: Name of the vector field to index 

193 dimensions: Number of dimensions (if known) 

194 metric: Distance metric for the index 

195 index_type: Type of index to create 

196 **kwargs: Backend-specific index parameters 

197 

198 Returns: 

199 True if index was created successfully 

200 """ 

201 # Default no-op implementation 

202 return True 

203 

204 async def drop_vector_index( 

205 self, 

206 vector_field: str = "embedding", 

207 ) -> bool: 

208 """Drop a vector index. 

209 

210 Args: 

211 vector_field: Name of the vector field 

212 

213 Returns: 

214 True if index was dropped successfully 

215 """ 

216 # Default no-op implementation 

217 return True 

218 

219 async def get_vector_index_stats( 

220 self, 

221 vector_field: str = "embedding", 

222 ) -> dict[str, Any]: 

223 """Get statistics about a vector index. 

224 

225 Args: 

226 vector_field: Name of the vector field 

227 

228 Returns: 

229 Dictionary of index statistics 

230 """ 

231 return { 

232 "field": vector_field, 

233 "indexed": False, 

234 "vector_count": 0, 

235 } 

236 

237 async def hybrid_search( 

238 self, 

239 query_text: str, 

240 query_vector: np.ndarray | list[float], 

241 text_fields: list[str] | None = None, 

242 vector_field: str = "embedding", 

243 k: int = 10, 

244 config: HybridSearchConfig | None = None, 

245 filter: Query | None = None, 

246 metric: DistanceMetric = DistanceMetric.COSINE, 

247 ) -> list[HybridSearchResult]: 

248 """Perform hybrid search combining text and vector similarity. 

249 

250 This method combines traditional text search with vector similarity search 

251 using configurable fusion strategies. The default implementation performs 

252 both searches and merges results client-side. Backends with native hybrid 

253 search support (like Elasticsearch) can override this for better performance. 

254 

255 Args: 

256 query_text: Text query for keyword/text matching 

257 query_vector: Vector for semantic similarity search 

258 text_fields: Fields to search for text matching (default: search all text fields) 

259 vector_field: Name of the vector field to search 

260 k: Number of results to return 

261 config: Hybrid search configuration (weights, fusion strategy) 

262 filter: Optional additional filters to apply 

263 metric: Distance metric for vector search 

264 

265 Returns: 

266 List of HybridSearchResult ordered by combined score (descending) 

267 

268 Example: 

269 ```python 

270 from dataknobs_data.vector import HybridSearchConfig, FusionStrategy 

271 

272 # Default RRF fusion 

273 results = await db.hybrid_search( 

274 query_text="machine learning", 

275 query_vector=embedding, 

276 text_fields=["title", "content"], 

277 k=10, 

278 ) 

279 

280 # Custom weighted fusion 

281 config = HybridSearchConfig( 

282 text_weight=0.3, 

283 vector_weight=0.7, 

284 fusion_strategy=FusionStrategy.WEIGHTED_SUM, 

285 ) 

286 results = await db.hybrid_search( 

287 query_text="machine learning", 

288 query_vector=embedding, 

289 config=config, 

290 ) 

291 ``` 

292 """ 

293 config = config or HybridSearchConfig() 

294 

295 # If using NATIVE strategy but backend doesn't support it, fall back to RRF 

296 if config.fusion_strategy == FusionStrategy.NATIVE: 

297 if not await self._supports_native_hybrid(): # type: ignore[attr-defined] 

298 config = HybridSearchConfig( 

299 text_weight=config.text_weight, 

300 vector_weight=config.vector_weight, 

301 fusion_strategy=FusionStrategy.RRF, 

302 rrf_k=config.rrf_k, 

303 text_fields=config.text_fields, 

304 ) 

305 

306 # Use config.text_fields if provided, otherwise use parameter 

307 search_text_fields = config.text_fields or text_fields 

308 

309 # Get more results for fusion (we'll filter to k after combining) 

310 fetch_k = min(k * 3, 100) 

311 

312 # Perform text search 

313 text_results = await self._text_search_for_hybrid( 

314 query_text=query_text, 

315 text_fields=search_text_fields, 

316 k=fetch_k, 

317 filter=filter, 

318 ) 

319 

320 # Perform vector search 

321 vector_results = await self.vector_search( 

322 query_vector=query_vector, 

323 vector_field=vector_field, 

324 k=fetch_k, 

325 metric=metric, 

326 filter=filter, 

327 ) 

328 

329 # Build ID->Record and ID->score maps 

330 records_by_id: dict[str, Record] = {} 

331 text_scores: list[tuple[str, float]] = [] 

332 vector_scores: list[tuple[str, float]] = [] 

333 

334 for record, score in text_results: 

335 record_id = record.id or record.storage_id 

336 if record_id: 

337 records_by_id[record_id] = record 

338 text_scores.append((record_id, score)) 

339 

340 for result in vector_results: 

341 record_id = result.record.id or result.record.storage_id 

342 if record_id: 

343 records_by_id[record_id] = result.record 

344 vector_scores.append((record_id, result.score)) 

345 

346 # Fuse results 

347 if config.fusion_strategy == FusionStrategy.RRF: 

348 fused = reciprocal_rank_fusion( 

349 text_results=text_scores, 

350 vector_results=vector_scores, 

351 k=config.rrf_k, 

352 text_weight=config.text_weight, 

353 vector_weight=config.vector_weight, 

354 ) 

355 else: # WEIGHTED_SUM 

356 text_w, vector_w = config.normalize_weights() 

357 fused = weighted_score_fusion( 

358 text_results=text_scores, 

359 vector_results=vector_scores, 

360 text_weight=text_w, 

361 vector_weight=vector_w, 

362 normalize_scores=True, 

363 ) 

364 

365 # Build HybridSearchResult objects 

366 text_score_map = dict(text_scores) 

367 vector_score_map = dict(vector_scores) 

368 text_rank_map = {rid: i + 1 for i, (rid, _) in enumerate(text_scores)} 

369 vector_rank_map = {rid: i + 1 for i, (rid, _) in enumerate(vector_scores)} 

370 

371 results: list[HybridSearchResult] = [] 

372 for record_id, combined_score in fused[:k]: 

373 if record_id not in records_by_id: 

374 continue 

375 

376 results.append(HybridSearchResult( 

377 record=records_by_id[record_id], 

378 combined_score=combined_score, 

379 text_score=text_score_map.get(record_id), 

380 vector_score=vector_score_map.get(record_id), 

381 text_rank=text_rank_map.get(record_id), 

382 vector_rank=vector_rank_map.get(record_id), 

383 metadata={ 

384 "fusion_strategy": config.fusion_strategy.value, 

385 "text_weight": config.text_weight, 

386 "vector_weight": config.vector_weight, 

387 }, 

388 )) 

389 

390 return results 

391 

392 async def _text_search_for_hybrid( 

393 self, 

394 query_text: str, 

395 text_fields: list[str] | None, 

396 k: int, 

397 filter: Query | None = None, 

398 ) -> list[tuple[Record, float]]: 

399 """Perform text search for hybrid search fusion. 

400 

401 Default implementation uses LIKE query on text fields. 

402 Backends can override for better text search (e.g., full-text search). 

403 

404 Args: 

405 query_text: Text to search for 

406 text_fields: Fields to search in 

407 k: Maximum results to return 

408 filter: Additional filters 

409 

410 Returns: 

411 List of (record, score) tuples ordered by relevance 

412 """ 

413 from ..query import Filter, Operator, Query as Q 

414 

415 # Build text search query 

416 query = filter.copy() if filter else Q() 

417 query.limit_value = k 

418 

419 # Add text matching filters 

420 # For simple implementation, use LIKE on each text field with OR logic 

421 # This is a basic implementation; backends should override for better text search 

422 if text_fields: 

423 # Use first field for simplicity in default implementation 

424 # Backends with full-text search should override this 

425 for field in text_fields[:1]: # Only use first field to avoid complex OR 

426 query.filters.append(Filter( 

427 field=field, 

428 operator=Operator.LIKE, 

429 value=f"%{query_text}%", 

430 )) 

431 

432 # Perform search 

433 records = await self.search(query) # type: ignore[attr-defined] 

434 

435 # Assign basic scores based on match quality 

436 results: list[tuple[Record, float]] = [] 

437 query_lower = query_text.lower() 

438 for i, record in enumerate(records): 

439 # Calculate a simple relevance score 

440 score = 1.0 / (i + 1) # Rank-based score 

441 

442 # Boost exact matches 

443 for field in (text_fields or []): 

444 value = record.get_value(field) 

445 if value and isinstance(value, str): 

446 if query_lower in value.lower(): 

447 score *= 1.5 

448 if query_lower == value.lower(): 

449 score *= 2.0 

450 

451 results.append((record, min(score, 1.0))) 

452 

453 return results 

454 

455 async def _supports_native_hybrid(self) -> bool: 

456 """Check if this backend supports native hybrid search. 

457 

458 Override in backends that have native hybrid search support 

459 (e.g., Elasticsearch with RRF). 

460 

461 Returns: 

462 True if native hybrid search is supported 

463 """ 

464 return False 

465 

466 

467class VectorSyncMixin: 

468 """Mixin for synchronizing vectors with source text.""" 

469 

470 async def sync_vectors_with_text( 

471 self, 

472 records: list[Record], 

473 text_fields: list[str], 

474 vector_field: str = "embedding", 

475 embedding_fn: Callable[[list[str]], np.ndarray] | None = None, 

476 force: bool = False, 

477 ) -> int: 

478 """Synchronize vector embeddings with text content. 

479 

480 Args: 

481 records: Records to synchronize 

482 text_fields: Text fields to generate vectors from 

483 vector_field: Vector field to update 

484 embedding_fn: Embedding function 

485 force: Force re-generation even if vectors exist 

486 

487 Returns: 

488 Number of records updated 

489 """ 

490 if not embedding_fn: 

491 raise ValueError("Embedding function is required for vector synchronization") 

492 

493 updated = 0 

494 for record in records: 

495 # Check if vector needs update 

496 needs_update = force or vector_field not in record.fields 

497 

498 if not needs_update: 

499 # Check if source fields changed 

500 vector_meta = record.fields[vector_field].metadata 

501 source_fields = vector_meta.get("source_field", "").split(",") 

502 needs_update = set(source_fields) != set(text_fields) 

503 

504 if needs_update: 

505 # Concatenate text fields 

506 text_content = " ".join([ 

507 str(record.get_value(field)) 

508 for field in text_fields 

509 if record.get_value(field) 

510 ]) 

511 

512 # Generate embedding 

513 if text_content: 

514 from ..fields import VectorField 

515 

516 result = embedding_fn([text_content]) 

517 # Handle both sync and async embedding functions 

518 if hasattr(result, '__await__'): 

519 embeddings = await result # type: ignore[misc] 

520 else: 

521 embeddings = result 

522 record.fields[vector_field] = VectorField( 

523 name=vector_field, 

524 value=embeddings[0], 

525 source_field=",".join(text_fields), 

526 ) 

527 updated += 1 

528 

529 return updated