Coverage for src/dataknobs_data/database.py: 28%
279 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:15 -0600
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:15 -0600
1from __future__ import annotations
3from abc import ABC, abstractmethod
4from typing import Any, TYPE_CHECKING
6from .database_utils import ensure_record_id, process_search_results
7from .query import Query
8from .schema import DatabaseSchema, FieldSchema
10if TYPE_CHECKING:
11 from collections.abc import AsyncIterator, Callable, Iterator
12 from .query_logic import ComplexQuery
13 from .records import Record
14 from .streaming import StreamConfig, StreamResult
17class AsyncDatabase(ABC):
18 """Abstract base class for async database implementations."""
20 def __init__(self, config: dict[str, Any] | None = None, schema: DatabaseSchema | None = None):
21 """Initialize the database with optional configuration.
23 Args:
24 config: Backend-specific configuration parameters (may include 'schema' key)
25 schema: Optional database schema (overrides config schema)
26 """
27 config = config or {}
29 # Extract schema from config if present and no explicit schema provided
30 if schema is None and "schema" in config:
31 schema = self._extract_schema_from_config(config["schema"])
32 # Remove schema from config so backends don't see it
33 config = {k: v for k, v in config.items() if k != "schema"}
35 self.config = config
36 self.schema = schema or DatabaseSchema()
37 self._initialize()
39 @staticmethod
40 def _extract_schema_from_config(schema_config: Any) -> DatabaseSchema | None:
41 """Extract schema from configuration.
43 Args:
44 schema_config: Can be a DatabaseSchema, dict, or None
46 Returns:
47 DatabaseSchema instance or None
48 """
49 if isinstance(schema_config, DatabaseSchema):
50 return schema_config
51 elif isinstance(schema_config, dict):
52 return DatabaseSchema.from_dict(schema_config)
53 return None
55 def _initialize(self) -> None: # noqa: B027
56 """Initialize the database backend. Override in subclasses if needed."""
57 # Default implementation does nothing - backends can override if needed
59 def _ensure_record_id(self, record: Record, record_id: str) -> Record:
60 """Ensure a record has its ID set (delegates to utility function)."""
61 return ensure_record_id(record, record_id)
63 def _prepare_record_for_storage(self, record: Record) -> tuple[Record, str]:
64 """Prepare a record for storage by ensuring it has a storage_id.
66 Args:
67 record: The record to prepare
69 Returns:
70 Tuple of (prepared_record_copy, storage_id)
71 """
72 import uuid
73 # Make a copy to avoid modifying the original
74 record_copy = record.copy(deep=True)
76 # Generate storage ID if not present
77 if not record_copy.has_storage_id():
78 storage_id = str(uuid.uuid4())
79 record_copy.storage_id = storage_id
80 else:
81 storage_id = record_copy.storage_id
83 return record_copy, storage_id
85 def _prepare_record_from_storage(self, record: Record | None, storage_id: str) -> Record | None:
86 """Prepare a record retrieved from storage by ensuring storage_id is set.
88 Args:
89 record: The record retrieved from storage (or None)
90 storage_id: The storage ID used to retrieve the record
92 Returns:
93 Record with storage_id set, or None if record was None
94 """
95 if record:
96 record_copy = record.copy(deep=True)
97 # Ensure storage_id is set
98 if not record_copy.has_storage_id():
99 record_copy.storage_id = storage_id
100 return record_copy
101 return None
103 def _process_search_results(
104 self,
105 results: list[tuple[str, Record]],
106 query: Query,
107 deep_copy: bool = True
108 ) -> list[Record]:
109 """Process search results (delegates to utility function)."""
110 return process_search_results(results, query, deep_copy)
112 def set_schema(self, schema: DatabaseSchema) -> None:
113 """Set the database schema.
115 Args:
116 schema: The database schema to use
117 """
118 self.schema = schema
120 def add_field_schema(self, field_schema: FieldSchema) -> None:
121 """Add a field to the database schema.
123 Args:
124 field_schema: The field schema to add
125 """
126 self.schema.add_field(field_schema)
128 def with_schema(self, **field_definitions) -> AsyncDatabase:
129 """Set schema using field definitions.
131 Returns self for chaining.
133 Examples:
134 db = AsyncMemoryDatabase().with_schema(
135 content=FieldType.TEXT,
136 embedding=(FieldType.VECTOR, {"dimensions": 384, "source_field": "content"})
137 )
138 """
139 self.schema = DatabaseSchema.create(**field_definitions)
140 return self
142 @abstractmethod
143 async def create(self, record: Record) -> str:
144 """Create a new record in the database.
146 Args:
147 record: The record to create
149 Returns:
150 The ID of the created record
151 """
152 raise NotImplementedError
154 @abstractmethod
155 async def read(self, id: str) -> Record | None:
156 """Read a record by ID.
158 Args:
159 id: The record ID
161 Returns:
162 The record if found, None otherwise
163 """
164 raise NotImplementedError
166 @abstractmethod
167 async def update(self, id: str, record: Record) -> bool:
168 """Update an existing record.
170 Args:
171 id: The record ID
172 record: The updated record
174 Returns:
175 True if the record was updated, False if not found
176 """
177 raise NotImplementedError
179 @abstractmethod
180 async def delete(self, id: str) -> bool:
181 """Delete a record by ID.
183 Args:
184 id: The record ID
186 Returns:
187 True if the record was deleted, False if not found
188 """
189 raise NotImplementedError
191 @abstractmethod
192 async def search(self, query: Query | ComplexQuery) -> list[Record]:
193 """Search for records matching a query.
195 Args:
196 query: The search query (simple or complex)
198 Returns:
199 List of matching records
200 """
201 raise NotImplementedError
203 async def all(self) -> list[Record]:
204 """Get all records from the database.
206 Returns:
207 List of all records
208 """
209 # Default implementation using search with empty query
210 from .query import Query
211 return await self.search(Query())
213 async def _search_with_complex_query(self, query: ComplexQuery) -> list[Record]:
214 """Default implementation for ComplexQuery using in-memory filtering.
216 Backends can override this for native boolean logic support.
218 Args:
219 query: Complex query with boolean logic
221 Returns:
222 List of matching records
223 """
224 # Try to convert to simple query if possible
225 try:
226 simple_query = query.to_simple_query()
227 return await self.search(simple_query)
228 except ValueError:
229 # Can't convert - need to do in-memory filtering
230 # Get all records (or use a base filter if possible)
231 all_records = await self.search(Query())
233 # Apply complex condition filtering
234 results = []
235 for record in all_records:
236 if query.matches(record):
237 results.append(record)
239 # Apply sorting
240 if query.sort_specs:
241 for sort_spec in reversed(query.sort_specs):
242 reverse = sort_spec.order.value == "desc"
243 results.sort(
244 key=lambda r: r.get_value(sort_spec.field, ""),
245 reverse=reverse
246 )
248 # Apply offset and limit
249 if query.offset_value:
250 results = results[query.offset_value:]
251 if query.limit_value:
252 results = results[:query.limit_value]
254 # Apply field projection
255 if query.fields:
256 results = [r.project(query.fields) for r in results]
258 return results
260 @abstractmethod
261 async def exists(self, id: str) -> bool:
262 """Check if a record exists.
264 Args:
265 id: The record ID
267 Returns:
268 True if the record exists, False otherwise
269 """
270 raise NotImplementedError
272 async def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str:
273 """Update or insert a record.
275 Can be called as:
276 - upsert(id, record) - explicit ID and record
277 - upsert(record) - extract ID from record using Record's built-in logic
279 Args:
280 id_or_record: Either an ID string or a Record
281 record: The record to upsert (if first arg is ID)
283 Returns:
284 The record ID
285 """
286 import uuid
288 # Determine ID and record based on arguments
289 if isinstance(id_or_record, str):
290 # Called with explicit ID: upsert(id, record)
291 id = id_or_record
292 if record is None:
293 raise ValueError("Record required when ID is provided")
294 else:
295 # Called with just record: upsert(record)
296 record = id_or_record
297 # Use Record's built-in ID property which handles all the logic
298 id = record.id
300 if id is None:
301 # Generate a new ID if none found
302 id = str(uuid.uuid4()) # type: ignore[unreachable]
303 # Set it on the record for future reference
304 record.storage_id = id
306 # Now perform the upsert
307 if await self.exists(id):
308 await self.update(id, record)
309 else:
310 # Ensure the record has the storage_id set for create
311 if not record.storage_id:
312 record.storage_id = id
313 created_id = await self.create(record)
314 # Return the created ID (might be different from what we provided)
315 return created_id or id
316 return id
318 async def create_batch(self, records: list[Record]) -> list[str]:
319 """Create multiple records in batch.
321 Args:
322 records: List of records to create
324 Returns:
325 List of created record IDs
326 """
327 ids = []
328 for record in records:
329 id = await self.create(record)
330 ids.append(id)
331 return ids
333 async def read_batch(self, ids: list[str]) -> list[Record | None]:
334 """Read multiple records by ID.
336 Args:
337 ids: List of record IDs
339 Returns:
340 List of records (None for not found)
341 """
342 records = []
343 for id in ids:
344 record = await self.read(id)
345 records.append(record)
346 return records
348 async def delete_batch(self, ids: list[str]) -> list[bool]:
349 """Delete multiple records by ID.
351 Args:
352 ids: List of record IDs
354 Returns:
355 List of deletion results
356 """
357 results = []
358 for id in ids:
359 result = await self.delete(id)
360 results.append(result)
361 return results
363 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
364 """Update multiple records.
366 Default implementation calls update() for each ID/record pair.
367 Override for better performance.
369 Args:
370 updates: List of (id, record) tuples to update
372 Returns:
373 List of success flags for each update
374 """
375 results = []
376 for id, record in updates:
377 result = await self.update(id, record)
378 results.append(result)
379 return results
381 async def count(self, query: Query | None = None) -> int:
382 """Count records matching a query.
384 Args:
385 query: Optional search query (counts all if None)
387 Returns:
388 Number of matching records
389 """
390 if query:
391 results = await self.search(query)
392 return len(results)
393 else:
394 return await self._count_all()
396 @abstractmethod
397 async def _count_all(self) -> int:
398 """Count all records in the database."""
399 raise NotImplementedError
401 async def clear(self) -> int:
402 """Clear all records from the database.
404 Returns:
405 Number of records deleted
406 """
407 raise NotImplementedError
409 async def connect(self) -> None: # noqa: B027
410 """Connect to the database. Override in subclasses if needed."""
411 # Default implementation does nothing - many backends don't need explicit connection
413 async def close(self) -> None: # noqa: B027
414 """Close the database connection. Override in subclasses if needed."""
415 # Default implementation does nothing - many backends don't need explicit closing
417 async def disconnect(self) -> None:
418 """Disconnect from the database (alias for close)."""
419 await self.close()
421 async def __aenter__(self):
422 """Async context manager entry."""
423 await self.connect()
424 return self
426 async def __aexit__(self, exc_type, exc_val, exc_tb):
427 """Async context manager exit."""
428 await self.close()
430 @abstractmethod
431 async def stream_read(
432 self,
433 query: Query | None = None,
434 config: StreamConfig | None = None
435 ) -> AsyncIterator[Record]:
436 """Stream records from database.
438 Yields records one at a time, fetching in batches internally.
440 Args:
441 query: Optional query to filter records
442 config: Streaming configuration
444 Yields:
445 Records matching the query
446 """
447 raise NotImplementedError
449 @abstractmethod
450 async def stream_write(
451 self,
452 records: AsyncIterator[Record],
453 config: StreamConfig | None = None
454 ) -> StreamResult:
455 """Stream records into database.
457 Accepts an iterator and writes in batches.
459 Args:
460 records: Iterator of records to write
461 config: Streaming configuration
463 Returns:
464 Result of the streaming operation
465 """
466 raise NotImplementedError
468 async def stream_transform(
469 self,
470 query: Query | None = None,
471 transform: Callable[[Record], Record | None] | None = None,
472 config: StreamConfig | None = None
473 ) -> AsyncIterator[Record]:
474 """Stream records through a transformation.
476 Default implementation, can be overridden for efficiency.
478 Args:
479 query: Optional query to filter records
480 transform: Optional transformation function
481 config: Streaming configuration
483 Yields:
484 Transformed records
485 """
486 async for record in self.stream_read(query, config):
487 if transform:
488 transformed = transform(record)
489 if transformed: # None means filter out
490 yield transformed
491 else:
492 yield record
494 @classmethod
495 async def from_backend(cls, backend: str, config: dict[str, Any] | None = None) -> AsyncDatabase:
496 """Factory method to create and connect a database instance.
498 Args:
499 backend: The backend type ("memory", "file", "s3", "postgres", "elasticsearch")
500 config: Backend-specific configuration
502 Returns:
503 Connected AsyncDatabase instance
504 """
505 from .backends import BACKEND_REGISTRY
507 backend_class = BACKEND_REGISTRY.get(backend)
508 if not backend_class:
509 raise ValueError(
510 f"Unknown backend: {backend}. Available: {list(BACKEND_REGISTRY.keys())}"
511 )
513 instance = backend_class(config)
514 await instance.connect()
515 return instance
518class SyncDatabase(ABC):
519 """Synchronous variant of the Database abstract base class."""
521 def __init__(self, config: dict[str, Any] | None = None, schema: DatabaseSchema | None = None):
522 """Initialize the database with optional configuration.
524 Args:
525 config: Backend-specific configuration parameters (may include 'schema' key)
526 schema: Optional database schema (overrides config schema)
527 """
528 config = config or {}
530 # Extract schema from config if present and no explicit schema provided
531 if schema is None and "schema" in config:
532 schema = AsyncDatabase._extract_schema_from_config(config["schema"])
533 # Remove schema from config so backends don't see it
534 config = {k: v for k, v in config.items() if k != "schema"}
536 self.config = config
537 self.schema = schema or DatabaseSchema()
538 self._initialize()
540 def _initialize(self) -> None: # noqa: B027
541 """Initialize the database backend. Override in subclasses if needed."""
542 # Default implementation does nothing - backends can override if needed
544 def _ensure_record_id(self, record: Record, record_id: str) -> Record:
545 """Ensure a record has its ID set (delegates to utility function)."""
546 return ensure_record_id(record, record_id)
548 def _prepare_record_for_storage(self, record: Record) -> tuple[Record, str]:
549 """Prepare a record for storage by ensuring it has a storage_id.
551 Args:
552 record: The record to prepare
554 Returns:
555 Tuple of (prepared_record_copy, storage_id)
556 """
557 import uuid
558 # Make a copy to avoid modifying the original
559 record_copy = record.copy(deep=True)
561 # Generate storage ID if not present
562 if not record_copy.has_storage_id():
563 storage_id = str(uuid.uuid4())
564 record_copy.storage_id = storage_id
565 else:
566 storage_id = record_copy.storage_id
568 return record_copy, storage_id
570 def _prepare_record_from_storage(self, record: Record | None, storage_id: str) -> Record | None:
571 """Prepare a record retrieved from storage by ensuring storage_id is set.
573 Args:
574 record: The record retrieved from storage (or None)
575 storage_id: The storage ID used to retrieve the record
577 Returns:
578 Record with storage_id set, or None if record was None
579 """
580 if record:
581 record_copy = record.copy(deep=True)
582 # Ensure storage_id is set
583 if not record_copy.has_storage_id():
584 record_copy.storage_id = storage_id
585 return record_copy
586 return None
588 def _process_search_results(
589 self,
590 results: list[tuple[str, Record]],
591 query: Query,
592 deep_copy: bool = True
593 ) -> list[Record]:
594 """Process search results (delegates to utility function)."""
595 return process_search_results(results, query, deep_copy)
597 def set_schema(self, schema: DatabaseSchema) -> None:
598 """Set the database schema.
600 Args:
601 schema: The database schema to use
602 """
603 self.schema = schema
605 def add_field_schema(self, field_schema: FieldSchema) -> None:
606 """Add a field to the database schema.
608 Args:
609 field_schema: The field schema to add
610 """
611 self.schema.add_field(field_schema)
613 def with_schema(self, **field_definitions) -> SyncDatabase:
614 """Set schema using field definitions.
616 Returns self for chaining.
618 Examples:
619 db = SyncMemoryDatabase().with_schema(
620 content=FieldType.TEXT,
621 embedding=(FieldType.VECTOR, {"dimensions": 384, "source_field": "content"})
622 )
623 """
624 self.schema = DatabaseSchema.create(**field_definitions)
625 return self
627 @abstractmethod
628 def create(self, record: Record) -> str:
629 """Create a new record in the database."""
630 raise NotImplementedError
632 @abstractmethod
633 def read(self, id: str) -> Record | None:
634 """Read a record by ID."""
635 raise NotImplementedError
637 @abstractmethod
638 def update(self, id: str, record: Record) -> bool:
639 """Update an existing record."""
640 raise NotImplementedError
642 @abstractmethod
643 def delete(self, id: str) -> bool:
644 """Delete a record by ID."""
645 raise NotImplementedError
647 @abstractmethod
648 def search(self, query: Query | ComplexQuery) -> list[Record]:
649 """Search for records matching a query (simple or complex)."""
650 raise NotImplementedError
652 def all(self) -> list[Record]:
653 """Get all records from the database.
655 Returns:
656 List of all records
657 """
658 # Default implementation using search with empty query
659 from .query import Query
660 return self.search(Query())
662 def _search_with_complex_query(self, query: ComplexQuery) -> list[Record]:
663 """Default implementation for ComplexQuery using in-memory filtering.
665 Backends can override this for native boolean logic support.
667 Args:
668 query: Complex query with boolean logic
670 Returns:
671 List of matching records
672 """
673 # Try to convert to simple query if possible
674 try:
675 simple_query = query.to_simple_query()
676 return self.search(simple_query)
677 except ValueError:
678 # Can't convert - need to do in-memory filtering
679 # Get all records (or use a base filter if possible)
680 all_records = self.search(Query())
682 # Apply complex condition filtering
683 results = []
684 for record in all_records:
685 if query.matches(record):
686 results.append(record)
688 # Apply sorting
689 if query.sort_specs:
690 for sort_spec in reversed(query.sort_specs):
691 reverse = sort_spec.order.value == "desc"
692 results.sort(
693 key=lambda r: r.get_value(sort_spec.field, ""),
694 reverse=reverse
695 )
697 # Apply offset and limit
698 if query.offset_value:
699 results = results[query.offset_value:]
700 if query.limit_value:
701 results = results[:query.limit_value]
703 # Apply field projection
704 if query.fields:
705 results = [r.project(query.fields) for r in results]
707 return results
709 @abstractmethod
710 def exists(self, id: str) -> bool:
711 """Check if a record exists."""
712 raise NotImplementedError
714 def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str:
715 """Update or insert a record.
717 Can be called as:
718 - upsert(id, record) - explicit ID and record
719 - upsert(record) - extract ID from record using Record's built-in logic
721 Args:
722 id_or_record: Either an ID string or a Record
723 record: The record to upsert (if first arg is ID)
725 Returns:
726 The record ID
727 """
728 import uuid
730 # Determine ID and record based on arguments
731 if isinstance(id_or_record, str):
732 # Called with explicit ID: upsert(id, record)
733 id = id_or_record
734 if record is None:
735 raise ValueError("Record required when ID is provided")
736 else:
737 # Called with just record: upsert(record)
738 record = id_or_record
739 # Use Record's built-in ID property which handles all the logic
740 id = record.id
742 if id is None:
743 # Generate a new ID if none found
744 id = str(uuid.uuid4()) # type: ignore[unreachable]
745 # Set it on the record for future reference
746 record.storage_id = id
748 # Now perform the upsert
749 if self.exists(id):
750 self.update(id, record)
751 else:
752 # Ensure the record has the storage_id set for create
753 if not record.storage_id:
754 record.storage_id = id
755 created_id = self.create(record)
756 # Return the created ID (might be different from what we provided)
757 return created_id or id
758 return id
760 def create_batch(self, records: list[Record]) -> list[str]:
761 """Create multiple records in batch."""
762 ids = []
763 for record in records:
764 id = self.create(record)
765 ids.append(id)
766 return ids
768 def read_batch(self, ids: list[str]) -> list[Record | None]:
769 """Read multiple records by ID."""
770 records = []
771 for id in ids:
772 record = self.read(id)
773 records.append(record)
774 return records
776 def delete_batch(self, ids: list[str]) -> list[bool]:
777 """Delete multiple records by ID."""
778 results = []
779 for id in ids:
780 result = self.delete(id)
781 results.append(result)
782 return results
784 def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
785 """Update multiple records.
787 Default implementation calls update() for each ID/record pair.
788 Override for better performance.
790 Args:
791 updates: List of (id, record) tuples to update
793 Returns:
794 List of success flags for each update
795 """
796 results = []
797 for id, record in updates:
798 result = self.update(id, record)
799 results.append(result)
800 return results
802 def count(self, query: Query | None = None) -> int:
803 """Count records matching a query."""
804 if query:
805 results = self.search(query)
806 return len(results)
807 else:
808 return self._count_all()
810 @abstractmethod
811 def _count_all(self) -> int:
812 """Count all records in the database."""
813 raise NotImplementedError
815 def clear(self) -> int:
816 """Clear all records from the database."""
817 raise NotImplementedError
819 def connect(self) -> None: # noqa: B027
820 """Connect to the database. Override in subclasses if needed."""
821 # Default implementation does nothing - many backends don't need explicit connection
823 def close(self) -> None: # noqa: B027
824 """Close the database connection. Override in subclasses if needed."""
825 # Default implementation does nothing - many backends don't need explicit closing
827 def disconnect(self) -> None:
828 """Disconnect from the database (alias for close)."""
829 self.close()
831 def __enter__(self):
832 """Context manager entry."""
833 self.connect()
834 return self
836 def __exit__(self, exc_type, exc_val, exc_tb):
837 """Context manager exit."""
838 self.close()
840 @abstractmethod
841 def stream_read(
842 self,
843 query: Query | None = None,
844 config: StreamConfig | None = None
845 ) -> Iterator[Record]:
846 """Stream records from database.
848 Yields records one at a time, fetching in batches internally.
850 Args:
851 query: Optional query to filter records
852 config: Streaming configuration
854 Yields:
855 Records matching the query
856 """
857 raise NotImplementedError
859 @abstractmethod
860 def stream_write(
861 self,
862 records: Iterator[Record],
863 config: StreamConfig | None = None
864 ) -> StreamResult:
865 """Stream records into database.
867 Accepts an iterator and writes in batches.
869 Args:
870 records: Iterator of records to write
871 config: Streaming configuration
873 Returns:
874 Result of the streaming operation
875 """
876 raise NotImplementedError
878 def stream_transform(
879 self,
880 query: Query | None = None,
881 transform: Callable[[Record], Record | None] | None = None,
882 config: StreamConfig | None = None
883 ) -> Iterator[Record]:
884 """Stream records through a transformation.
886 Default implementation, can be overridden for efficiency.
888 Args:
889 query: Optional query to filter records
890 transform: Optional transformation function
891 config: Streaming configuration
893 Yields:
894 Transformed records
895 """
896 for record in self.stream_read(query, config):
897 if transform:
898 transformed = transform(record)
899 if transformed: # None means filter out
900 yield transformed
901 else:
902 yield record
904 @classmethod
905 def from_backend(cls, backend: str, config: dict[str, Any] | None = None) -> SyncDatabase:
906 """Factory method to create and connect a synchronous database instance.
908 Args:
909 backend: The backend type ("memory", "file", "s3", "postgres", "elasticsearch")
910 config: Backend-specific configuration
912 Returns:
913 Connected SyncDatabase instance
914 """
915 from .backends import SYNC_BACKEND_REGISTRY
917 backend_class = SYNC_BACKEND_REGISTRY.get(backend)
918 if not backend_class:
919 raise ValueError(
920 f"Unknown backend: {backend}. Available: {list(SYNC_BACKEND_REGISTRY.keys())}"
921 )
923 instance = backend_class(config)
924 instance.connect()
925 return instance