Coverage for src/dataknobs_data/vector/bulk_embed_mixin.py: 9%

81 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-29 14:14 -0600

1"""Mixin providing default bulk_embed_and_store implementation.""" 

2 

3from __future__ import annotations 

4 

5from typing import TYPE_CHECKING, cast 

6 

7from ..fields import VectorField 

8 

9if TYPE_CHECKING: 

10 import numpy as np 

11 from collections.abc import Awaitable, Callable 

12 from ..records import Record 

13 

14 

15class BulkEmbedMixin: 

16 """Mixin providing default implementation of bulk_embed_and_store. 

17  

18 This mixin can be used by any database backend to provide a standard 

19 implementation of bulk embedding and storage without circular dependencies. 

20 """ 

21 

22 def bulk_embed_and_store( 

23 self, 

24 records: list[Record], 

25 text_field: str | list[str], 

26 vector_field: str = "embedding", 

27 embedding_fn: Callable[[list[str]], np.ndarray] | None = None, 

28 batch_size: int = 100, 

29 model_name: str | None = None, 

30 model_version: str | None = None, 

31 ) -> list[str]: 

32 """Embed text fields and store vectors with records. 

33  

34 Args: 

35 records: Records to process 

36 text_field: Field name(s) containing text to embed 

37 vector_field: Field name to store vectors in 

38 embedding_fn: Function to generate embeddings 

39 batch_size: Number of records to process at once 

40 model_name: Name of the embedding model 

41 model_version: Version of the embedding model 

42  

43 Returns: 

44 List of record IDs that were processed 

45  

46 Raises: 

47 ValueError: If embedding_fn is not provided 

48 """ 

49 if not embedding_fn: 

50 raise ValueError("embedding_fn is required for bulk_embed_and_store") 

51 

52 # Process text fields 

53 if isinstance(text_field, str): 

54 text_fields = [text_field] 

55 else: 

56 text_fields = text_field 

57 

58 processed_ids = [] 

59 

60 # Process in batches 

61 for i in range(0, len(records), batch_size): 

62 batch = records[i:i + batch_size] 

63 

64 # Extract text from records 

65 texts = [] 

66 for record in batch: 

67 # Combine text from all specified fields 

68 text_parts = [] 

69 for field_name in text_fields: 

70 if field_name in record.fields: 

71 field_value = record.fields[field_name].value 

72 if field_value: 

73 text_parts.append(str(field_value)) 

74 texts.append(" ".join(text_parts)) 

75 

76 # Generate embeddings 

77 if texts: 

78 embeddings = embedding_fn(texts) 

79 

80 # Add vectors to records 

81 for j, record in enumerate(batch): 

82 if j < len(embeddings) if hasattr(embeddings, '__len__') else j == 0: 

83 # Get the embedding for this record 

84 if hasattr(embeddings, '__getitem__'): 

85 vector = embeddings[j] 

86 else: 

87 # Single embedding returned for single text 

88 vector = embeddings 

89 

90 # Add or update vector field 

91 # Join multiple source fields with comma for metadata 

92 source_field_str = text_fields[0] if len(text_fields) == 1 else ",".join(text_fields) 

93 record.fields[vector_field] = VectorField( 

94 name=vector_field, 

95 value=vector, 

96 source_field=source_field_str, 

97 model_name=model_name, 

98 model_version=model_version, 

99 ) 

100 

101 # Update vector dimensions tracking if available 

102 if hasattr(self, '_has_vector_fields') and hasattr(self, '_update_vector_dimensions'): 

103 if self._has_vector_fields(record): 

104 self._update_vector_dimensions(record) 

105 

106 # Create or update the record 

107 # Assumes self has create, update, and exists methods (from Database interface) 

108 if record.id and self.exists(record.id): # type: ignore 

109 self.update(record.id, record) # type: ignore 

110 processed_ids.append(record.id) 

111 else: 

112 record_id = self.create(record) # type: ignore 

113 processed_ids.append(record_id) 

114 

115 return processed_ids 

116 

117 

118class AsyncBulkEmbedMixin: 

119 """Async mixin providing default implementation of bulk_embed_and_store. 

120  

121 This mixin can be used by any async database backend to provide a standard 

122 implementation of bulk embedding and storage without circular dependencies. 

123 """ 

124 

125 async def bulk_embed_and_store( 

126 self, 

127 records: list[Record], 

128 text_field: str | list[str], 

129 vector_field: str = "embedding", 

130 embedding_fn: Callable[[list[str]], np.ndarray | Awaitable[np.ndarray]] | None = None, 

131 batch_size: int = 100, 

132 model_name: str | None = None, 

133 model_version: str | None = None, 

134 ) -> list[str]: 

135 """Embed text fields and store vectors with records. 

136  

137 Args: 

138 records: Records to process 

139 text_field: Field name(s) containing text to embed 

140 vector_field: Field name to store vectors in 

141 embedding_fn: Function to generate embeddings (can be sync or async) 

142 batch_size: Number of records to process at once 

143 model_name: Name of the embedding model 

144 model_version: Version of the embedding model 

145  

146 Returns: 

147 List of record IDs that were processed 

148  

149 Raises: 

150 ValueError: If embedding_fn is not provided 

151 """ 

152 import inspect 

153 

154 if not embedding_fn: 

155 raise ValueError("embedding_fn is required for bulk_embed_and_store") 

156 

157 # Check if embedding_fn is async 

158 is_async_fn = inspect.iscoroutinefunction(embedding_fn) 

159 

160 # Process text fields 

161 if isinstance(text_field, str): 

162 text_fields = [text_field] 

163 else: 

164 text_fields = text_field 

165 

166 processed_ids = [] 

167 

168 # Process in batches 

169 for i in range(0, len(records), batch_size): 

170 batch = records[i:i + batch_size] 

171 

172 # Extract text from records 

173 texts = [] 

174 for record in batch: 

175 # Combine text from all specified fields 

176 text_parts = [] 

177 for field_name in text_fields: 

178 if field_name in record.fields: 

179 field_value = record.fields[field_name].value 

180 if field_value: 

181 text_parts.append(str(field_value)) 

182 texts.append(" ".join(text_parts)) 

183 

184 # Generate embeddings 

185 if texts: 

186 if is_async_fn: 

187 embeddings = await cast("Awaitable[np.ndarray]", embedding_fn(texts)) 

188 else: 

189 embeddings = cast("np.ndarray", embedding_fn(texts)) 

190 

191 # Add vectors to records 

192 for j, record in enumerate(batch): 

193 if j < len(embeddings) if hasattr(embeddings, '__len__') else j == 0: 

194 # Get the embedding for this record 

195 if hasattr(embeddings, '__getitem__'): 

196 vector = embeddings[j] 

197 else: 

198 # Single embedding returned for single text 

199 vector = embeddings 

200 

201 # Add or update vector field 

202 # Join multiple source fields with comma for metadata 

203 source_field_str = text_fields[0] if len(text_fields) == 1 else ",".join(text_fields) 

204 record.fields[vector_field] = VectorField( 

205 name=vector_field, 

206 value=vector, 

207 source_field=source_field_str, 

208 model_name=model_name, 

209 model_version=model_version, 

210 ) 

211 

212 # Update vector dimensions tracking if available 

213 if hasattr(self, '_has_vector_fields') and hasattr(self, '_update_vector_dimensions'): 

214 if self._has_vector_fields(record): 

215 self._update_vector_dimensions(record) 

216 

217 # Create or update the record 

218 # Assumes self has async create, update, and exists methods (from AsyncDatabase interface) 

219 if record.id and await self.exists(record.id): # type: ignore 

220 await self.update(record.id, record) # type: ignore 

221 processed_ids.append(record.id) 

222 else: 

223 record_id = await self.create(record) # type: ignore 

224 processed_ids.append(record_id) 

225 

226 return processed_ids