Coverage for src/dataknobs_data/vector/mixins.py: 35%
52 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
1"""Mixins and protocols for vector-capable databases."""
3from __future__ import annotations
5from abc import ABC, abstractmethod
6from typing import TYPE_CHECKING, Any, Protocol
8from ..fields import FieldType
9from .types import DistanceMetric, VectorSearchResult
11if TYPE_CHECKING:
12 import numpy as np
13 from collections.abc import Callable
14 from ..query import Query
15 from ..records import Record
18class VectorCapable(Protocol):
19 """Protocol for backends that can handle vector operations."""
21 async def has_vector_support(self) -> bool:
22 """Check if backend has vector support available.
24 Returns:
25 True if vector operations are supported
26 """
27 ...
29 async def enable_vector_support(self) -> bool:
30 """Enable vector support (install extensions, configure indices, etc.).
32 Returns:
33 True if vector support was successfully enabled
34 """
35 ...
37 async def detect_vector_fields(self, record: Record) -> list[str]:
38 """Detect vector fields in a record.
40 Args:
41 record: Record to examine
43 Returns:
44 List of field names that contain vectors
45 """
46 return [
47 field_name
48 for field_name, field_obj in record.fields.items()
49 if field_obj.type in (FieldType.VECTOR, FieldType.SPARSE_VECTOR)
50 ]
52 def get_vector_config(self) -> dict[str, Any]:
53 """Get vector-specific configuration for this backend.
55 Returns:
56 Dictionary of vector configuration options
57 """
58 return {}
61class VectorOperationsMixin(ABC):
62 """Mixin providing vector operations for databases.
64 This mixin should be added to database backend classes that support
65 vector operations. It provides abstract methods that must be implemented
66 by the concrete backend class.
67 """
69 @abstractmethod
70 async def vector_search(
71 self,
72 query_vector: np.ndarray | list[float],
73 vector_field: str = "embedding",
74 k: int = 10,
75 metric: DistanceMetric = DistanceMetric.COSINE,
76 filter: Query | None = None,
77 include_source: bool = True,
78 score_threshold: float | None = None,
79 ) -> list[VectorSearchResult]:
80 """Search for similar vectors.
82 Args:
83 query_vector: The vector to search for
84 vector_field: Name of the vector field to search
85 k: Number of results to return
86 metric: Distance metric to use
87 filter: Optional query filter to apply before vector search
88 include_source: Whether to include source text in results
89 score_threshold: Optional minimum similarity score
91 Returns:
92 List of search results ordered by similarity
93 """
94 pass
96 @abstractmethod
97 async def bulk_embed_and_store(
98 self,
99 records: list[Record],
100 text_field: str | list[str],
101 vector_field: str = "embedding",
102 embedding_fn: Callable[[list[str]], np.ndarray] | None = None,
103 batch_size: int = 100,
104 model_name: str | None = None,
105 model_version: str | None = None,
106 ) -> list[str]:
107 """Embed text fields and store vectors with records.
109 Args:
110 records: Records to process
111 text_field: Field name(s) containing text to embed
112 vector_field: Field name to store vectors in
113 embedding_fn: Function to generate embeddings
114 batch_size: Number of records to process at once
115 model_name: Name of the embedding model
116 model_version: Version of the embedding model
118 Returns:
119 List of record IDs that were processed
120 """
121 pass
123 async def update_vector(
124 self,
125 record_id: str,
126 vector_field: str,
127 vector: np.ndarray | list[float],
128 metadata: dict[str, Any] | None = None,
129 ) -> bool:
130 """Update a vector field for a specific record.
132 Args:
133 record_id: ID of the record to update
134 vector_field: Name of the vector field
135 vector: New vector value
136 metadata: Optional metadata to attach
138 Returns:
139 True if update was successful
140 """
141 # Default implementation using standard update
142 from ..fields import VectorField
144 record = await self.read(record_id) # type: ignore
145 if not record:
146 return False
148 # Update the vector field
149 record.fields[vector_field] = VectorField(
150 name=vector_field,
151 value=vector,
152 metadata=metadata,
153 )
155 return await self.update(record_id, record) is not None # type: ignore
157 async def delete_from_index(
158 self,
159 record_id: str,
160 vector_field: str = "embedding",
161 ) -> bool:
162 """Remove a record from the vector index.
164 Args:
165 record_id: ID of the record to remove
166 vector_field: Name of the vector field
168 Returns:
169 True if deletion was successful
170 """
171 # Default implementation using standard delete
172 return await self.delete(record_id) # type: ignore
174 async def create_vector_index(
175 self,
176 vector_field: str = "embedding",
177 dimensions: int | None = None,
178 metric: DistanceMetric = DistanceMetric.COSINE,
179 index_type: str = "auto",
180 **kwargs: Any,
181 ) -> bool:
182 """Create an index for vector similarity search.
184 Args:
185 vector_field: Name of the vector field to index
186 dimensions: Number of dimensions (if known)
187 metric: Distance metric for the index
188 index_type: Type of index to create
189 **kwargs: Backend-specific index parameters
191 Returns:
192 True if index was created successfully
193 """
194 # Default no-op implementation
195 return True
197 async def drop_vector_index(
198 self,
199 vector_field: str = "embedding",
200 ) -> bool:
201 """Drop a vector index.
203 Args:
204 vector_field: Name of the vector field
206 Returns:
207 True if index was dropped successfully
208 """
209 # Default no-op implementation
210 return True
212 async def get_vector_index_stats(
213 self,
214 vector_field: str = "embedding",
215 ) -> dict[str, Any]:
216 """Get statistics about a vector index.
218 Args:
219 vector_field: Name of the vector field
221 Returns:
222 Dictionary of index statistics
223 """
224 return {
225 "field": vector_field,
226 "indexed": False,
227 "vector_count": 0,
228 }
231class VectorSyncMixin:
232 """Mixin for synchronizing vectors with source text."""
234 async def sync_vectors_with_text(
235 self,
236 records: list[Record],
237 text_fields: list[str],
238 vector_field: str = "embedding",
239 embedding_fn: Callable[[list[str]], np.ndarray] | None = None,
240 force: bool = False,
241 ) -> int:
242 """Synchronize vector embeddings with text content.
244 Args:
245 records: Records to synchronize
246 text_fields: Text fields to generate vectors from
247 vector_field: Vector field to update
248 embedding_fn: Embedding function
249 force: Force re-generation even if vectors exist
251 Returns:
252 Number of records updated
253 """
254 if not embedding_fn:
255 raise ValueError("Embedding function is required for vector synchronization")
257 updated = 0
258 for record in records:
259 # Check if vector needs update
260 needs_update = force or vector_field not in record.fields
262 if not needs_update:
263 # Check if source fields changed
264 vector_meta = record.fields[vector_field].metadata
265 source_fields = vector_meta.get("source_field", "").split(",")
266 needs_update = set(source_fields) != set(text_fields)
268 if needs_update:
269 # Concatenate text fields
270 text_content = " ".join([
271 str(record.get_value(field))
272 for field in text_fields
273 if record.get_value(field)
274 ])
276 # Generate embedding
277 if text_content:
278 from ..fields import VectorField
280 result = embedding_fn([text_content])
281 # Handle both sync and async embedding functions
282 if hasattr(result, '__await__'):
283 embeddings = await result # type: ignore[misc]
284 else:
285 embeddings = result
286 record.fields[vector_field] = VectorField(
287 name=vector_field,
288 value=embeddings[0],
289 source_field=",".join(text_fields),
290 )
291 updated += 1
293 return updated