Coverage for src / dataknobs_data / vector / mixins.py: 20%
112 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-26 16:34 -0700
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-26 16:34 -0700
1"""Mixins and protocols for vector-capable databases."""
3from __future__ import annotations
5from abc import ABC, abstractmethod
6from typing import TYPE_CHECKING, Any, Protocol
8from ..fields import FieldType
9from .hybrid import (
10 FusionStrategy,
11 HybridSearchConfig,
12 HybridSearchResult,
13 reciprocal_rank_fusion,
14 weighted_score_fusion,
15)
16from .types import DistanceMetric, VectorSearchResult
18if TYPE_CHECKING:
19 import numpy as np
20 from collections.abc import Callable
21 from ..query import Query
22 from ..records import Record
25class VectorCapable(Protocol):
26 """Protocol for backends that can handle vector operations."""
28 async def has_vector_support(self) -> bool:
29 """Check if backend has vector support available.
31 Returns:
32 True if vector operations are supported
33 """
34 ...
36 async def enable_vector_support(self) -> bool:
37 """Enable vector support (install extensions, configure indices, etc.).
39 Returns:
40 True if vector support was successfully enabled
41 """
42 ...
44 async def detect_vector_fields(self, record: Record) -> list[str]:
45 """Detect vector fields in a record.
47 Args:
48 record: Record to examine
50 Returns:
51 List of field names that contain vectors
52 """
53 return [
54 field_name
55 for field_name, field_obj in record.fields.items()
56 if field_obj.type in (FieldType.VECTOR, FieldType.SPARSE_VECTOR)
57 ]
59 def get_vector_config(self) -> dict[str, Any]:
60 """Get vector-specific configuration for this backend.
62 Returns:
63 Dictionary of vector configuration options
64 """
65 return {}
68class VectorOperationsMixin(ABC):
69 """Mixin providing vector operations for databases.
71 This mixin should be added to database backend classes that support
72 vector operations. It provides abstract methods that must be implemented
73 by the concrete backend class.
74 """
76 @abstractmethod
77 async def vector_search(
78 self,
79 query_vector: np.ndarray | list[float],
80 vector_field: str = "embedding",
81 k: int = 10,
82 metric: DistanceMetric = DistanceMetric.COSINE,
83 filter: Query | None = None,
84 include_source: bool = True,
85 score_threshold: float | None = None,
86 ) -> list[VectorSearchResult]:
87 """Search for similar vectors.
89 Args:
90 query_vector: The vector to search for
91 vector_field: Name of the vector field to search
92 k: Number of results to return
93 metric: Distance metric to use
94 filter: Optional query filter to apply before vector search
95 include_source: Whether to include source text in results
96 score_threshold: Optional minimum similarity score
98 Returns:
99 List of search results ordered by similarity
100 """
101 pass
103 @abstractmethod
104 async def bulk_embed_and_store(
105 self,
106 records: list[Record],
107 text_field: str | list[str],
108 vector_field: str = "embedding",
109 embedding_fn: Callable[[list[str]], np.ndarray] | None = None,
110 batch_size: int = 100,
111 model_name: str | None = None,
112 model_version: str | None = None,
113 ) -> list[str]:
114 """Embed text fields and store vectors with records.
116 Args:
117 records: Records to process
118 text_field: Field name(s) containing text to embed
119 vector_field: Field name to store vectors in
120 embedding_fn: Function to generate embeddings
121 batch_size: Number of records to process at once
122 model_name: Name of the embedding model
123 model_version: Version of the embedding model
125 Returns:
126 List of record IDs that were processed
127 """
128 pass
130 async def update_vector(
131 self,
132 record_id: str,
133 vector_field: str,
134 vector: np.ndarray | list[float],
135 metadata: dict[str, Any] | None = None,
136 ) -> bool:
137 """Update a vector field for a specific record.
139 Args:
140 record_id: ID of the record to update
141 vector_field: Name of the vector field
142 vector: New vector value
143 metadata: Optional metadata to attach
145 Returns:
146 True if update was successful
147 """
148 # Default implementation using standard update
149 from ..fields import VectorField
151 record = await self.read(record_id) # type: ignore
152 if not record:
153 return False
155 # Update the vector field
156 record.fields[vector_field] = VectorField(
157 name=vector_field,
158 value=vector,
159 metadata=metadata,
160 )
162 return await self.update(record_id, record) is not None # type: ignore
164 async def delete_from_index(
165 self,
166 record_id: str,
167 vector_field: str = "embedding",
168 ) -> bool:
169 """Remove a record from the vector index.
171 Args:
172 record_id: ID of the record to remove
173 vector_field: Name of the vector field
175 Returns:
176 True if deletion was successful
177 """
178 # Default implementation using standard delete
179 return await self.delete(record_id) # type: ignore
181 async def create_vector_index(
182 self,
183 vector_field: str = "embedding",
184 dimensions: int | None = None,
185 metric: DistanceMetric = DistanceMetric.COSINE,
186 index_type: str = "auto",
187 **kwargs: Any,
188 ) -> bool:
189 """Create an index for vector similarity search.
191 Args:
192 vector_field: Name of the vector field to index
193 dimensions: Number of dimensions (if known)
194 metric: Distance metric for the index
195 index_type: Type of index to create
196 **kwargs: Backend-specific index parameters
198 Returns:
199 True if index was created successfully
200 """
201 # Default no-op implementation
202 return True
204 async def drop_vector_index(
205 self,
206 vector_field: str = "embedding",
207 ) -> bool:
208 """Drop a vector index.
210 Args:
211 vector_field: Name of the vector field
213 Returns:
214 True if index was dropped successfully
215 """
216 # Default no-op implementation
217 return True
219 async def get_vector_index_stats(
220 self,
221 vector_field: str = "embedding",
222 ) -> dict[str, Any]:
223 """Get statistics about a vector index.
225 Args:
226 vector_field: Name of the vector field
228 Returns:
229 Dictionary of index statistics
230 """
231 return {
232 "field": vector_field,
233 "indexed": False,
234 "vector_count": 0,
235 }
237 async def hybrid_search(
238 self,
239 query_text: str,
240 query_vector: np.ndarray | list[float],
241 text_fields: list[str] | None = None,
242 vector_field: str = "embedding",
243 k: int = 10,
244 config: HybridSearchConfig | None = None,
245 filter: Query | None = None,
246 metric: DistanceMetric = DistanceMetric.COSINE,
247 ) -> list[HybridSearchResult]:
248 """Perform hybrid search combining text and vector similarity.
250 This method combines traditional text search with vector similarity search
251 using configurable fusion strategies. The default implementation performs
252 both searches and merges results client-side. Backends with native hybrid
253 search support (like Elasticsearch) can override this for better performance.
255 Args:
256 query_text: Text query for keyword/text matching
257 query_vector: Vector for semantic similarity search
258 text_fields: Fields to search for text matching (default: search all text fields)
259 vector_field: Name of the vector field to search
260 k: Number of results to return
261 config: Hybrid search configuration (weights, fusion strategy)
262 filter: Optional additional filters to apply
263 metric: Distance metric for vector search
265 Returns:
266 List of HybridSearchResult ordered by combined score (descending)
268 Example:
269 ```python
270 from dataknobs_data.vector import HybridSearchConfig, FusionStrategy
272 # Default RRF fusion
273 results = await db.hybrid_search(
274 query_text="machine learning",
275 query_vector=embedding,
276 text_fields=["title", "content"],
277 k=10,
278 )
280 # Custom weighted fusion
281 config = HybridSearchConfig(
282 text_weight=0.3,
283 vector_weight=0.7,
284 fusion_strategy=FusionStrategy.WEIGHTED_SUM,
285 )
286 results = await db.hybrid_search(
287 query_text="machine learning",
288 query_vector=embedding,
289 config=config,
290 )
291 ```
292 """
293 config = config or HybridSearchConfig()
295 # If using NATIVE strategy but backend doesn't support it, fall back to RRF
296 if config.fusion_strategy == FusionStrategy.NATIVE:
297 if not await self._supports_native_hybrid(): # type: ignore[attr-defined]
298 config = HybridSearchConfig(
299 text_weight=config.text_weight,
300 vector_weight=config.vector_weight,
301 fusion_strategy=FusionStrategy.RRF,
302 rrf_k=config.rrf_k,
303 text_fields=config.text_fields,
304 )
306 # Use config.text_fields if provided, otherwise use parameter
307 search_text_fields = config.text_fields or text_fields
309 # Get more results for fusion (we'll filter to k after combining)
310 fetch_k = min(k * 3, 100)
312 # Perform text search
313 text_results = await self._text_search_for_hybrid(
314 query_text=query_text,
315 text_fields=search_text_fields,
316 k=fetch_k,
317 filter=filter,
318 )
320 # Perform vector search
321 vector_results = await self.vector_search(
322 query_vector=query_vector,
323 vector_field=vector_field,
324 k=fetch_k,
325 metric=metric,
326 filter=filter,
327 )
329 # Build ID->Record and ID->score maps
330 records_by_id: dict[str, Record] = {}
331 text_scores: list[tuple[str, float]] = []
332 vector_scores: list[tuple[str, float]] = []
334 for record, score in text_results:
335 record_id = record.id or record.storage_id
336 if record_id:
337 records_by_id[record_id] = record
338 text_scores.append((record_id, score))
340 for result in vector_results:
341 record_id = result.record.id or result.record.storage_id
342 if record_id:
343 records_by_id[record_id] = result.record
344 vector_scores.append((record_id, result.score))
346 # Fuse results
347 if config.fusion_strategy == FusionStrategy.RRF:
348 fused = reciprocal_rank_fusion(
349 text_results=text_scores,
350 vector_results=vector_scores,
351 k=config.rrf_k,
352 text_weight=config.text_weight,
353 vector_weight=config.vector_weight,
354 )
355 else: # WEIGHTED_SUM
356 text_w, vector_w = config.normalize_weights()
357 fused = weighted_score_fusion(
358 text_results=text_scores,
359 vector_results=vector_scores,
360 text_weight=text_w,
361 vector_weight=vector_w,
362 normalize_scores=True,
363 )
365 # Build HybridSearchResult objects
366 text_score_map = dict(text_scores)
367 vector_score_map = dict(vector_scores)
368 text_rank_map = {rid: i + 1 for i, (rid, _) in enumerate(text_scores)}
369 vector_rank_map = {rid: i + 1 for i, (rid, _) in enumerate(vector_scores)}
371 results: list[HybridSearchResult] = []
372 for record_id, combined_score in fused[:k]:
373 if record_id not in records_by_id:
374 continue
376 results.append(HybridSearchResult(
377 record=records_by_id[record_id],
378 combined_score=combined_score,
379 text_score=text_score_map.get(record_id),
380 vector_score=vector_score_map.get(record_id),
381 text_rank=text_rank_map.get(record_id),
382 vector_rank=vector_rank_map.get(record_id),
383 metadata={
384 "fusion_strategy": config.fusion_strategy.value,
385 "text_weight": config.text_weight,
386 "vector_weight": config.vector_weight,
387 },
388 ))
390 return results
392 async def _text_search_for_hybrid(
393 self,
394 query_text: str,
395 text_fields: list[str] | None,
396 k: int,
397 filter: Query | None = None,
398 ) -> list[tuple[Record, float]]:
399 """Perform text search for hybrid search fusion.
401 Default implementation uses LIKE query on text fields.
402 Backends can override for better text search (e.g., full-text search).
404 Args:
405 query_text: Text to search for
406 text_fields: Fields to search in
407 k: Maximum results to return
408 filter: Additional filters
410 Returns:
411 List of (record, score) tuples ordered by relevance
412 """
413 from ..query import Filter, Operator, Query as Q
415 # Build text search query
416 query = filter.copy() if filter else Q()
417 query.limit_value = k
419 # Add text matching filters
420 # For simple implementation, use LIKE on each text field with OR logic
421 # This is a basic implementation; backends should override for better text search
422 if text_fields:
423 # Use first field for simplicity in default implementation
424 # Backends with full-text search should override this
425 for field in text_fields[:1]: # Only use first field to avoid complex OR
426 query.filters.append(Filter(
427 field=field,
428 operator=Operator.LIKE,
429 value=f"%{query_text}%",
430 ))
432 # Perform search
433 records = await self.search(query) # type: ignore[attr-defined]
435 # Assign basic scores based on match quality
436 results: list[tuple[Record, float]] = []
437 query_lower = query_text.lower()
438 for i, record in enumerate(records):
439 # Calculate a simple relevance score
440 score = 1.0 / (i + 1) # Rank-based score
442 # Boost exact matches
443 for field in (text_fields or []):
444 value = record.get_value(field)
445 if value and isinstance(value, str):
446 if query_lower in value.lower():
447 score *= 1.5
448 if query_lower == value.lower():
449 score *= 2.0
451 results.append((record, min(score, 1.0)))
453 return results
455 async def _supports_native_hybrid(self) -> bool:
456 """Check if this backend supports native hybrid search.
458 Override in backends that have native hybrid search support
459 (e.g., Elasticsearch with RRF).
461 Returns:
462 True if native hybrid search is supported
463 """
464 return False
467class VectorSyncMixin:
468 """Mixin for synchronizing vectors with source text."""
470 async def sync_vectors_with_text(
471 self,
472 records: list[Record],
473 text_fields: list[str],
474 vector_field: str = "embedding",
475 embedding_fn: Callable[[list[str]], np.ndarray] | None = None,
476 force: bool = False,
477 ) -> int:
478 """Synchronize vector embeddings with text content.
480 Args:
481 records: Records to synchronize
482 text_fields: Text fields to generate vectors from
483 vector_field: Vector field to update
484 embedding_fn: Embedding function
485 force: Force re-generation even if vectors exist
487 Returns:
488 Number of records updated
489 """
490 if not embedding_fn:
491 raise ValueError("Embedding function is required for vector synchronization")
493 updated = 0
494 for record in records:
495 # Check if vector needs update
496 needs_update = force or vector_field not in record.fields
498 if not needs_update:
499 # Check if source fields changed
500 vector_meta = record.fields[vector_field].metadata
501 source_fields = vector_meta.get("source_field", "").split(",")
502 needs_update = set(source_fields) != set(text_fields)
504 if needs_update:
505 # Concatenate text fields
506 text_content = " ".join([
507 str(record.get_value(field))
508 for field in text_fields
509 if record.get_value(field)
510 ])
512 # Generate embedding
513 if text_content:
514 from ..fields import VectorField
516 result = embedding_fn([text_content])
517 # Handle both sync and async embedding functions
518 if hasattr(result, '__await__'):
519 embeddings = await result # type: ignore[misc]
520 else:
521 embeddings = result
522 record.fields[vector_field] = VectorField(
523 name=vector_field,
524 value=embeddings[0],
525 source_field=",".join(text_fields),
526 )
527 updated += 1
529 return updated