Coverage for src/dataknobs_data/backends/memory.py: 33%
224 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
1"""In-memory database backend implementation."""
3from __future__ import annotations
5import asyncio
6import threading
7import uuid
8from collections import OrderedDict
9from typing import Any, TYPE_CHECKING
11from dataknobs_config import ConfigurableBase
13from ..database import AsyncDatabase, SyncDatabase
14from ..query_logic import ComplexQuery
15from ..streaming import AsyncStreamingMixin, StreamConfig, StreamingMixin, StreamResult
16from ..vector import VectorOperationsMixin
17from ..vector.bulk_embed_mixin import BulkEmbedMixin
18from ..vector.python_vector_search import PythonVectorSearchMixin
19from .sqlite_mixins import SQLiteVectorSupport
20from .vector_config_mixin import VectorConfigMixin
22if TYPE_CHECKING:
23 from collections.abc import AsyncIterator, Iterator
24 from ..query import Query
25 from ..records import Record
28class AsyncMemoryDatabase( # type: ignore[misc]
29 AsyncDatabase,
30 AsyncStreamingMixin,
31 ConfigurableBase,
32 VectorConfigMixin, # Parse vector config
33 SQLiteVectorSupport, # Provides _compute_similarity
34 PythonVectorSearchMixin, # Provides python_vector_search_async
35 BulkEmbedMixin, # Bulk embedding operations
36 VectorOperationsMixin # Standard vector interface
37):
38 """Async in-memory database implementation."""
40 def __init__(self, config: dict[str, Any] | None = None):
41 super().__init__(config)
42 self._storage: OrderedDict[str, Record] = OrderedDict()
43 self._lock = asyncio.Lock()
45 # Initialize vector support
46 self._parse_vector_config(config or {})
47 self._init_vector_state() # From SQLiteVectorSupport
49 @classmethod
50 def from_config(cls, config: dict) -> AsyncMemoryDatabase:
51 """Create from config dictionary."""
52 return cls(config)
55 def _generate_id(self) -> str:
56 """Generate a unique ID for a record."""
57 return str(uuid.uuid4())
59 async def create(self, record: Record) -> str:
60 """Create a new record in memory."""
61 async with self._lock:
62 # Use centralized method to prepare record
63 record_copy, storage_id = self._prepare_record_for_storage(record)
65 # Store the record
66 self._storage[storage_id] = record_copy
67 return storage_id
69 async def read(self, id: str) -> Record | None:
70 """Read a record from memory."""
71 async with self._lock:
72 record = self._storage.get(id)
73 # Use centralized method to prepare record
74 return self._prepare_record_from_storage(record, id)
76 async def update(self, id: str, record: Record) -> bool:
77 """Update a record in memory."""
78 async with self._lock:
79 if id in self._storage:
80 self._storage[id] = record.copy(deep=True)
81 return True
82 return False
84 async def delete(self, id: str) -> bool:
85 """Delete a record from memory."""
86 async with self._lock:
87 if id in self._storage:
88 del self._storage[id]
89 return True
90 return False
92 async def exists(self, id: str) -> bool:
93 """Check if a record exists in memory."""
94 async with self._lock:
95 return id in self._storage
97 async def upsert(self, id: str, record: Record) -> str:
98 """Update or insert a record with the specified ID."""
99 async with self._lock:
100 self._storage[id] = record.copy(deep=True)
101 return id
103 async def search(self, query: Query | ComplexQuery) -> list[Record]:
104 """Search for records matching the query."""
105 # Handle ComplexQuery using base class implementation
106 if isinstance(query, ComplexQuery):
107 return await self._search_with_complex_query(query)
109 async with self._lock:
110 results = []
112 for id, record in self._storage.items():
113 # Apply filters
114 matches = True
115 for filter in query.filters:
116 field_value = record.get_value(filter.field)
117 if not filter.matches(field_value):
118 matches = False
119 break
121 if matches:
122 results.append((id, record))
124 # Use the helper method from base class
125 return self._process_search_results(results, query, deep_copy=True)
127 async def _count_all(self) -> int:
128 """Count all records in memory."""
129 async with self._lock:
130 return len(self._storage)
132 async def clear(self) -> int:
133 """Clear all records from memory."""
134 async with self._lock:
135 count = len(self._storage)
136 self._storage.clear()
137 return count
139 async def create_batch(self, records: list[Record]) -> list[str]:
140 """Create multiple records efficiently."""
141 async with self._lock:
142 ids = []
143 for record in records:
144 # Use centralized method to prepare record
145 record_copy, storage_id = self._prepare_record_for_storage(record)
147 # Store the record
148 self._storage[storage_id] = record_copy
149 ids.append(storage_id)
150 return ids
152 async def read_batch(self, ids: list[str]) -> list[Record | None]:
153 """Read multiple records efficiently."""
154 async with self._lock:
155 results = []
156 for id in ids:
157 record = self._storage.get(id)
158 # Use centralized method to prepare record
159 results.append(self._prepare_record_from_storage(record, id))
160 return results
162 async def delete_batch(self, ids: list[str]) -> list[bool]:
163 """Delete multiple records efficiently."""
164 async with self._lock:
165 results = []
166 for id in ids:
167 if id in self._storage:
168 del self._storage[id]
169 results.append(True)
170 else:
171 results.append(False)
172 return results
174 async def stream_read(
175 self,
176 query: Query | None = None,
177 config: StreamConfig | None = None
178 ) -> AsyncIterator[Record]:
179 """Stream records from memory."""
180 config = config or StreamConfig()
182 # Get all matching records
183 if query:
184 records = await self.search(query)
185 else:
186 async with self._lock:
187 # Ensure records have IDs when getting directly from storage
188 records = []
189 for record_id, record in self._storage.items():
190 record_copy = self._ensure_record_id(record, record_id)
191 records.append(record_copy)
193 # Yield records in batches
194 for i in range(0, len(records), config.batch_size):
195 batch = records[i:i + config.batch_size]
196 for record in batch:
197 yield record.copy(deep=True)
198 # Small yield to prevent blocking
199 await asyncio.sleep(0)
201 async def stream_write(
202 self,
203 records: AsyncIterator[Record],
204 config: StreamConfig | None = None
205 ) -> StreamResult:
206 """Stream records into memory."""
207 # Use the default implementation from mixin
208 return await self._default_stream_write(records, config)
210 async def vector_search(
211 self,
212 query_vector,
213 vector_field: str = "embedding",
214 k: int = 10,
215 filter=None,
216 metric=None,
217 **kwargs
218 ):
219 """Perform vector similarity search using Python calculations."""
220 return await self.python_vector_search_async(
221 query_vector=query_vector,
222 vector_field=vector_field,
223 k=k,
224 filter=filter,
225 metric=metric,
226 **kwargs
227 )
230class SyncMemoryDatabase( # type: ignore[misc]
231 SyncDatabase,
232 StreamingMixin,
233 ConfigurableBase,
234 VectorConfigMixin,
235 SQLiteVectorSupport,
236 PythonVectorSearchMixin,
237 BulkEmbedMixin,
238 VectorOperationsMixin
239):
240 """Synchronous in-memory database implementation."""
242 def __init__(self, config: dict[str, Any] | None = None):
243 super().__init__(config)
244 self._storage: OrderedDict[str, Record] = OrderedDict()
245 self._lock = threading.RLock()
247 # Initialize vector support
248 self._parse_vector_config(config or {})
249 self._init_vector_state()
251 @classmethod
252 def from_config(cls, config: dict) -> SyncMemoryDatabase:
253 """Create from config dictionary."""
254 return cls(config)
257 def _generate_id(self) -> str:
258 """Generate a unique ID for a record."""
259 return str(uuid.uuid4())
261 def create(self, record: Record) -> str:
262 """Create a new record in memory."""
263 with self._lock:
264 # Use record's ID if it has one, otherwise generate a new one
265 id = record.id if record.id else self._generate_id()
266 self._storage[id] = record.copy(deep=True)
267 return id
269 def read(self, id: str) -> Record | None:
270 """Read a record from memory."""
271 with self._lock:
272 record = self._storage.get(id)
273 return record.copy(deep=True) if record else None
275 def update(self, id: str, record: Record) -> bool:
276 """Update a record in memory."""
277 with self._lock:
278 if id in self._storage:
279 self._storage[id] = record.copy(deep=True)
280 return True
281 return False
283 def delete(self, id: str) -> bool:
284 """Delete a record from memory."""
285 with self._lock:
286 if id in self._storage:
287 del self._storage[id]
288 return True
289 return False
291 def exists(self, id: str) -> bool:
292 """Check if a record exists in memory."""
293 with self._lock:
294 return id in self._storage
296 def upsert(self, id: str, record: Record) -> str:
297 """Update or insert a record with the specified ID."""
298 with self._lock:
299 self._storage[id] = record.copy(deep=True)
300 return id
302 def search(self, query: Query | ComplexQuery) -> list[Record]:
303 """Search for records matching the query."""
304 # Handle ComplexQuery using base class implementation
305 if isinstance(query, ComplexQuery):
306 return self._search_with_complex_query(query)
308 with self._lock:
309 results = []
311 for id, record in self._storage.items():
312 # Apply filters
313 matches = True
314 for filter in query.filters:
315 field_value = record.get_value(filter.field)
316 if not filter.matches(field_value):
317 matches = False
318 break
320 if matches:
321 results.append((id, record))
323 # Use the helper method from base class
324 return self._process_search_results(results, query, deep_copy=True)
326 def _count_all(self) -> int:
327 """Count all records in memory."""
328 with self._lock:
329 return len(self._storage)
331 def clear(self) -> int:
332 """Clear all records from memory."""
333 with self._lock:
334 count = len(self._storage)
335 self._storage.clear()
336 return count
338 def create_batch(self, records: list[Record]) -> list[str]:
339 """Create multiple records efficiently."""
340 with self._lock:
341 ids = []
342 for record in records:
343 # Use record's ID if it has one, otherwise generate a new one
344 id = record.id if record.id else self._generate_id()
345 self._storage[id] = record.copy(deep=True)
346 ids.append(id)
347 return ids
349 def read_batch(self, ids: list[str]) -> list[Record | None]:
350 """Read multiple records efficiently."""
351 with self._lock:
352 results = []
353 for id in ids:
354 record = self._storage.get(id)
355 results.append(record.copy(deep=True) if record else None)
356 return results
358 def delete_batch(self, ids: list[str]) -> list[bool]:
359 """Delete multiple records efficiently."""
360 with self._lock:
361 results = []
362 for id in ids:
363 if id in self._storage:
364 del self._storage[id]
365 results.append(True)
366 else:
367 results.append(False)
368 return results
370 def stream_read(
371 self,
372 query: Query | None = None,
373 config: StreamConfig | None = None
374 ) -> Iterator[Record]:
375 """Stream records from memory."""
376 config = config or StreamConfig()
378 # Get all matching records
379 if query:
380 records = self.search(query)
381 else:
382 with self._lock:
383 # Ensure records have IDs when getting directly from storage
384 records = []
385 for record_id, record in self._storage.items():
386 record_copy = self._ensure_record_id(record, record_id)
387 records.append(record_copy)
389 # Yield records in batches
390 for i in range(0, len(records), config.batch_size):
391 batch = records[i:i + config.batch_size]
392 for record in batch:
393 yield record.copy(deep=True)
395 def stream_write(
396 self,
397 records: Iterator[Record],
398 config: StreamConfig | None = None
399 ) -> StreamResult:
400 """Stream records into memory."""
401 # Use the default implementation from mixin
402 return self._default_stream_write(records, config)
404 def vector_search(
405 self,
406 query_vector,
407 vector_field: str = "embedding",
408 k: int = 10,
409 filter=None,
410 metric=None,
411 **kwargs
412 ):
413 """Perform vector similarity search using Python calculations."""
414 return self.python_vector_search_sync(
415 query_vector=query_vector,
416 vector_field=vector_field,
417 k=k,
418 filter=filter,
419 metric=metric,
420 **kwargs
421 )