Coverage for src/dataknobs_data/backends/postgres.py: 12%
734 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:23 -0700
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:23 -0700
1"""PostgreSQL backend implementation with proper connection management and vector support."""
3from __future__ import annotations
5import json
6import logging
7import time
8import uuid
9from typing import TYPE_CHECKING, Any, cast
11import asyncpg
12from dataknobs_config import ConfigurableBase
14from dataknobs_utils.sql_utils import DotenvPostgresConnector, PostgresDB
16from ..database import AsyncDatabase, SyncDatabase
17from ..pooling import ConnectionPoolManager
18from ..pooling.postgres import PostgresPoolConfig, create_asyncpg_pool, validate_asyncpg_pool
19from ..query import Operator, Query
20from ..query_logic import ComplexQuery
21from ..streaming import (
22 StreamConfig,
23 StreamResult,
24 async_process_batch_with_fallback,
25 process_batch_with_fallback,
26)
27from ..vector.mixins import VectorOperationsMixin
28from .postgres_mixins import (
29 PostgresBaseConfig,
30 PostgresConnectionValidator,
31 PostgresErrorHandler,
32 PostgresTableManager,
33 PostgresVectorSupport,
34)
35from .sql_base import SQLQueryBuilder, SQLRecordSerializer
37if TYPE_CHECKING:
38 import numpy as np
40 from collections.abc import AsyncIterator, Iterator, Callable, Awaitable
41 from ..fields import VectorField
42 from ..records import Record
43 from ..vector.types import DistanceMetric, VectorSearchResult
45logger = logging.getLogger(__name__)
48class SyncPostgresDatabase(
49 SyncDatabase,
50 ConfigurableBase,
51 VectorOperationsMixin,
52 SQLRecordSerializer,
53 PostgresBaseConfig,
54 PostgresTableManager,
55 PostgresVectorSupport,
56 PostgresConnectionValidator,
57 PostgresErrorHandler,
58):
59 """Synchronous PostgreSQL database backend with proper connection management."""
61 def __init__(self, config: dict[str, Any] | None = None):
62 """Initialize PostgreSQL database configuration.
64 Args:
65 config: Configuration with the following optional keys:
66 - host: PostgreSQL host (default: from env/localhost)
67 - port: PostgreSQL port (default: 5432)
68 - database: Database name (default: from env/postgres)
69 - user: Username (default: from env/postgres)
70 - password: Password (default: from env)
71 - table: Table name (default: "records")
72 - schema: Schema name (default: "public")
73 - enable_vector: Enable vector support (default: False)
74 """
75 super().__init__(config)
77 # Parse configuration using mixin
78 table_name, schema_name, conn_config = self._parse_postgres_config(config or {})
79 self._init_postgres_attributes(table_name, schema_name)
81 # Store connection config for later use
82 self._conn_config = conn_config
83 self.db = None # Will be initialized in connect()
84 self.query_builder = None # Will be initialized in connect()
86 @classmethod
87 def from_config(cls, config: dict) -> SyncPostgresDatabase:
88 """Create from config dictionary."""
89 return cls(config)
91 def connect(self) -> None:
92 """Connect to the PostgreSQL database."""
93 if self._connected:
94 return # Already connected
96 # Initialize query builder with pyformat style for psycopg2
97 self.query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat")
99 # Create connection using existing utilities
100 if not any(key in self._conn_config for key in ["host", "database", "user"]):
101 # Use dotenv connector for environment-based config
102 connector = DotenvPostgresConnector()
103 self.db = PostgresDB(connector)
104 else:
105 # Direct configuration - map 'database' to 'db' for PostgresDB
106 self.db = PostgresDB(
107 host=self._conn_config.get("host", "localhost"),
108 db=self._conn_config.get("database", "postgres"), # Note: PostgresDB expects 'db' not 'database'
109 user=self._conn_config.get("user", "postgres"),
110 pwd=self._conn_config.get("password"), # Note: PostgresDB expects 'pwd' not 'password'
111 port=self._conn_config.get("port", 5432),
112 )
114 # Create table if it doesn't exist
115 self._ensure_table()
117 # Detect and enable vector support if requested
118 if self.vector_enabled:
119 self._detect_vector_support()
121 self._connected = True
122 self.log_operation("connect", f"Connected to table: {self.schema_name}.{self.table_name}")
124 def close(self) -> None:
125 """Close the database connection."""
126 if self.db:
127 # PostgresDB manages its own connections via context managers
128 # but we can mark as disconnected
129 self._connected = False # type: ignore[unreachable]
131 def _initialize(self) -> None:
132 """Initialize method - connection setup moved to connect()."""
133 # Configuration parsing stays here, actual connection in connect()
134 pass
136 def _detect_vector_support(self) -> None:
137 """Detect and enable vector support if pgvector is available."""
138 from .postgres_vector import check_pgvector_extension_sync, install_pgvector_extension_sync
140 try:
141 # Check if pgvector is installed
142 if check_pgvector_extension_sync(self.db):
143 self._vector_enabled = True
144 logger.info("pgvector extension detected and enabled")
145 else:
146 # Try to install it
147 if install_pgvector_extension_sync(self.db):
148 self._vector_enabled = True
149 logger.info("pgvector extension installed and enabled")
150 else:
151 logger.debug("pgvector extension not available")
152 except Exception as e:
153 logger.debug(f"Could not enable vector support: {e}")
154 self._vector_enabled = False
156 def _ensure_table(self) -> None:
157 """Ensure the records table exists."""
158 if not self.db:
159 raise RuntimeError("Database not connected. Call connect() first.")
161 create_table_sql = self.get_create_table_sql(self.schema_name, self.table_name) # type: ignore[unreachable]
162 self.db.execute(create_table_sql)
165 def _record_to_row(self, record: Record, id: str | None = None) -> dict[str, Any]:
166 """Convert a Record to a database row."""
167 return {
168 "id": id or str(uuid.uuid4()),
169 "data": self.record_to_json(record),
170 "metadata": json.dumps(record.metadata) if record.metadata else None,
171 }
173 def _row_to_record(self, row: dict[str, Any]) -> Record:
174 """Convert a database row to a Record."""
175 return self.row_to_record(row)
177 def create(self, record: Record) -> str:
178 """Create a new record."""
179 self._check_connection()
180 # Use record's ID if it has one, otherwise generate a new one
181 id = record.id if record.id else str(uuid.uuid4())
182 row = self._record_to_row(record, id)
184 sql = f"""
185 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata)
186 VALUES (%(id)s, %(data)s, %(metadata)s)
187 """
188 self.db.execute(sql, row)
189 return id
191 def read(self, id: str) -> Record | None:
192 """Read a record by ID."""
193 self._check_connection()
194 sql = f"""
195 SELECT id, data, metadata
196 FROM {self.schema_name}.{self.table_name}
197 WHERE id = %(id)s
198 """
199 df = self.db.query(sql, {"id": id})
201 if df.empty:
202 return None
204 row = df.iloc[0].to_dict()
205 return self._row_to_record(row)
207 def update(self, id: str, record: Record) -> bool:
208 """Update an existing record.
210 Args:
211 id: The record ID to update
212 record: The record data to update with
214 Returns:
215 True if the record was updated, False if no record with the given ID exists
216 """
217 self._check_connection()
218 row = self._record_to_row(record, id)
220 sql = f"""
221 UPDATE {self.schema_name}.{self.table_name}
222 SET data = %(data)s, metadata = %(metadata)s, updated_at = CURRENT_TIMESTAMP
223 WHERE id = %(id)s
224 """
225 result = self.db.execute(sql, row)
227 # PostgresDB.execute returns number of affected rows
228 rows_affected = result if isinstance(result, int) else 0
230 if rows_affected == 0:
231 logger.warning(f"Update affected 0 rows for id={id}. Record may not exist.")
233 return rows_affected > 0
235 def delete(self, id: str) -> bool:
236 """Delete a record by ID."""
237 self._check_connection()
238 sql = f"""
239 DELETE FROM {self.schema_name}.{self.table_name}
240 WHERE id = %(id)s
241 """
242 result = self.db.execute(sql, {"id": id})
243 return result > 0 if isinstance(result, int) else False
245 def exists(self, id: str) -> bool:
246 """Check if a record exists."""
247 self._check_connection()
248 sql = f"""
249 SELECT 1 FROM {self.schema_name}.{self.table_name}
250 WHERE id = %(id)s
251 LIMIT 1
252 """
253 df = self.db.query(sql, {"id": id})
254 return not df.empty
256 def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str:
257 """Update or insert a record.
259 Can be called as:
260 - upsert(id, record) - explicit ID and record
261 - upsert(record) - extract ID from record using Record's built-in logic
262 """
263 self._check_connection()
265 # Determine ID and record based on arguments
266 if isinstance(id_or_record, str):
267 id = id_or_record
268 if record is None:
269 raise ValueError("Record required when ID is provided")
270 else:
271 record = id_or_record
272 id = record.id
273 if id is None:
274 import uuid # type: ignore[unreachable]
275 id = str(uuid.uuid4())
276 record.storage_id = id
278 if self.exists(id):
279 self.update(id, record)
280 else:
281 # Insert with specific ID
282 row = self._record_to_row(record, id)
283 sql = f"""
284 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata)
285 VALUES (%(id)s, %(data)s, %(metadata)s)
286 """
287 self.db.execute(sql, row)
288 return id
290 def search(self, query: Query | ComplexQuery) -> list[Record]:
291 """Search for records matching the query."""
292 self._check_connection()
294 # Handle ComplexQuery with native SQL support
295 if isinstance(query, ComplexQuery):
296 sql_query, params_list = self.query_builder.build_complex_search_query(query)
297 else:
298 sql_query, params_list = self.query_builder.build_search_query(query)
300 # Build params dict for psycopg2
301 # The query builder now generates %(p0)s style placeholders directly
302 params_dict = {}
303 if params_list:
304 for i, param in enumerate(params_list):
305 params_dict[f"p{i}"] = param
307 # Execute query
308 df = self.db.query(sql_query, params_dict)
310 # Convert to records
311 records = []
312 for _, row in df.iterrows():
313 row_dict = row.to_dict()
314 record = self._row_to_record(row_dict)
316 # Populate storage_id from database ID
317 record.storage_id = str(row_dict['id'])
319 # Apply field projection if specified
320 if query.fields:
321 record = record.project(query.fields)
323 records.append(record)
325 return records
327 def _count_all(self) -> int:
328 """Count all records in the database."""
329 self._check_connection()
330 sql = f"SELECT COUNT(*) as count FROM {self.schema_name}.{self.table_name}"
331 df = self.db.query(sql)
332 return int(df.iloc[0]["count"]) if not df.empty else 0
334 def clear(self) -> int:
335 """Clear all records from the database."""
336 self._check_connection()
337 # Get count first
338 count = self._count_all()
340 # Delete all records
341 sql = f"TRUNCATE TABLE {self.schema_name}.{self.table_name}"
342 self.db.execute(sql)
344 return count
346 def create_batch(self, records: list[Record]) -> list[str]:
347 """Create multiple records efficiently using a single query.
349 Uses multi-value INSERT for better performance.
351 Args:
352 records: List of records to create
354 Returns:
355 List of created record IDs
356 """
357 if not records:
358 return []
360 self._check_connection()
362 # Create a query builder for PostgreSQL with pyformat style
363 from .sql_base import SQLQueryBuilder
364 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat")
366 # Use the shared batch create query builder
367 query, params_list, ids = query_builder.build_batch_create_query(records)
369 # Build params dict for psycopg2
370 params_dict = {}
371 for i, param in enumerate(params_list):
372 params_dict[f"p{i}"] = param
374 # Execute the batch insert and get returned IDs
375 result_df = self.db.query(query, params_dict)
377 # PostgreSQL RETURNING clause gives us the actual inserted IDs
378 if not result_df.empty:
379 return result_df['id'].tolist()
380 return ids
382 def delete_batch(self, ids: list[str]) -> list[bool]:
383 """Delete multiple records efficiently using a single query.
385 Uses single DELETE with IN clause for better performance.
387 Args:
388 ids: List of record IDs to delete
390 Returns:
391 List of success flags for each deletion
392 """
393 if not ids:
394 return []
396 self._check_connection()
398 # Create a query builder for PostgreSQL with pyformat style
399 from .sql_base import SQLQueryBuilder
400 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat")
402 # Use the shared batch delete query builder (includes RETURNING clause)
403 query, params_list = query_builder.build_batch_delete_query(ids)
405 # Build params dict for psycopg2
406 params_dict = {}
407 for i, param in enumerate(params_list):
408 params_dict[f"p{i}"] = param
410 # Execute the batch delete and get returned IDs
411 result_df = self.db.query(query, params_dict)
413 # Get list of deleted IDs from RETURNING clause
414 deleted_ids = set(result_df['id'].tolist()) if not result_df.empty else set()
416 # Return results based on which IDs were actually deleted
417 results = []
418 for id in ids:
419 results.append(id in deleted_ids)
421 return results
423 def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
424 """Update multiple records efficiently using a single query.
426 Uses PostgreSQL's CASE expressions for batch updates via shared SQL builder.
428 Args:
429 updates: List of (id, record) tuples to update
431 Returns:
432 List of success flags for each update
433 """
434 if not updates:
435 return []
437 self._check_connection()
439 # Create a query builder for PostgreSQL with pyformat style
440 from .sql_base import SQLQueryBuilder
441 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat")
443 # Use the shared batch update query builder
444 query, params_list = query_builder.build_batch_update_query(updates)
446 # Build params dict for psycopg2
447 params_dict = {}
448 for i, param in enumerate(params_list):
449 params_dict[f"p{i}"] = param
451 # Execute the batch update and get returned IDs (query now includes RETURNING clause)
452 result_df = self.db.query(query, params_dict)
454 # Get list of updated IDs from RETURNING clause
455 updated_ids = set(result_df['id'].tolist()) if not result_df.empty else set()
457 results = []
458 for record_id, _ in updates:
459 results.append(record_id in updated_ids)
461 return results
463 def stream_read(
464 self,
465 query: Query | None = None,
466 config: StreamConfig | None = None
467 ) -> Iterator[Record]:
468 """Stream records from PostgreSQL."""
469 self._check_connection()
470 config = config or StreamConfig()
472 # Build SQL query
473 sql = f"SELECT id, data, metadata FROM {self.schema_name}.{self.table_name}"
474 params = {}
476 if query and query.filters:
477 # Add WHERE clause (simplified for now)
478 where_clauses = []
479 for i, filter in enumerate(query.filters):
480 field_path = f"data->>'{filter.field}'"
481 param_name = f"param_{i}"
483 if filter.operator == Operator.EQ:
484 where_clauses.append(f"{field_path} = %({param_name})s")
485 params[param_name] = str(filter.value)
487 if where_clauses:
488 sql += " WHERE " + " AND ".join(where_clauses)
490 # Use cursor for streaming
491 # Note: PostgresDB may need modification to support cursors
492 # For now, we'll fetch in batches
493 sql += f" LIMIT {config.batch_size} OFFSET %(offset)s"
495 offset = 0
496 while True:
497 params["offset"] = offset
498 df = self.db.query(sql, params)
500 if df.empty:
501 break
503 for _, row in df.iterrows():
504 record = self._row_to_record(row.to_dict())
505 if query and query.fields:
506 record = record.project(query.fields)
507 yield record
509 offset += config.batch_size
511 # If we got less than batch_size, we're done
512 if len(df) < config.batch_size:
513 break
515 def stream_write(
516 self,
517 records: Iterator[Record],
518 config: StreamConfig | None = None
519 ) -> StreamResult:
520 """Stream records into PostgreSQL."""
521 self._check_connection()
522 config = config or StreamConfig()
523 result = StreamResult()
524 start_time = time.time()
525 quitting = False
527 batch = []
528 for record in records:
529 batch.append(record)
531 if len(batch) >= config.batch_size:
532 # Write batch with graceful fallback
533 # Use lambda wrapper for _write_batch
534 continue_processing = process_batch_with_fallback(
535 batch,
536 lambda b: self._write_batch(b),
537 self.create,
538 result,
539 config
540 )
542 if not continue_processing:
543 quitting = True
544 break
546 batch = []
548 # Write remaining batch
549 if batch and not quitting:
550 process_batch_with_fallback(
551 batch,
552 lambda b: self._write_batch(b),
553 self.create,
554 result,
555 config
556 )
558 result.duration = time.time() - start_time
559 return result
561 def _write_batch(self, records: list[Record]) -> list[str]:
562 """Write a batch of records to the database.
564 Returns:
565 List of created record IDs
566 """
567 # Build batch insert SQL
568 values = []
569 params = {}
570 ids = []
572 for i, record in enumerate(records):
573 id = str(uuid.uuid4())
574 ids.append(id)
575 row = self._record_to_row(record, id)
576 values.append(f"(%(id_{i})s, %(data_{i})s, %(metadata_{i})s)")
577 params[f"id_{i}"] = row["id"]
578 params[f"data_{i}"] = row["data"]
579 params[f"metadata_{i}"] = row["metadata"]
581 sql = f"""
582 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata)
583 VALUES {', '.join(values)}
584 """
585 self.db.execute(sql, params)
586 return ids
588 def vector_search(
589 self,
590 query_vector: np.ndarray | list[float] | VectorField,
591 field_name: str,
592 k: int = 10,
593 filter: Query | None = None,
594 metric: DistanceMetric | str = "cosine"
595 ) -> list[VectorSearchResult]:
596 """Search for similar vectors using PostgreSQL pgvector.
598 Args:
599 query_vector: Query vector (numpy array, list, or VectorField)
600 field_name: Name of vector field to search (must be in data JSON)
601 limit: Maximum number of results
602 filters: Optional filters to apply
603 metric: Distance metric to use (cosine, euclidean, l2, inner_product)
605 Returns:
606 List of VectorSearchResult objects ordered by similarity
607 """
608 if not self._vector_enabled:
609 raise RuntimeError("Vector search not available - pgvector not installed")
611 self._check_connection()
613 from ..fields import VectorField
614 from ..vector.types import DistanceMetric, VectorSearchResult
615 from .postgres_vector import format_vector_for_postgres, get_vector_operator
617 # Convert query vector to proper format
618 if isinstance(query_vector, VectorField):
619 vector_str = format_vector_for_postgres(query_vector.value)
620 else:
621 vector_str = format_vector_for_postgres(query_vector)
623 # Get the appropriate operator
624 if isinstance(metric, DistanceMetric):
625 metric_str = metric.value
626 else:
627 metric_str = str(metric).lower()
629 operator = get_vector_operator(metric_str)
631 # Build the query - vectors are stored in JSON data field
632 # Use centralized vector extraction logic
633 vector_expr = self.get_vector_extraction_sql(field_name, dialect="postgres")
635 # Build the base SQL with pyformat placeholders
636 sql = f"""
637 SELECT
638 id,
639 data,
640 metadata,
641 {vector_expr} {operator} %(p0)s::vector AS distance
642 FROM {self.schema_name}.{self.table_name}
643 WHERE data ? %(p1)s -- Check field exists
644 """
646 params: list[Any] = [vector_str, field_name]
648 # Add filters if provided using the query builder
649 if filter:
650 # Query builder will generate pyformat placeholders since we configured it that way
651 where_clause, filter_params = self.query_builder.build_where_clause(filter, len(params) + 1)
652 if where_clause:
653 sql += where_clause
654 params.extend(filter_params)
656 # Order by distance and limit
657 next_param = len(params)
658 sql += f" ORDER BY distance LIMIT %(p{next_param})s"
659 params.append(k)
661 # Build param dict for psycopg2
662 param_dict = {}
663 for i, param in enumerate(params):
664 param_dict[f"p{i}"] = param
666 df = self.db.query(sql, param_dict)
668 # Convert results
669 results = []
670 for _, row in df.iterrows():
671 record = self._row_to_record(row)
673 # Calculate similarity score from distance
674 distance = row["distance"]
675 if metric_str in ["cosine", "cosine_similarity"]:
676 score = 1.0 - distance # Cosine distance to similarity
677 elif metric_str in ["euclidean", "l2"]:
678 score = 1.0 / (1.0 + distance) # Convert distance to similarity
679 elif metric_str in ["inner_product", "dot_product"]:
680 score = -distance # Negative because pgvector uses negative for descending
681 else:
682 score = -distance # Default: lower distance = better
684 result = VectorSearchResult(
685 record=record,
686 score=float(score),
687 vector_field=field_name
688 )
689 results.append(result)
691 return results
693 def has_vector_support(self) -> bool:
694 """Check if this database has vector support enabled.
696 Returns:
697 True if vector operations are supported
698 """
699 return self._vector_enabled
701 def enable_vector_support(self) -> bool:
702 """Enable vector support for this database if possible.
704 Returns:
705 True if vector support is now enabled
706 """
707 if self._vector_enabled:
708 return True
710 self._detect_vector_support()
711 return self._vector_enabled
713 def bulk_embed_and_store(
714 self,
715 records: list[Record],
716 text_field: str | list[str],
717 vector_field: str = "embedding",
718 embedding_fn: Any = None,
719 batch_size: int = 100,
720 model_name: str | None = None,
721 model_version: str | None = None,
722 ) -> list[str]:
723 """Embed text fields and store vectors with records (stub for abstract requirement).
725 This is a placeholder implementation to satisfy the abstract method requirement.
726 Full implementation would require actual embedding function.
727 """
728 raise NotImplementedError("bulk_embed_and_store requires an embedding function")
731# Global pool manager instance for async PostgreSQL connections
732_pool_manager = ConnectionPoolManager[asyncpg.Pool]()
735class AsyncPostgresDatabase(
736 AsyncDatabase,
737 VectorOperationsMixin,
738 ConfigurableBase,
739 PostgresBaseConfig,
740 PostgresTableManager,
741 PostgresVectorSupport,
742 PostgresConnectionValidator,
743 PostgresErrorHandler,
744):
745 """Native async PostgreSQL database backend with vector support and event loop-aware connection pooling."""
747 def __init__(self, config: dict[str, Any] | None = None):
748 """Initialize async PostgreSQL database."""
749 super().__init__(config)
751 # Parse configuration using mixin
752 table_name, schema_name, conn_config = self._parse_postgres_config(config or {})
753 self._init_postgres_attributes(table_name, schema_name)
755 # Extract pool configuration
756 self._pool_config = PostgresPoolConfig.from_dict(conn_config)
757 self._pool: asyncpg.Pool | None = None
759 @classmethod
760 def from_config(cls, config: dict) -> AsyncPostgresDatabase:
761 """Create from config dictionary."""
762 return cls(config)
764 async def connect(self) -> None:
765 """Connect to the database."""
766 if self._connected:
767 return
769 # Get or create pool for current event loop
770 from ..pooling import BasePoolConfig
771 self._pool = await _pool_manager.get_pool(
772 self._pool_config,
773 cast("Callable[[BasePoolConfig], Awaitable[Any]]", create_asyncpg_pool),
774 validate_asyncpg_pool
775 )
777 # Initialize query builder
778 self.query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres")
780 # Ensure table exists
781 await self._ensure_table()
783 # Check and enable vector support if requested
784 if self.vector_enabled:
785 await self._detect_vector_support()
787 self._connected = True
788 self.log_operation("connect", f"Connected to table: {self.schema_name}.{self.table_name}")
790 async def close(self) -> None:
791 """Close the database connection and properly close the pool."""
792 if self._connected:
793 # Properly close the pool if we have one
794 if self._pool:
795 try:
796 await self._pool.close()
797 except Exception as e:
798 logger.warning(f"Error closing connection pool: {e}")
799 self._pool = None
800 self._connected = False
802 def _initialize(self) -> None:
803 """Initialize is handled in connect."""
804 pass
806 async def _ensure_table(self) -> None:
807 """Ensure the records table exists."""
808 if not self._pool:
809 raise RuntimeError("Database not connected. Call connect() first.")
811 create_table_sql = self.get_create_table_sql(self.schema_name, self.table_name)
813 async with self._pool.acquire() as conn:
814 await conn.execute(create_table_sql)
816 async def _detect_vector_support(self) -> None:
817 """Detect and enable vector support if pgvector is available."""
818 from .postgres_vector import check_pgvector_extension, install_pgvector_extension
820 async with self._pool.acquire() as conn:
821 # Check if pgvector is available
822 if await check_pgvector_extension(conn):
823 self._vector_enabled = True
824 logger.info("pgvector extension detected and enabled")
825 else:
826 # Try to install it
827 if await install_pgvector_extension(conn):
828 self._vector_enabled = True
829 logger.info("pgvector extension installed and enabled")
830 else:
831 logger.debug("pgvector extension not available")
833 async def _ensure_vector_column(self, field_name: str, dimensions: int) -> None:
834 """Ensure a vector column exists for the given field.
836 Args:
837 field_name: Name of the vector field
838 dimensions: Number of dimensions
839 """
840 if not self._vector_enabled:
841 return
843 column_name = f"vector_{field_name}"
845 # Check if column already exists
846 check_sql = """
847 SELECT column_name FROM information_schema.columns
848 WHERE table_schema = $1 AND table_name = $2 AND column_name = $3
849 """
851 async with self._pool.acquire() as conn:
852 existing = await conn.fetchval(check_sql, self.schema_name, self.table_name, column_name)
854 if not existing:
855 # Add vector column
856 alter_sql = f"""
857 ALTER TABLE {self.schema_name}.{self.table_name}
858 ADD COLUMN IF NOT EXISTS {column_name} vector({dimensions})
859 """
860 try:
861 await conn.execute(alter_sql)
862 self._vector_dimensions[field_name] = dimensions
863 logger.info(f"Added vector column {column_name} with {dimensions} dimensions")
865 # Create index for the vector column
866 from .postgres_vector import build_vector_index_sql, get_optimal_index_type
868 # Get row count for optimal index selection
869 count_sql = f"SELECT COUNT(*) FROM {self.schema_name}.{self.table_name}"
870 count = await conn.fetchval(count_sql)
872 index_type, index_params = get_optimal_index_type(count)
873 index_sql = build_vector_index_sql(
874 self.table_name,
875 self.schema_name,
876 column_name,
877 dimensions,
878 metric="cosine",
879 index_type=index_type,
880 index_params=index_params
881 )
883 # Note: IVFFlat requires table to have data before creating index
884 if count > 0 or index_type != "ivfflat":
885 await conn.execute(index_sql)
886 logger.info(f"Created {index_type} index for {column_name}")
888 except Exception as e:
889 logger.warning(f"Could not create vector column {column_name}: {e}")
890 else:
891 self._vector_dimensions[field_name] = dimensions
893 def _check_connection(self) -> None:
894 """Check if async database is connected."""
895 self._check_async_connection()
897 def _record_to_row(self, record: Record, id: str | None = None) -> dict[str, Any]:
898 """Convert a Record to a database row using common serializer."""
899 from .sql_base import SQLRecordSerializer
901 return {
902 "id": id or str(uuid.uuid4()),
903 "data": SQLRecordSerializer.record_to_json(record),
904 "metadata": json.dumps(record.metadata) if record.metadata else None,
905 }
907 def _row_to_record(self, row: asyncpg.Record) -> Record:
908 """Convert a database row to a Record using the common serializer."""
909 from .sql_base import SQLRecordSerializer
911 # Convert asyncpg.Record to dict format expected by SQLRecordSerializer
912 data_json = row.get("data", {})
913 if not isinstance(data_json, str):
914 data_json = json.dumps(data_json)
916 metadata_json = row.get("metadata")
917 if metadata_json and not isinstance(metadata_json, str):
918 metadata_json = json.dumps(metadata_json)
920 # Use the common serializer to reconstruct the record
921 return SQLRecordSerializer.json_to_record(data_json, metadata_json)
923 async def create(self, record: Record) -> str:
924 """Create a new record with vector support."""
925 self._check_connection()
927 # Check for vector fields and ensure columns exist
928 from ..fields import VectorField
929 for field_name, field_obj in record.fields.items():
930 if isinstance(field_obj, VectorField) and self._vector_enabled:
931 await self._ensure_vector_column(field_name, field_obj.dimensions)
933 # Use record's ID if it has one, otherwise generate a new one
934 id = record.id if record.id else str(uuid.uuid4())
935 row = self._record_to_row(record, id)
937 # Build dynamic SQL based on vector columns present
938 columns = ["id", "data", "metadata"]
939 values = [row["id"], row["data"], row["metadata"]]
940 placeholders = ["$1", "$2", "$3"]
942 # Add vector columns
943 param_num = 4
944 for key, value in row.items():
945 if key.startswith("vector_"):
946 columns.append(key)
947 values.append(value)
948 placeholders.append(f"${param_num}")
949 param_num += 1
951 sql = f"""
952 INSERT INTO {self.schema_name}.{self.table_name} ({', '.join(columns)})
953 VALUES ({', '.join(placeholders)})
954 """
956 async with self._pool.acquire() as conn:
957 await conn.execute(sql, *values)
959 return id
961 async def read(self, id: str) -> Record | None:
962 """Read a record by ID."""
963 self._check_connection()
964 sql = f"""
965 SELECT id, data, metadata
966 FROM {self.schema_name}.{self.table_name}
967 WHERE id = $1
968 """
970 async with self._pool.acquire() as conn:
971 row = await conn.fetchrow(sql, id)
973 if not row:
974 return None
976 return self._row_to_record(row)
978 async def update(self, id: str, record: Record) -> bool:
979 """Update an existing record.
981 Args:
982 id: The record ID to update
983 record: The record data to update with
985 Returns:
986 True if the record was updated, False if no record with the given ID exists
987 """
988 self._check_connection()
989 row = self._record_to_row(record, id)
991 sql = f"""
992 UPDATE {self.schema_name}.{self.table_name}
993 SET data = $2, metadata = $3, updated_at = CURRENT_TIMESTAMP
994 WHERE id = $1
995 """
997 async with self._pool.acquire() as conn:
998 result = await conn.execute(sql, row["id"], row["data"], row["metadata"])
1000 # Returns UPDATE n where n is rows affected
1001 rows_affected = int(result.split()[-1])
1003 if rows_affected == 0:
1004 logger.warning(f"Update affected 0 rows for id={id}. Record may not exist.")
1006 return rows_affected > 0
1008 async def delete(self, id: str) -> bool:
1009 """Delete a record by ID."""
1010 self._check_connection()
1011 sql = f"""
1012 DELETE FROM {self.schema_name}.{self.table_name}
1013 WHERE id = $1
1014 """
1016 async with self._pool.acquire() as conn:
1017 result = await conn.execute(sql, id)
1019 # Returns DELETE n where n is rows affected
1020 return result.split()[-1] != "0"
1022 async def exists(self, id: str) -> bool:
1023 """Check if a record exists."""
1024 self._check_connection()
1025 sql = f"""
1026 SELECT 1 FROM {self.schema_name}.{self.table_name}
1027 WHERE id = $1
1028 LIMIT 1
1029 """
1031 async with self._pool.acquire() as conn:
1032 row = await conn.fetchrow(sql, id)
1034 return row is not None
1036 async def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str:
1037 """Update or insert a record.
1039 Can be called as:
1040 - upsert(id, record) - explicit ID and record
1041 - upsert(record) - extract ID from record using Record's built-in logic
1042 """
1043 self._check_connection()
1045 # Determine ID and record based on arguments
1046 if isinstance(id_or_record, str):
1047 id = id_or_record
1048 if record is None:
1049 raise ValueError("Record required when ID is provided")
1050 else:
1051 record = id_or_record
1052 id = record.id
1053 if id is None:
1054 import uuid # type: ignore[unreachable]
1055 id = str(uuid.uuid4())
1056 record.storage_id = id
1058 row = self._record_to_row(record, id)
1060 sql = f"""
1061 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata)
1062 VALUES ($1, $2, $3)
1063 ON CONFLICT (id) DO UPDATE
1064 SET data = EXCLUDED.data, metadata = EXCLUDED.metadata, updated_at = CURRENT_TIMESTAMP
1065 """
1067 async with self._pool.acquire() as conn:
1068 await conn.execute(sql, row["id"], row["data"], row["metadata"])
1070 return id
1072 async def search(self, query: Query | ComplexQuery) -> list[Record]:
1073 """Search for records matching the query."""
1074 self._check_connection()
1076 # Initialize query builder if not already done
1077 if not hasattr(self, 'query_builder'):
1078 self.query_builder = SQLQueryBuilder(
1079 self.table_name, self.schema_name, dialect="postgres"
1080 )
1082 # Handle ComplexQuery with native SQL support
1083 if isinstance(query, ComplexQuery):
1084 sql, params = self.query_builder.build_complex_search_query(query)
1085 else:
1086 sql, params = self.query_builder.build_search_query(query)
1088 # Execute query with asyncpg (already uses positional parameters)
1089 async with self._pool.acquire() as conn:
1090 rows = await conn.fetch(sql, *params)
1092 # Convert to records
1093 records = []
1094 for row in rows:
1095 record = self._row_to_record(row)
1097 # Populate storage_id from database ID
1098 record.storage_id = str(row['id'])
1100 # Apply field projection if specified
1101 if query.fields:
1102 record = record.project(query.fields)
1104 records.append(record)
1106 return records
1108 async def _count_all(self) -> int:
1109 """Count all records in the database."""
1110 self._check_connection()
1111 sql = f"SELECT COUNT(*) as count FROM {self.schema_name}.{self.table_name}"
1113 async with self._pool.acquire() as conn:
1114 row = await conn.fetchrow(sql)
1116 return row["count"] if row else 0
1118 async def clear(self) -> int:
1119 """Clear all records from the database."""
1120 self._check_connection()
1121 # Get count first
1122 count = await self._count_all()
1124 # Delete all records
1125 sql = f"TRUNCATE TABLE {self.schema_name}.{self.table_name}"
1127 async with self._pool.acquire() as conn:
1128 await conn.execute(sql)
1130 return count
1132 async def create_batch(self, records: list[Record]) -> list[str]:
1133 """Create multiple records efficiently using a single query.
1135 Uses multi-value INSERT with RETURNING for better performance.
1137 Args:
1138 records: List of records to create
1140 Returns:
1141 List of created record IDs
1142 """
1143 if not records:
1144 return []
1146 self._check_connection()
1148 # Create a query builder for PostgreSQL
1149 from .sql_base import SQLQueryBuilder
1150 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres")
1152 # Use the shared batch create query builder
1153 query, params, ids = query_builder.build_batch_create_query(records)
1155 # Execute the batch insert with RETURNING
1156 async with self._pool.acquire() as conn:
1157 rows = await conn.fetch(query, *params)
1159 # Return the actual inserted IDs from RETURNING clause
1160 if rows:
1161 return [row["id"] for row in rows]
1162 return ids # Fallback to generated IDs
1164 async def delete_batch(self, ids: list[str]) -> list[bool]:
1165 """Delete multiple records efficiently using a single query.
1167 Uses single DELETE with IN clause and RETURNING for verification.
1169 Args:
1170 ids: List of record IDs to delete
1172 Returns:
1173 List of success flags for each deletion
1174 """
1175 if not ids:
1176 return []
1178 self._check_connection()
1180 # Create a query builder for PostgreSQL
1181 from .sql_base import SQLQueryBuilder
1182 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres")
1184 # Use the shared batch delete query builder
1185 query, params = query_builder.build_batch_delete_query(ids)
1187 # Execute the batch delete with RETURNING
1188 async with self._pool.acquire() as conn:
1189 rows = await conn.fetch(query, *params)
1191 # Convert returned rows to set of deleted IDs
1192 deleted_ids = {row["id"] for row in rows}
1194 # Return results for each deletion
1195 results = []
1196 for id in ids:
1197 results.append(id in deleted_ids)
1199 return results
1201 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
1202 """Update multiple records efficiently using a single query.
1204 Uses PostgreSQL's CASE expressions for batch updates with native asyncpg.
1206 Args:
1207 updates: List of (id, record) tuples to update
1209 Returns:
1210 List of success flags for each update
1211 """
1212 if not updates:
1213 return []
1215 self._check_connection()
1217 # Create a query builder for PostgreSQL
1218 from .sql_base import SQLQueryBuilder
1219 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres")
1221 # Use the shared batch update query builder
1222 # It already produces positional parameters ($1, $2) for PostgreSQL
1223 query, params = query_builder.build_batch_update_query(updates)
1225 # Add RETURNING clause for PostgreSQL to get updated IDs
1226 query = query.rstrip() + " RETURNING id"
1228 # Execute the batch update
1229 async with self._pool.acquire() as conn:
1230 rows = await conn.fetch(query, *params)
1232 # Convert returned rows to set of updated IDs
1233 updated_ids = {row["id"] for row in rows}
1235 # Return results for each update
1236 results = []
1237 for record_id, _ in updates:
1238 results.append(record_id in updated_ids)
1240 return results
1242 async def vector_search(
1243 self,
1244 query_vector: np.ndarray | list[float] | VectorField,
1245 field_name: str,
1246 k: int = 10,
1247 filter: Query | None = None,
1248 metric: DistanceMetric | str = "cosine"
1249 ) -> list[VectorSearchResult]:
1250 """Search for similar vectors using PostgreSQL pgvector.
1252 Args:
1253 query_vector: Query vector (numpy array, list, or VectorField)
1254 field_name: Name of vector field to search
1255 limit: Maximum number of results
1256 filters: Optional filters to apply
1257 metric: Distance metric to use
1259 Returns:
1260 List of VectorSearchResult objects
1261 """
1262 if not self._vector_enabled:
1263 raise RuntimeError("Vector search not available - pgvector not installed")
1265 self._check_connection()
1267 from ..fields import VectorField
1268 from ..vector.types import DistanceMetric, VectorSearchResult
1269 from .postgres_vector import format_vector_for_postgres, get_vector_operator
1271 # Convert query vector to proper format
1272 if isinstance(query_vector, VectorField):
1273 vector_str = format_vector_for_postgres(query_vector.value)
1274 else:
1275 vector_str = format_vector_for_postgres(query_vector)
1277 # Get the appropriate operator
1278 if isinstance(metric, DistanceMetric):
1279 metric_str = metric.value
1280 else:
1281 metric_str = str(metric).lower()
1282 operator = get_vector_operator(metric_str)
1284 vector_column = f"vector_{field_name}"
1286 # Build query
1287 sql = f"""
1288 SELECT id, data, metadata, {vector_column},
1289 {vector_column} {operator} $1::vector AS distance
1290 FROM {self.schema_name}.{self.table_name}
1291 WHERE {vector_column} IS NOT NULL
1292 """
1294 params = [vector_str]
1295 param_num = 2
1297 # Add filters if provided using the query builder
1298 if filter:
1299 # First get the where clause from query builder
1300 where_clause, filter_params = self.query_builder.build_where_clause(filter, param_num)
1301 if where_clause:
1302 # Convert %s placeholders to $N for asyncpg
1303 for param in filter_params:
1304 where_clause = where_clause.replace("%s", f"${param_num}", 1)
1305 params.append(param)
1306 param_num += 1
1307 sql += where_clause
1309 # Order by distance and limit
1310 sql += f"""
1311 ORDER BY distance
1312 LIMIT {k}
1313 """
1315 # Execute query
1316 async with self._pool.acquire() as conn:
1317 rows = await conn.fetch(sql, *params)
1319 # Convert to VectorSearchResult objects
1320 results = []
1321 for row in rows:
1322 record = self._row_to_record(row)
1324 # Convert distance to similarity score (1 - normalized_distance for cosine)
1325 distance = float(row['distance'])
1326 if metric_str == "cosine":
1327 score = 1.0 - min(distance, 2.0) / 2.0 # Normalize cosine distance [0,2] to similarity [0,1]
1328 elif metric_str in ["euclidean", "l2"]:
1329 score = 1.0 / (1.0 + distance) # Convert distance to similarity
1330 else:
1331 score = 1.0 - distance # Generic conversion
1333 result = VectorSearchResult(
1334 record=record,
1335 score=score,
1336 vector_field=field_name,
1337 metadata={"distance": distance, "metric": metric_str}
1338 )
1339 results.append(result)
1341 return results
1343 async def enable_vector_support(self) -> bool:
1344 """Enable vector support for this database.
1346 Returns:
1347 True if vector support is enabled
1348 """
1349 if self._vector_enabled:
1350 return True
1352 await self._detect_vector_support()
1353 return self._vector_enabled
1355 async def has_vector_support(self) -> bool:
1356 """Check if this database has vector support enabled.
1358 Returns:
1359 True if vector support is available
1360 """
1361 return self._vector_enabled
1363 async def bulk_embed_and_store(
1364 self,
1365 records: list[Record],
1366 text_field: str | list[str],
1367 vector_field: str,
1368 embedding_fn: Any | None = None,
1369 batch_size: int = 100,
1370 model_name: str | None = None,
1371 model_version: str | None = None,
1372 ) -> list[str]:
1373 """Embed text fields and store vectors with records.
1375 This is a placeholder implementation. In a real scenario, you would:
1376 1. Extract text from the specified fields
1377 2. Call the embedding function to generate vectors
1378 3. Store the vectors alongside the records
1380 Args:
1381 records: Records to process
1382 text_field: Field name(s) containing text to embed
1383 vector_field: Field name to store vectors in
1384 embedding_fn: Function to generate embeddings
1385 batch_size: Number of records to process at once
1386 model_name: Name of the embedding model
1387 model_version: Version of the embedding model
1389 Returns:
1390 List of record IDs that were processed
1391 """
1392 if not embedding_fn:
1393 raise ValueError("embedding_fn is required for bulk_embed_and_store")
1395 from ..fields import VectorField
1397 processed_ids = []
1399 # Process in batches
1400 for i in range(0, len(records), batch_size):
1401 batch = records[i:i + batch_size]
1403 # Extract texts
1404 texts = []
1405 for record in batch:
1406 if isinstance(text_field, list):
1407 text = " ".join(str(record.fields.get(f, {}).value) for f in text_field if f in record.fields)
1408 else:
1409 text = str(record.fields.get(text_field, {}).value) if text_field in record.fields else ""
1410 texts.append(text)
1412 # Generate embeddings
1413 if texts:
1414 embeddings = await embedding_fn(texts)
1416 # Store vectors with records
1417 for j, record in enumerate(batch):
1418 if j < len(embeddings):
1419 vector = embeddings[j]
1421 # Add vector field to record
1422 record.fields[vector_field] = VectorField(
1423 name=vector_field,
1424 value=vector,
1425 dimensions=len(vector) if hasattr(vector, '__len__') else None,
1426 source_field=text_field if isinstance(text_field, str) else ",".join(text_field),
1427 model_name=model_name,
1428 model_version=model_version,
1429 )
1431 # Create or update record
1432 if record.has_storage_id():
1433 if record.storage_id is None:
1434 raise ValueError("Record has_storage_id() returned True but storage_id is None")
1435 await self.update(record.storage_id, record)
1436 else:
1437 record_id = await self.create(record)
1438 record.storage_id = record_id
1440 if record.storage_id is None:
1441 raise ValueError("Record storage_id is None after create/update")
1442 processed_ids.append(record.storage_id)
1444 return processed_ids
1446 async def create_vector_index(
1447 self,
1448 vector_field: str,
1449 dimensions: int,
1450 metric: DistanceMetric | str = "cosine",
1451 index_type: str = "ivfflat",
1452 lists: int | None = None,
1453 ) -> bool:
1454 """Create a vector index for efficient similarity search.
1456 Args:
1457 vector_field: Name of the vector field to index
1458 dimensions: Number of dimensions in the vectors
1459 metric: Distance metric for the index
1460 index_type: Type of index (ivfflat, hnsw)
1461 lists: Number of lists for IVFFlat index
1463 Returns:
1464 True if index was created successfully
1465 """
1466 from .postgres_vector import (
1467 build_vector_column_expression,
1468 build_vector_index_sql,
1469 get_optimal_index_type,
1470 get_vector_count_sql,
1471 )
1473 self._check_connection()
1475 if not self._vector_enabled:
1476 return False
1478 # Determine optimal parameters if not provided
1479 if not lists and index_type == "ivfflat":
1480 # Count vectors to determine optimal lists
1481 count_sql = get_vector_count_sql(self.schema_name, self.table_name, vector_field)
1482 async with self._pool.acquire() as conn:
1483 count = await conn.fetchval(count_sql) or 0
1484 _, params = get_optimal_index_type(count)
1485 lists = params.get("lists", 100)
1487 # Convert metric enum to string if needed
1488 if hasattr(metric, 'value'):
1489 metric_str = metric.value
1490 else:
1491 metric_str = str(metric).lower()
1493 # Build vector column expression for index
1494 column_expr = build_vector_column_expression(vector_field, dimensions, for_index=True)
1496 # Build index SQL - pass field_name for proper index naming
1497 index_sql = build_vector_index_sql(
1498 table_name=self.table_name,
1499 schema_name=self.schema_name,
1500 column_name=column_expr,
1501 dimensions=dimensions,
1502 metric=metric_str,
1503 index_type=index_type,
1504 index_params={"lists": lists} if lists else None,
1505 field_name=vector_field
1506 )
1508 # Create the index
1509 try:
1510 logger.debug(f"Creating vector index with SQL: {index_sql}")
1511 async with self._pool.acquire() as conn:
1512 await conn.execute(index_sql)
1513 return True
1514 except Exception as e:
1515 logger.warning(f"Failed to create vector index: {e}")
1516 logger.debug(f"Index SQL was: {index_sql}")
1517 return False
1519 async def drop_vector_index(self, vector_field: str, metric: str = "cosine") -> bool:
1520 """Drop a vector index.
1522 Args:
1523 vector_field: Name of the vector field
1524 metric: Distance metric used in the index
1526 Returns:
1527 True if index was dropped successfully
1528 """
1529 from .postgres_vector import get_vector_index_name
1531 self._check_connection()
1533 index_name = get_vector_index_name(self.table_name, vector_field, metric)
1535 try:
1536 async with self._pool.acquire() as conn:
1537 await conn.execute(f"DROP INDEX IF EXISTS {self.schema_name}.{index_name}")
1538 return True
1539 except Exception as e:
1540 logger.warning(f"Failed to drop vector index: {e}")
1541 return False
1543 async def get_vector_index_stats(self, vector_field: str) -> dict[str, Any]:
1544 """Get statistics about a vector field and its index.
1546 Args:
1547 vector_field: Name of the vector field
1549 Returns:
1550 Dictionary with index statistics
1551 """
1552 from .postgres_vector import get_index_check_sql, get_vector_count_sql
1554 self._check_connection()
1556 stats = {
1557 "field": vector_field,
1558 "indexed": False,
1559 "vector_count": 0,
1560 }
1562 try:
1563 async with self._pool.acquire() as conn:
1564 # Count vectors
1565 count_sql = get_vector_count_sql(self.schema_name, self.table_name, vector_field)
1566 stats["vector_count"] = await conn.fetchval(count_sql) or 0
1568 # Check for index
1569 index_sql, params = get_index_check_sql(self.schema_name, self.table_name, vector_field)
1570 stats["indexed"] = await conn.fetchval(index_sql, *params) or False
1571 except Exception as e:
1572 logger.warning(f"Failed to get vector index stats: {e}")
1574 return stats
1576 async def stream_read(
1577 self,
1578 query: Query | None = None,
1579 config: StreamConfig | None = None
1580 ) -> AsyncIterator[Record]:
1581 """Stream records from PostgreSQL using cursor."""
1582 self._check_connection()
1583 config = config or StreamConfig()
1585 # Build SQL query
1586 sql = f"SELECT id, data, metadata FROM {self.schema_name}.{self.table_name}"
1587 params = []
1589 if query and query.filters:
1590 where_clauses = []
1591 param_count = 0
1593 for filter in query.filters:
1594 param_count += 1
1595 field_path = f"data->>'{filter.field}'"
1597 if filter.operator == Operator.EQ:
1598 where_clauses.append(f"{field_path} = ${param_count}")
1599 params.append(str(filter.value))
1601 if where_clauses:
1602 sql += " WHERE " + " AND ".join(where_clauses)
1604 # Use cursor for efficient streaming
1605 async with self._pool.acquire() as conn:
1606 async with conn.transaction():
1607 cursor = await conn.cursor(sql, *params)
1609 batch = []
1610 async for row in cursor:
1611 record = self._row_to_record(row)
1612 if query and query.fields:
1613 record = record.project(query.fields)
1615 batch.append(record)
1617 if len(batch) >= config.batch_size:
1618 for rec in batch:
1619 yield rec
1620 batch = []
1622 # Yield remaining records
1623 for rec in batch:
1624 yield rec
1626 async def stream_write(
1627 self,
1628 records: AsyncIterator[Record],
1629 config: StreamConfig | None = None
1630 ) -> StreamResult:
1631 """Stream records into PostgreSQL using batch inserts."""
1632 self._check_connection()
1633 config = config or StreamConfig()
1634 result = StreamResult()
1635 start_time = time.time()
1636 quitting = False
1638 batch = []
1639 async for record in records:
1640 batch.append(record)
1642 if len(batch) >= config.batch_size:
1643 # Write batch with graceful fallback
1644 # Use lambda wrapper for _write_batch
1645 async def batch_func(b):
1646 await self._write_batch(b)
1647 return [r.id for r in b]
1649 continue_processing = await async_process_batch_with_fallback(
1650 batch,
1651 batch_func,
1652 self.create,
1653 result,
1654 config
1655 )
1657 if not continue_processing:
1658 quitting = True
1659 break
1661 batch = []
1663 # Write remaining batch
1664 if batch and not quitting:
1665 async def batch_func(b):
1666 await self._write_batch(b)
1667 return [r.id for r in b]
1669 await async_process_batch_with_fallback(
1670 batch,
1671 batch_func,
1672 self.create,
1673 result,
1674 config
1675 )
1677 result.duration = time.time() - start_time
1678 return result
1680 async def _write_batch(self, records: list[Record]) -> list[str]:
1681 """Write a batch of records using COPY for performance.
1683 Returns:
1684 List of created record IDs
1685 """
1686 if not records:
1687 return []
1689 # Prepare data for COPY
1690 rows = []
1691 ids = []
1692 for record in records:
1693 row_data = self._record_to_row(record)
1694 ids.append(row_data["id"])
1695 rows.append((
1696 row_data["id"],
1697 row_data["data"],
1698 row_data["metadata"]
1699 ))
1701 # Use COPY for efficient bulk insert
1702 async with self._pool.acquire() as conn:
1703 await conn.copy_records_to_table(
1704 f"{self.schema_name}.{self.table_name}",
1705 records=rows,
1706 columns=["id", "data", "metadata"]
1707 )
1709 return ids