Coverage for src/dataknobs_data/vector/bulk_embed_mixin.py: 9%
81 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:23 -0700
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:23 -0700
1"""Mixin providing default bulk_embed_and_store implementation."""
3from __future__ import annotations
5from typing import TYPE_CHECKING, cast
7from ..fields import VectorField
9if TYPE_CHECKING:
10 import numpy as np
11 from collections.abc import Awaitable, Callable
12 from ..records import Record
15class BulkEmbedMixin:
16 """Mixin providing default implementation of bulk_embed_and_store.
18 This mixin can be used by any database backend to provide a standard
19 implementation of bulk embedding and storage without circular dependencies.
20 """
22 def bulk_embed_and_store(
23 self,
24 records: list[Record],
25 text_field: str | list[str],
26 vector_field: str = "embedding",
27 embedding_fn: Callable[[list[str]], np.ndarray] | None = None,
28 batch_size: int = 100,
29 model_name: str | None = None,
30 model_version: str | None = None,
31 ) -> list[str]:
32 """Embed text fields and store vectors with records.
34 Args:
35 records: Records to process
36 text_field: Field name(s) containing text to embed
37 vector_field: Field name to store vectors in
38 embedding_fn: Function to generate embeddings
39 batch_size: Number of records to process at once
40 model_name: Name of the embedding model
41 model_version: Version of the embedding model
43 Returns:
44 List of record IDs that were processed
46 Raises:
47 ValueError: If embedding_fn is not provided
48 """
49 if not embedding_fn:
50 raise ValueError("embedding_fn is required for bulk_embed_and_store")
52 # Process text fields
53 if isinstance(text_field, str):
54 text_fields = [text_field]
55 else:
56 text_fields = text_field
58 processed_ids = []
60 # Process in batches
61 for i in range(0, len(records), batch_size):
62 batch = records[i:i + batch_size]
64 # Extract text from records
65 texts = []
66 for record in batch:
67 # Combine text from all specified fields
68 text_parts = []
69 for field_name in text_fields:
70 if field_name in record.fields:
71 field_value = record.fields[field_name].value
72 if field_value:
73 text_parts.append(str(field_value))
74 texts.append(" ".join(text_parts))
76 # Generate embeddings
77 if texts:
78 embeddings = embedding_fn(texts)
80 # Add vectors to records
81 for j, record in enumerate(batch):
82 if j < len(embeddings) if hasattr(embeddings, '__len__') else j == 0:
83 # Get the embedding for this record
84 if hasattr(embeddings, '__getitem__'):
85 vector = embeddings[j]
86 else:
87 # Single embedding returned for single text
88 vector = embeddings
90 # Add or update vector field
91 # Join multiple source fields with comma for metadata
92 source_field_str = text_fields[0] if len(text_fields) == 1 else ",".join(text_fields)
93 record.fields[vector_field] = VectorField(
94 name=vector_field,
95 value=vector,
96 source_field=source_field_str,
97 model_name=model_name,
98 model_version=model_version,
99 )
101 # Update vector dimensions tracking if available
102 if hasattr(self, '_has_vector_fields') and hasattr(self, '_update_vector_dimensions'):
103 if self._has_vector_fields(record):
104 self._update_vector_dimensions(record)
106 # Create or update the record
107 # Assumes self has create, update, and exists methods (from Database interface)
108 if record.id and self.exists(record.id): # type: ignore
109 self.update(record.id, record) # type: ignore
110 processed_ids.append(record.id)
111 else:
112 record_id = self.create(record) # type: ignore
113 processed_ids.append(record_id)
115 return processed_ids
118class AsyncBulkEmbedMixin:
119 """Async mixin providing default implementation of bulk_embed_and_store.
121 This mixin can be used by any async database backend to provide a standard
122 implementation of bulk embedding and storage without circular dependencies.
123 """
125 async def bulk_embed_and_store(
126 self,
127 records: list[Record],
128 text_field: str | list[str],
129 vector_field: str = "embedding",
130 embedding_fn: Callable[[list[str]], np.ndarray | Awaitable[np.ndarray]] | None = None,
131 batch_size: int = 100,
132 model_name: str | None = None,
133 model_version: str | None = None,
134 ) -> list[str]:
135 """Embed text fields and store vectors with records.
137 Args:
138 records: Records to process
139 text_field: Field name(s) containing text to embed
140 vector_field: Field name to store vectors in
141 embedding_fn: Function to generate embeddings (can be sync or async)
142 batch_size: Number of records to process at once
143 model_name: Name of the embedding model
144 model_version: Version of the embedding model
146 Returns:
147 List of record IDs that were processed
149 Raises:
150 ValueError: If embedding_fn is not provided
151 """
152 import inspect
154 if not embedding_fn:
155 raise ValueError("embedding_fn is required for bulk_embed_and_store")
157 # Check if embedding_fn is async
158 is_async_fn = inspect.iscoroutinefunction(embedding_fn)
160 # Process text fields
161 if isinstance(text_field, str):
162 text_fields = [text_field]
163 else:
164 text_fields = text_field
166 processed_ids = []
168 # Process in batches
169 for i in range(0, len(records), batch_size):
170 batch = records[i:i + batch_size]
172 # Extract text from records
173 texts = []
174 for record in batch:
175 # Combine text from all specified fields
176 text_parts = []
177 for field_name in text_fields:
178 if field_name in record.fields:
179 field_value = record.fields[field_name].value
180 if field_value:
181 text_parts.append(str(field_value))
182 texts.append(" ".join(text_parts))
184 # Generate embeddings
185 if texts:
186 if is_async_fn:
187 embeddings = await cast("Awaitable[np.ndarray]", embedding_fn(texts))
188 else:
189 embeddings = cast("np.ndarray", embedding_fn(texts))
191 # Add vectors to records
192 for j, record in enumerate(batch):
193 if j < len(embeddings) if hasattr(embeddings, '__len__') else j == 0:
194 # Get the embedding for this record
195 if hasattr(embeddings, '__getitem__'):
196 vector = embeddings[j]
197 else:
198 # Single embedding returned for single text
199 vector = embeddings
201 # Add or update vector field
202 # Join multiple source fields with comma for metadata
203 source_field_str = text_fields[0] if len(text_fields) == 1 else ",".join(text_fields)
204 record.fields[vector_field] = VectorField(
205 name=vector_field,
206 value=vector,
207 source_field=source_field_str,
208 model_name=model_name,
209 model_version=model_version,
210 )
212 # Update vector dimensions tracking if available
213 if hasattr(self, '_has_vector_fields') and hasattr(self, '_update_vector_dimensions'):
214 if self._has_vector_fields(record):
215 self._update_vector_dimensions(record)
217 # Create or update the record
218 # Assumes self has async create, update, and exists methods (from AsyncDatabase interface)
219 if record.id and await self.exists(record.id): # type: ignore
220 await self.update(record.id, record) # type: ignore
221 processed_ids.append(record.id)
222 else:
223 record_id = await self.create(record) # type: ignore
224 processed_ids.append(record_id)
226 return processed_ids