Coverage for src/dataknobs_data/backends/memory.py: 42%
248 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:15 -0600
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:15 -0600
1"""In-memory database backend implementation."""
3from __future__ import annotations
5import asyncio
6import threading
7import uuid
8from collections import OrderedDict
9from typing import Any, TYPE_CHECKING
11from dataknobs_config import ConfigurableBase
13from ..database import AsyncDatabase, SyncDatabase
14from ..query_logic import ComplexQuery
15from ..streaming import AsyncStreamingMixin, StreamConfig, StreamingMixin, StreamResult
16from ..vector import VectorOperationsMixin
17from ..vector.bulk_embed_mixin import BulkEmbedMixin
18from ..vector.python_vector_search import PythonVectorSearchMixin
19from .sqlite_mixins import SQLiteVectorSupport
20from .vector_config_mixin import VectorConfigMixin
22if TYPE_CHECKING:
23 from collections.abc import AsyncIterator, Iterator
24 from ..query import Query
25 from ..records import Record
28class AsyncMemoryDatabase( # type: ignore[misc]
29 AsyncDatabase,
30 AsyncStreamingMixin,
31 ConfigurableBase,
32 VectorConfigMixin, # Parse vector config
33 SQLiteVectorSupport, # Provides _compute_similarity
34 PythonVectorSearchMixin, # Provides python_vector_search_async
35 BulkEmbedMixin, # Bulk embedding operations
36 VectorOperationsMixin # Standard vector interface
37):
38 """Async in-memory database implementation."""
40 def __init__(self, config: dict[str, Any] | None = None):
41 super().__init__(config)
42 self._storage: OrderedDict[str, Record] = OrderedDict()
43 self._lock = asyncio.Lock()
45 # Initialize vector support
46 self._parse_vector_config(config or {})
47 self._init_vector_state() # From SQLiteVectorSupport
49 @classmethod
50 def from_config(cls, config: dict) -> AsyncMemoryDatabase:
51 """Create from config dictionary."""
52 return cls(config)
55 def _generate_id(self) -> str:
56 """Generate a unique ID for a record."""
57 return str(uuid.uuid4())
59 async def create(self, record: Record) -> str:
60 """Create a new record in memory."""
61 async with self._lock:
62 # Use centralized method to prepare record
63 record_copy, storage_id = self._prepare_record_for_storage(record)
65 # Store the record
66 self._storage[storage_id] = record_copy
67 return storage_id
69 async def read(self, id: str) -> Record | None:
70 """Read a record from memory."""
71 async with self._lock:
72 record = self._storage.get(id)
73 # Use centralized method to prepare record
74 return self._prepare_record_from_storage(record, id)
76 async def update(self, id: str, record: Record) -> bool:
77 """Update a record in memory."""
78 async with self._lock:
79 if id in self._storage:
80 self._storage[id] = record.copy(deep=True)
81 return True
82 return False
84 async def delete(self, id: str) -> bool:
85 """Delete a record from memory."""
86 async with self._lock:
87 if id in self._storage:
88 del self._storage[id]
89 return True
90 return False
92 async def exists(self, id: str) -> bool:
93 """Check if a record exists in memory."""
94 async with self._lock:
95 return id in self._storage
97 async def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str:
98 """Update or insert a record with the specified ID.
100 Overrides base class to handle memory-specific storage.
101 """
102 # Use base class logic to determine ID and record
103 if isinstance(id_or_record, str):
104 id = id_or_record
105 if record is None:
106 raise ValueError("Record required when ID is provided")
107 else:
108 record = id_or_record
109 id = record.id
110 if id is None:
111 import uuid # type: ignore[unreachable]
112 id = str(uuid.uuid4())
113 record.storage_id = id
115 # Memory-specific implementation
116 async with self._lock:
117 self._storage[id] = record.copy(deep=True)
118 return id
120 async def search(self, query: Query | ComplexQuery) -> list[Record]:
121 """Search for records matching the query."""
122 # Handle ComplexQuery using base class implementation
123 if isinstance(query, ComplexQuery):
124 return await self._search_with_complex_query(query)
126 async with self._lock:
127 results = []
129 for id, record in self._storage.items():
130 # Apply filters
131 matches = True
132 for filter in query.filters:
133 # Special handling for 'id' field
134 if filter.field == 'id':
135 field_value = id
136 else:
137 field_value = record.get_value(filter.field)
138 if not filter.matches(field_value):
139 matches = False
140 break
142 if matches:
143 results.append((id, record))
145 # Use the helper method from base class
146 return self._process_search_results(results, query, deep_copy=True)
148 async def _count_all(self) -> int:
149 """Count all records in memory."""
150 async with self._lock:
151 return len(self._storage)
153 async def clear(self) -> int:
154 """Clear all records from memory."""
155 async with self._lock:
156 count = len(self._storage)
157 self._storage.clear()
158 return count
160 async def create_batch(self, records: list[Record]) -> list[str]:
161 """Create multiple records efficiently."""
162 async with self._lock:
163 ids = []
164 for record in records:
165 # Use centralized method to prepare record
166 record_copy, storage_id = self._prepare_record_for_storage(record)
168 # Store the record
169 self._storage[storage_id] = record_copy
170 ids.append(storage_id)
171 return ids
173 async def read_batch(self, ids: list[str]) -> list[Record | None]:
174 """Read multiple records efficiently."""
175 async with self._lock:
176 results = []
177 for id in ids:
178 record = self._storage.get(id)
179 # Use centralized method to prepare record
180 results.append(self._prepare_record_from_storage(record, id))
181 return results
183 async def delete_batch(self, ids: list[str]) -> list[bool]:
184 """Delete multiple records efficiently."""
185 async with self._lock:
186 results = []
187 for id in ids:
188 if id in self._storage:
189 del self._storage[id]
190 results.append(True)
191 else:
192 results.append(False)
193 return results
195 async def stream_read(
196 self,
197 query: Query | None = None,
198 config: StreamConfig | None = None
199 ) -> AsyncIterator[Record]:
200 """Stream records from memory."""
201 config = config or StreamConfig()
203 # Get all matching records
204 if query:
205 records = await self.search(query)
206 else:
207 async with self._lock:
208 # Ensure records have IDs when getting directly from storage
209 records = []
210 for record_id, record in self._storage.items():
211 record_copy = self._ensure_record_id(record, record_id)
212 records.append(record_copy)
214 # Yield records in batches
215 for i in range(0, len(records), config.batch_size):
216 batch = records[i:i + config.batch_size]
217 for record in batch:
218 yield record.copy(deep=True)
219 # Small yield to prevent blocking
220 await asyncio.sleep(0)
222 async def stream_write(
223 self,
224 records: AsyncIterator[Record],
225 config: StreamConfig | None = None
226 ) -> StreamResult:
227 """Stream records into memory."""
228 # Use the default implementation from mixin
229 return await self._default_stream_write(records, config)
231 async def vector_search(
232 self,
233 query_vector,
234 vector_field: str = "embedding",
235 k: int = 10,
236 filter=None,
237 metric=None,
238 **kwargs
239 ):
240 """Perform vector similarity search using Python calculations."""
241 return await self.python_vector_search_async(
242 query_vector=query_vector,
243 vector_field=vector_field,
244 k=k,
245 filter=filter,
246 metric=metric,
247 **kwargs
248 )
251class SyncMemoryDatabase( # type: ignore[misc]
252 SyncDatabase,
253 StreamingMixin,
254 ConfigurableBase,
255 VectorConfigMixin,
256 SQLiteVectorSupport,
257 PythonVectorSearchMixin,
258 BulkEmbedMixin,
259 VectorOperationsMixin
260):
261 """Synchronous in-memory database implementation."""
263 def __init__(self, config: dict[str, Any] | None = None):
264 super().__init__(config)
265 self._storage: OrderedDict[str, Record] = OrderedDict()
266 self._lock = threading.RLock()
268 # Initialize vector support
269 self._parse_vector_config(config or {})
270 self._init_vector_state()
272 @classmethod
273 def from_config(cls, config: dict) -> SyncMemoryDatabase:
274 """Create from config dictionary."""
275 return cls(config)
278 def _generate_id(self) -> str:
279 """Generate a unique ID for a record."""
280 return str(uuid.uuid4())
282 def create(self, record: Record) -> str:
283 """Create a new record in memory."""
284 with self._lock:
285 # Use record's ID if it has one, otherwise generate a new one
286 id = record.id if record.id else self._generate_id()
287 self._storage[id] = record.copy(deep=True)
288 return id
290 def read(self, id: str) -> Record | None:
291 """Read a record from memory."""
292 with self._lock:
293 record = self._storage.get(id)
294 return record.copy(deep=True) if record else None
296 def update(self, id: str, record: Record) -> bool:
297 """Update a record in memory."""
298 with self._lock:
299 if id in self._storage:
300 self._storage[id] = record.copy(deep=True)
301 return True
302 return False
304 def delete(self, id: str) -> bool:
305 """Delete a record from memory."""
306 with self._lock:
307 if id in self._storage:
308 del self._storage[id]
309 return True
310 return False
312 def exists(self, id: str) -> bool:
313 """Check if a record exists in memory."""
314 with self._lock:
315 return id in self._storage
317 def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str:
318 """Update or insert a record with the specified ID.
320 Overrides base class to handle memory-specific storage.
321 """
322 # Use base class logic to determine ID and record
323 if isinstance(id_or_record, str):
324 id = id_or_record
325 if record is None:
326 raise ValueError("Record required when ID is provided")
327 else:
328 record = id_or_record
329 id = record.id
330 if id is None:
331 import uuid # type: ignore[unreachable]
332 id = str(uuid.uuid4())
333 record.storage_id = id
335 # Memory-specific implementation
336 with self._lock:
337 self._storage[id] = record.copy(deep=True)
338 return id
340 def search(self, query: Query | ComplexQuery) -> list[Record]:
341 """Search for records matching the query."""
342 # Handle ComplexQuery using base class implementation
343 if isinstance(query, ComplexQuery):
344 return self._search_with_complex_query(query)
346 with self._lock:
347 results = []
349 for id, record in self._storage.items():
350 # Apply filters
351 matches = True
352 for filter in query.filters:
353 # Special handling for 'id' field
354 if filter.field == 'id':
355 field_value = id
356 else:
357 field_value = record.get_value(filter.field)
358 if not filter.matches(field_value):
359 matches = False
360 break
362 if matches:
363 results.append((id, record))
365 # Use the helper method from base class
366 return self._process_search_results(results, query, deep_copy=True)
368 def _count_all(self) -> int:
369 """Count all records in memory."""
370 with self._lock:
371 return len(self._storage)
373 def clear(self) -> int:
374 """Clear all records from memory."""
375 with self._lock:
376 count = len(self._storage)
377 self._storage.clear()
378 return count
380 def create_batch(self, records: list[Record]) -> list[str]:
381 """Create multiple records efficiently."""
382 with self._lock:
383 ids = []
384 for record in records:
385 # Use record's ID if it has one, otherwise generate a new one
386 id = record.id if record.id else self._generate_id()
387 self._storage[id] = record.copy(deep=True)
388 ids.append(id)
389 return ids
391 def read_batch(self, ids: list[str]) -> list[Record | None]:
392 """Read multiple records efficiently."""
393 with self._lock:
394 results = []
395 for id in ids:
396 record = self._storage.get(id)
397 results.append(record.copy(deep=True) if record else None)
398 return results
400 def delete_batch(self, ids: list[str]) -> list[bool]:
401 """Delete multiple records efficiently."""
402 with self._lock:
403 results = []
404 for id in ids:
405 if id in self._storage:
406 del self._storage[id]
407 results.append(True)
408 else:
409 results.append(False)
410 return results
412 def stream_read(
413 self,
414 query: Query | None = None,
415 config: StreamConfig | None = None
416 ) -> Iterator[Record]:
417 """Stream records from memory."""
418 config = config or StreamConfig()
420 # Get all matching records
421 if query:
422 records = self.search(query)
423 else:
424 with self._lock:
425 # Ensure records have IDs when getting directly from storage
426 records = []
427 for record_id, record in self._storage.items():
428 record_copy = self._ensure_record_id(record, record_id)
429 records.append(record_copy)
431 # Yield records in batches
432 for i in range(0, len(records), config.batch_size):
433 batch = records[i:i + config.batch_size]
434 for record in batch:
435 yield record.copy(deep=True)
437 def stream_write(
438 self,
439 records: Iterator[Record],
440 config: StreamConfig | None = None
441 ) -> StreamResult:
442 """Stream records into memory."""
443 # Use the default implementation from mixin
444 return self._default_stream_write(records, config)
446 def vector_search(
447 self,
448 query_vector,
449 vector_field: str = "embedding",
450 k: int = 10,
451 filter=None,
452 metric=None,
453 **kwargs
454 ):
455 """Perform vector similarity search using Python calculations."""
456 return self.python_vector_search_sync(
457 query_vector=query_vector,
458 vector_field=vector_field,
459 k=k,
460 filter=filter,
461 metric=metric,
462 **kwargs
463 )