Coverage for src/dataknobs_data/database.py: 27%
253 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-31 15:06 -0600
1from __future__ import annotations
3from abc import ABC, abstractmethod
4from typing import Any, TYPE_CHECKING
6from .database_utils import ensure_record_id, process_search_results
7from .query import Query
8from .schema import DatabaseSchema, FieldSchema
10if TYPE_CHECKING:
11 from collections.abc import AsyncIterator, Callable, Iterator
12 from .query_logic import ComplexQuery
13 from .records import Record
14 from .streaming import StreamConfig, StreamResult
17class AsyncDatabase(ABC):
18 """Abstract base class for async database implementations."""
20 def __init__(self, config: dict[str, Any] | None = None, schema: DatabaseSchema | None = None):
21 """Initialize the database with optional configuration.
23 Args:
24 config: Backend-specific configuration parameters (may include 'schema' key)
25 schema: Optional database schema (overrides config schema)
26 """
27 config = config or {}
29 # Extract schema from config if present and no explicit schema provided
30 if schema is None and "schema" in config:
31 schema = self._extract_schema_from_config(config["schema"])
32 # Remove schema from config so backends don't see it
33 config = {k: v for k, v in config.items() if k != "schema"}
35 self.config = config
36 self.schema = schema or DatabaseSchema()
37 self._initialize()
39 @staticmethod
40 def _extract_schema_from_config(schema_config: Any) -> DatabaseSchema | None:
41 """Extract schema from configuration.
43 Args:
44 schema_config: Can be a DatabaseSchema, dict, or None
46 Returns:
47 DatabaseSchema instance or None
48 """
49 if isinstance(schema_config, DatabaseSchema):
50 return schema_config
51 elif isinstance(schema_config, dict):
52 return DatabaseSchema.from_dict(schema_config)
53 return None
55 def _initialize(self) -> None: # noqa: B027
56 """Initialize the database backend. Override in subclasses if needed."""
57 # Default implementation does nothing - backends can override if needed
59 def _ensure_record_id(self, record: Record, record_id: str) -> Record:
60 """Ensure a record has its ID set (delegates to utility function)."""
61 return ensure_record_id(record, record_id)
63 def _prepare_record_for_storage(self, record: Record) -> tuple[Record, str]:
64 """Prepare a record for storage by ensuring it has a storage_id.
66 Args:
67 record: The record to prepare
69 Returns:
70 Tuple of (prepared_record_copy, storage_id)
71 """
72 import uuid
73 # Make a copy to avoid modifying the original
74 record_copy = record.copy(deep=True)
76 # Generate storage ID if not present
77 if not record_copy.has_storage_id():
78 storage_id = str(uuid.uuid4())
79 record_copy.storage_id = storage_id
80 else:
81 storage_id = record_copy.storage_id
83 return record_copy, storage_id
85 def _prepare_record_from_storage(self, record: Record | None, storage_id: str) -> Record | None:
86 """Prepare a record retrieved from storage by ensuring storage_id is set.
88 Args:
89 record: The record retrieved from storage (or None)
90 storage_id: The storage ID used to retrieve the record
92 Returns:
93 Record with storage_id set, or None if record was None
94 """
95 if record:
96 record_copy = record.copy(deep=True)
97 # Ensure storage_id is set
98 if not record_copy.has_storage_id():
99 record_copy.storage_id = storage_id
100 return record_copy
101 return None
103 def _process_search_results(
104 self,
105 results: list[tuple[str, Record]],
106 query: Query,
107 deep_copy: bool = True
108 ) -> list[Record]:
109 """Process search results (delegates to utility function)."""
110 return process_search_results(results, query, deep_copy)
112 def set_schema(self, schema: DatabaseSchema) -> None:
113 """Set the database schema.
115 Args:
116 schema: The database schema to use
117 """
118 self.schema = schema
120 def add_field_schema(self, field_schema: FieldSchema) -> None:
121 """Add a field to the database schema.
123 Args:
124 field_schema: The field schema to add
125 """
126 self.schema.add_field(field_schema)
128 def with_schema(self, **field_definitions) -> AsyncDatabase:
129 """Set schema using field definitions.
131 Returns self for chaining.
133 Examples:
134 db = AsyncMemoryDatabase().with_schema(
135 content=FieldType.TEXT,
136 embedding=(FieldType.VECTOR, {"dimensions": 384, "source_field": "content"})
137 )
138 """
139 self.schema = DatabaseSchema.create(**field_definitions)
140 return self
142 @abstractmethod
143 async def create(self, record: Record) -> str:
144 """Create a new record in the database.
146 Args:
147 record: The record to create
149 Returns:
150 The ID of the created record
151 """
152 raise NotImplementedError
154 @abstractmethod
155 async def read(self, id: str) -> Record | None:
156 """Read a record by ID.
158 Args:
159 id: The record ID
161 Returns:
162 The record if found, None otherwise
163 """
164 raise NotImplementedError
166 @abstractmethod
167 async def update(self, id: str, record: Record) -> bool:
168 """Update an existing record.
170 Args:
171 id: The record ID
172 record: The updated record
174 Returns:
175 True if the record was updated, False if not found
176 """
177 raise NotImplementedError
179 @abstractmethod
180 async def delete(self, id: str) -> bool:
181 """Delete a record by ID.
183 Args:
184 id: The record ID
186 Returns:
187 True if the record was deleted, False if not found
188 """
189 raise NotImplementedError
191 @abstractmethod
192 async def search(self, query: Query | ComplexQuery) -> list[Record]:
193 """Search for records matching a query.
195 Args:
196 query: The search query (simple or complex)
198 Returns:
199 List of matching records
200 """
201 raise NotImplementedError
203 async def all(self) -> list[Record]:
204 """Get all records from the database.
206 Returns:
207 List of all records
208 """
209 # Default implementation using search with empty query
210 from .query import Query
211 return await self.search(Query())
213 async def _search_with_complex_query(self, query: ComplexQuery) -> list[Record]:
214 """Default implementation for ComplexQuery using in-memory filtering.
216 Backends can override this for native boolean logic support.
218 Args:
219 query: Complex query with boolean logic
221 Returns:
222 List of matching records
223 """
224 # Try to convert to simple query if possible
225 try:
226 simple_query = query.to_simple_query()
227 return await self.search(simple_query)
228 except ValueError:
229 # Can't convert - need to do in-memory filtering
230 # Get all records (or use a base filter if possible)
231 all_records = await self.search(Query())
233 # Apply complex condition filtering
234 results = []
235 for record in all_records:
236 if query.matches(record):
237 results.append(record)
239 # Apply sorting
240 if query.sort_specs:
241 for sort_spec in reversed(query.sort_specs):
242 reverse = sort_spec.order.value == "desc"
243 results.sort(
244 key=lambda r: r.get_value(sort_spec.field, ""),
245 reverse=reverse
246 )
248 # Apply offset and limit
249 if query.offset_value:
250 results = results[query.offset_value:]
251 if query.limit_value:
252 results = results[:query.limit_value]
254 # Apply field projection
255 if query.fields:
256 results = [r.project(query.fields) for r in results]
258 return results
260 @abstractmethod
261 async def exists(self, id: str) -> bool:
262 """Check if a record exists.
264 Args:
265 id: The record ID
267 Returns:
268 True if the record exists, False otherwise
269 """
270 raise NotImplementedError
272 async def upsert(self, id: str, record: Record) -> str:
273 """Update or insert a record.
275 Args:
276 id: The record ID
277 record: The record to upsert
279 Returns:
280 The record ID
281 """
282 if await self.exists(id):
283 await self.update(id, record)
284 else:
285 return await self.create(record)
286 return id
288 async def create_batch(self, records: list[Record]) -> list[str]:
289 """Create multiple records in batch.
291 Args:
292 records: List of records to create
294 Returns:
295 List of created record IDs
296 """
297 ids = []
298 for record in records:
299 id = await self.create(record)
300 ids.append(id)
301 return ids
303 async def read_batch(self, ids: list[str]) -> list[Record | None]:
304 """Read multiple records by ID.
306 Args:
307 ids: List of record IDs
309 Returns:
310 List of records (None for not found)
311 """
312 records = []
313 for id in ids:
314 record = await self.read(id)
315 records.append(record)
316 return records
318 async def delete_batch(self, ids: list[str]) -> list[bool]:
319 """Delete multiple records by ID.
321 Args:
322 ids: List of record IDs
324 Returns:
325 List of deletion results
326 """
327 results = []
328 for id in ids:
329 result = await self.delete(id)
330 results.append(result)
331 return results
333 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
334 """Update multiple records.
336 Default implementation calls update() for each ID/record pair.
337 Override for better performance.
339 Args:
340 updates: List of (id, record) tuples to update
342 Returns:
343 List of success flags for each update
344 """
345 results = []
346 for id, record in updates:
347 result = await self.update(id, record)
348 results.append(result)
349 return results
351 async def count(self, query: Query | None = None) -> int:
352 """Count records matching a query.
354 Args:
355 query: Optional search query (counts all if None)
357 Returns:
358 Number of matching records
359 """
360 if query:
361 results = await self.search(query)
362 return len(results)
363 else:
364 return await self._count_all()
366 @abstractmethod
367 async def _count_all(self) -> int:
368 """Count all records in the database."""
369 raise NotImplementedError
371 async def clear(self) -> int:
372 """Clear all records from the database.
374 Returns:
375 Number of records deleted
376 """
377 raise NotImplementedError
379 async def connect(self) -> None: # noqa: B027
380 """Connect to the database. Override in subclasses if needed."""
381 # Default implementation does nothing - many backends don't need explicit connection
383 async def close(self) -> None: # noqa: B027
384 """Close the database connection. Override in subclasses if needed."""
385 # Default implementation does nothing - many backends don't need explicit closing
387 async def disconnect(self) -> None:
388 """Disconnect from the database (alias for close)."""
389 await self.close()
391 async def __aenter__(self):
392 """Async context manager entry."""
393 await self.connect()
394 return self
396 async def __aexit__(self, exc_type, exc_val, exc_tb):
397 """Async context manager exit."""
398 await self.close()
400 @abstractmethod
401 async def stream_read(
402 self,
403 query: Query | None = None,
404 config: StreamConfig | None = None
405 ) -> AsyncIterator[Record]:
406 """Stream records from database.
408 Yields records one at a time, fetching in batches internally.
410 Args:
411 query: Optional query to filter records
412 config: Streaming configuration
414 Yields:
415 Records matching the query
416 """
417 raise NotImplementedError
419 @abstractmethod
420 async def stream_write(
421 self,
422 records: AsyncIterator[Record],
423 config: StreamConfig | None = None
424 ) -> StreamResult:
425 """Stream records into database.
427 Accepts an iterator and writes in batches.
429 Args:
430 records: Iterator of records to write
431 config: Streaming configuration
433 Returns:
434 Result of the streaming operation
435 """
436 raise NotImplementedError
438 async def stream_transform(
439 self,
440 query: Query | None = None,
441 transform: Callable[[Record], Record | None] | None = None,
442 config: StreamConfig | None = None
443 ) -> AsyncIterator[Record]:
444 """Stream records through a transformation.
446 Default implementation, can be overridden for efficiency.
448 Args:
449 query: Optional query to filter records
450 transform: Optional transformation function
451 config: Streaming configuration
453 Yields:
454 Transformed records
455 """
456 async for record in self.stream_read(query, config):
457 if transform:
458 transformed = transform(record)
459 if transformed: # None means filter out
460 yield transformed
461 else:
462 yield record
464 @classmethod
465 async def from_backend(cls, backend: str, config: dict[str, Any] | None = None) -> AsyncDatabase:
466 """Factory method to create and connect a database instance.
468 Args:
469 backend: The backend type ("memory", "file", "s3", "postgres", "elasticsearch")
470 config: Backend-specific configuration
472 Returns:
473 Connected AsyncDatabase instance
474 """
475 from .backends import BACKEND_REGISTRY
477 backend_class = BACKEND_REGISTRY.get(backend)
478 if not backend_class:
479 raise ValueError(
480 f"Unknown backend: {backend}. Available: {list(BACKEND_REGISTRY.keys())}"
481 )
483 instance = backend_class(config)
484 await instance.connect()
485 return instance
488class SyncDatabase(ABC):
489 """Synchronous variant of the Database abstract base class."""
491 def __init__(self, config: dict[str, Any] | None = None, schema: DatabaseSchema | None = None):
492 """Initialize the database with optional configuration.
494 Args:
495 config: Backend-specific configuration parameters (may include 'schema' key)
496 schema: Optional database schema (overrides config schema)
497 """
498 config = config or {}
500 # Extract schema from config if present and no explicit schema provided
501 if schema is None and "schema" in config:
502 schema = AsyncDatabase._extract_schema_from_config(config["schema"])
503 # Remove schema from config so backends don't see it
504 config = {k: v for k, v in config.items() if k != "schema"}
506 self.config = config
507 self.schema = schema or DatabaseSchema()
508 self._initialize()
510 def _initialize(self) -> None: # noqa: B027
511 """Initialize the database backend. Override in subclasses if needed."""
512 # Default implementation does nothing - backends can override if needed
514 def _ensure_record_id(self, record: Record, record_id: str) -> Record:
515 """Ensure a record has its ID set (delegates to utility function)."""
516 return ensure_record_id(record, record_id)
518 def _prepare_record_for_storage(self, record: Record) -> tuple[Record, str]:
519 """Prepare a record for storage by ensuring it has a storage_id.
521 Args:
522 record: The record to prepare
524 Returns:
525 Tuple of (prepared_record_copy, storage_id)
526 """
527 import uuid
528 # Make a copy to avoid modifying the original
529 record_copy = record.copy(deep=True)
531 # Generate storage ID if not present
532 if not record_copy.has_storage_id():
533 storage_id = str(uuid.uuid4())
534 record_copy.storage_id = storage_id
535 else:
536 storage_id = record_copy.storage_id
538 return record_copy, storage_id
540 def _prepare_record_from_storage(self, record: Record | None, storage_id: str) -> Record | None:
541 """Prepare a record retrieved from storage by ensuring storage_id is set.
543 Args:
544 record: The record retrieved from storage (or None)
545 storage_id: The storage ID used to retrieve the record
547 Returns:
548 Record with storage_id set, or None if record was None
549 """
550 if record:
551 record_copy = record.copy(deep=True)
552 # Ensure storage_id is set
553 if not record_copy.has_storage_id():
554 record_copy.storage_id = storage_id
555 return record_copy
556 return None
558 def _process_search_results(
559 self,
560 results: list[tuple[str, Record]],
561 query: Query,
562 deep_copy: bool = True
563 ) -> list[Record]:
564 """Process search results (delegates to utility function)."""
565 return process_search_results(results, query, deep_copy)
567 def set_schema(self, schema: DatabaseSchema) -> None:
568 """Set the database schema.
570 Args:
571 schema: The database schema to use
572 """
573 self.schema = schema
575 def add_field_schema(self, field_schema: FieldSchema) -> None:
576 """Add a field to the database schema.
578 Args:
579 field_schema: The field schema to add
580 """
581 self.schema.add_field(field_schema)
583 def with_schema(self, **field_definitions) -> SyncDatabase:
584 """Set schema using field definitions.
586 Returns self for chaining.
588 Examples:
589 db = SyncMemoryDatabase().with_schema(
590 content=FieldType.TEXT,
591 embedding=(FieldType.VECTOR, {"dimensions": 384, "source_field": "content"})
592 )
593 """
594 self.schema = DatabaseSchema.create(**field_definitions)
595 return self
597 @abstractmethod
598 def create(self, record: Record) -> str:
599 """Create a new record in the database."""
600 raise NotImplementedError
602 @abstractmethod
603 def read(self, id: str) -> Record | None:
604 """Read a record by ID."""
605 raise NotImplementedError
607 @abstractmethod
608 def update(self, id: str, record: Record) -> bool:
609 """Update an existing record."""
610 raise NotImplementedError
612 @abstractmethod
613 def delete(self, id: str) -> bool:
614 """Delete a record by ID."""
615 raise NotImplementedError
617 @abstractmethod
618 def search(self, query: Query | ComplexQuery) -> list[Record]:
619 """Search for records matching a query (simple or complex)."""
620 raise NotImplementedError
622 def all(self) -> list[Record]:
623 """Get all records from the database.
625 Returns:
626 List of all records
627 """
628 # Default implementation using search with empty query
629 from .query import Query
630 return self.search(Query())
632 def _search_with_complex_query(self, query: ComplexQuery) -> list[Record]:
633 """Default implementation for ComplexQuery using in-memory filtering.
635 Backends can override this for native boolean logic support.
637 Args:
638 query: Complex query with boolean logic
640 Returns:
641 List of matching records
642 """
643 # Try to convert to simple query if possible
644 try:
645 simple_query = query.to_simple_query()
646 return self.search(simple_query)
647 except ValueError:
648 # Can't convert - need to do in-memory filtering
649 # Get all records (or use a base filter if possible)
650 all_records = self.search(Query())
652 # Apply complex condition filtering
653 results = []
654 for record in all_records:
655 if query.matches(record):
656 results.append(record)
658 # Apply sorting
659 if query.sort_specs:
660 for sort_spec in reversed(query.sort_specs):
661 reverse = sort_spec.order.value == "desc"
662 results.sort(
663 key=lambda r: r.get_value(sort_spec.field, ""),
664 reverse=reverse
665 )
667 # Apply offset and limit
668 if query.offset_value:
669 results = results[query.offset_value:]
670 if query.limit_value:
671 results = results[:query.limit_value]
673 # Apply field projection
674 if query.fields:
675 results = [r.project(query.fields) for r in results]
677 return results
679 @abstractmethod
680 def exists(self, id: str) -> bool:
681 """Check if a record exists."""
682 raise NotImplementedError
684 def upsert(self, id: str, record: Record) -> str:
685 """Update or insert a record."""
686 if self.exists(id):
687 self.update(id, record)
688 else:
689 return self.create(record)
690 return id
692 def create_batch(self, records: list[Record]) -> list[str]:
693 """Create multiple records in batch."""
694 ids = []
695 for record in records:
696 id = self.create(record)
697 ids.append(id)
698 return ids
700 def read_batch(self, ids: list[str]) -> list[Record | None]:
701 """Read multiple records by ID."""
702 records = []
703 for id in ids:
704 record = self.read(id)
705 records.append(record)
706 return records
708 def delete_batch(self, ids: list[str]) -> list[bool]:
709 """Delete multiple records by ID."""
710 results = []
711 for id in ids:
712 result = self.delete(id)
713 results.append(result)
714 return results
716 def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
717 """Update multiple records.
719 Default implementation calls update() for each ID/record pair.
720 Override for better performance.
722 Args:
723 updates: List of (id, record) tuples to update
725 Returns:
726 List of success flags for each update
727 """
728 results = []
729 for id, record in updates:
730 result = self.update(id, record)
731 results.append(result)
732 return results
734 def count(self, query: Query | None = None) -> int:
735 """Count records matching a query."""
736 if query:
737 results = self.search(query)
738 return len(results)
739 else:
740 return self._count_all()
742 @abstractmethod
743 def _count_all(self) -> int:
744 """Count all records in the database."""
745 raise NotImplementedError
747 def clear(self) -> int:
748 """Clear all records from the database."""
749 raise NotImplementedError
751 def connect(self) -> None: # noqa: B027
752 """Connect to the database. Override in subclasses if needed."""
753 # Default implementation does nothing - many backends don't need explicit connection
755 def close(self) -> None: # noqa: B027
756 """Close the database connection. Override in subclasses if needed."""
757 # Default implementation does nothing - many backends don't need explicit closing
759 def disconnect(self) -> None:
760 """Disconnect from the database (alias for close)."""
761 self.close()
763 def __enter__(self):
764 """Context manager entry."""
765 self.connect()
766 return self
768 def __exit__(self, exc_type, exc_val, exc_tb):
769 """Context manager exit."""
770 self.close()
772 @abstractmethod
773 def stream_read(
774 self,
775 query: Query | None = None,
776 config: StreamConfig | None = None
777 ) -> Iterator[Record]:
778 """Stream records from database.
780 Yields records one at a time, fetching in batches internally.
782 Args:
783 query: Optional query to filter records
784 config: Streaming configuration
786 Yields:
787 Records matching the query
788 """
789 raise NotImplementedError
791 @abstractmethod
792 def stream_write(
793 self,
794 records: Iterator[Record],
795 config: StreamConfig | None = None
796 ) -> StreamResult:
797 """Stream records into database.
799 Accepts an iterator and writes in batches.
801 Args:
802 records: Iterator of records to write
803 config: Streaming configuration
805 Returns:
806 Result of the streaming operation
807 """
808 raise NotImplementedError
810 def stream_transform(
811 self,
812 query: Query | None = None,
813 transform: Callable[[Record], Record | None] | None = None,
814 config: StreamConfig | None = None
815 ) -> Iterator[Record]:
816 """Stream records through a transformation.
818 Default implementation, can be overridden for efficiency.
820 Args:
821 query: Optional query to filter records
822 transform: Optional transformation function
823 config: Streaming configuration
825 Yields:
826 Transformed records
827 """
828 for record in self.stream_read(query, config):
829 if transform:
830 transformed = transform(record)
831 if transformed: # None means filter out
832 yield transformed
833 else:
834 yield record
836 @classmethod
837 def from_backend(cls, backend: str, config: dict[str, Any] | None = None) -> SyncDatabase:
838 """Factory method to create and connect a synchronous database instance.
840 Args:
841 backend: The backend type ("memory", "file", "s3", "postgres", "elasticsearch")
842 config: Backend-specific configuration
844 Returns:
845 Connected SyncDatabase instance
846 """
847 from .backends import SYNC_BACKEND_REGISTRY
849 backend_class = SYNC_BACKEND_REGISTRY.get(backend)
850 if not backend_class:
851 raise ValueError(
852 f"Unknown backend: {backend}. Available: {list(SYNC_BACKEND_REGISTRY.keys())}"
853 )
855 instance = backend_class(config)
856 instance.connect()
857 return instance