Coverage for src / dataknobs_data / backends / postgres.py: 11%
813 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-26 16:34 -0700
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-26 16:34 -0700
1"""PostgreSQL backend implementation with proper connection management and vector support."""
3from __future__ import annotations
5import json
6import logging
7import time
8import uuid
9from typing import TYPE_CHECKING, Any, cast
11import asyncpg
12from dataknobs_config import ConfigurableBase
14from dataknobs_utils.sql_utils import DotenvPostgresConnector, PostgresDB
16from ..database import AsyncDatabase, SyncDatabase
17from ..pooling import ConnectionPoolManager
18from ..pooling.postgres import PostgresPoolConfig, create_asyncpg_pool, validate_asyncpg_pool
19from ..query import Operator, Query
20from ..query_logic import ComplexQuery
21from ..streaming import (
22 StreamConfig,
23 StreamResult,
24 async_process_batch_with_fallback,
25 process_batch_with_fallback,
26)
27from ..vector.mixins import VectorOperationsMixin
28from .postgres_mixins import (
29 PostgresBaseConfig,
30 PostgresConnectionValidator,
31 PostgresErrorHandler,
32 PostgresTableManager,
33 PostgresVectorSupport,
34)
35from .sql_base import SQLQueryBuilder, SQLRecordSerializer
36from ..vector.types import DistanceMetric
38if TYPE_CHECKING:
39 import numpy as np
41 from collections.abc import AsyncIterator, Iterator, Callable, Awaitable
42 from ..fields import VectorField
43 from ..records import Record
44 from ..vector.types import VectorSearchResult
46logger = logging.getLogger(__name__)
49class SyncPostgresDatabase(
50 SyncDatabase,
51 ConfigurableBase,
52 VectorOperationsMixin,
53 SQLRecordSerializer,
54 PostgresBaseConfig,
55 PostgresTableManager,
56 PostgresVectorSupport,
57 PostgresConnectionValidator,
58 PostgresErrorHandler,
59):
60 """Synchronous PostgreSQL database backend with proper connection management."""
62 def __init__(self, config: dict[str, Any] | None = None):
63 """Initialize PostgreSQL database configuration.
65 Args:
66 config: Configuration with the following optional keys:
67 - host: PostgreSQL host (default: from env/localhost)
68 - port: PostgreSQL port (default: 5432)
69 - database: Database name (default: from env/postgres)
70 - user: Username (default: from env/postgres)
71 - password: Password (default: from env)
72 - table: Table name (default: "records")
73 - schema: Schema name (default: "public")
74 - enable_vector: Enable vector support (default: False)
75 """
76 super().__init__(config)
78 # Parse configuration using mixin
79 table_name, schema_name, conn_config = self._parse_postgres_config(config or {})
80 self._init_postgres_attributes(table_name, schema_name)
82 # Store connection config for later use
83 self._conn_config = conn_config
84 self.db = None # Will be initialized in connect()
85 self.query_builder = None # Will be initialized in connect()
87 @classmethod
88 def from_config(cls, config: dict) -> SyncPostgresDatabase:
89 """Create from config dictionary."""
90 return cls(config)
92 def connect(self) -> None:
93 """Connect to the PostgreSQL database."""
94 if self._connected:
95 return # Already connected
97 # Initialize query builder with pyformat style for psycopg2
98 self.query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat")
100 # Create connection using existing utilities
101 if not any(key in self._conn_config for key in ["host", "database", "user"]):
102 # Use dotenv connector for environment-based config
103 connector = DotenvPostgresConnector()
104 self.db = PostgresDB(connector)
105 else:
106 # Direct configuration - map 'database' to 'db' for PostgresDB
107 self.db = PostgresDB(
108 host=self._conn_config.get("host", "localhost"),
109 db=self._conn_config.get("database", "postgres"), # Note: PostgresDB expects 'db' not 'database'
110 user=self._conn_config.get("user", "postgres"),
111 pwd=self._conn_config.get("password"), # Note: PostgresDB expects 'pwd' not 'password'
112 port=self._conn_config.get("port", 5432),
113 )
115 # Create table if it doesn't exist
116 self._ensure_table()
118 # Detect and enable vector support if requested
119 if self.vector_enabled:
120 self._detect_vector_support()
122 self._connected = True
123 self.log_operation("connect", f"Connected to table: {self.schema_name}.{self.table_name}")
125 def close(self) -> None:
126 """Close the database connection."""
127 if self.db:
128 # PostgresDB manages its own connections via context managers
129 # but we can mark as disconnected
130 self._connected = False # type: ignore[unreachable]
132 def _initialize(self) -> None:
133 """Initialize method - connection setup moved to connect()."""
134 # Configuration parsing stays here, actual connection in connect()
135 pass
137 def _detect_vector_support(self) -> None:
138 """Detect and enable vector support if pgvector is available."""
139 from .postgres_vector import check_pgvector_extension_sync, install_pgvector_extension_sync
141 try:
142 # Check if pgvector is installed
143 if check_pgvector_extension_sync(self.db):
144 self._vector_enabled = True
145 logger.info("pgvector extension detected and enabled")
146 else:
147 # Try to install it
148 if install_pgvector_extension_sync(self.db):
149 self._vector_enabled = True
150 logger.info("pgvector extension installed and enabled")
151 else:
152 logger.debug("pgvector extension not available")
153 except Exception as e:
154 logger.debug(f"Could not enable vector support: {e}")
155 self._vector_enabled = False
157 def _ensure_table(self) -> None:
158 """Ensure the records table exists."""
159 if not self.db:
160 raise RuntimeError("Database not connected. Call connect() first.")
162 create_table_sql = self.get_create_table_sql(self.schema_name, self.table_name) # type: ignore[unreachable]
163 self.db.execute(create_table_sql)
166 def _record_to_row(self, record: Record, id: str | None = None) -> dict[str, Any]:
167 """Convert a Record to a database row."""
168 return {
169 "id": id or str(uuid.uuid4()),
170 "data": self.record_to_json(record),
171 "metadata": json.dumps(record.metadata) if record.metadata else None,
172 }
174 def _row_to_record(self, row: dict[str, Any]) -> Record:
175 """Convert a database row to a Record."""
176 return self.row_to_record(row)
178 def create(self, record: Record) -> str:
179 """Create a new record."""
180 self._check_connection()
181 # Use record's ID if it has one, otherwise generate a new one
182 id = record.id if record.id else str(uuid.uuid4())
183 row = self._record_to_row(record, id)
185 sql = f"""
186 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata)
187 VALUES (%(id)s, %(data)s, %(metadata)s)
188 """
189 self.db.execute(sql, row)
190 return id
192 def read(self, id: str) -> Record | None:
193 """Read a record by ID."""
194 self._check_connection()
195 sql = f"""
196 SELECT id, data, metadata
197 FROM {self.schema_name}.{self.table_name}
198 WHERE id = %(id)s
199 """
200 df = self.db.query(sql, {"id": id})
202 if df.empty:
203 return None
205 row = df.iloc[0].to_dict()
206 return self._row_to_record(row)
208 def update(self, id: str, record: Record) -> bool:
209 """Update an existing record.
211 Args:
212 id: The record ID to update
213 record: The record data to update with
215 Returns:
216 True if the record was updated, False if no record with the given ID exists
217 """
218 self._check_connection()
219 row = self._record_to_row(record, id)
221 sql = f"""
222 UPDATE {self.schema_name}.{self.table_name}
223 SET data = %(data)s, metadata = %(metadata)s, updated_at = CURRENT_TIMESTAMP
224 WHERE id = %(id)s
225 """
226 result = self.db.execute(sql, row)
228 # PostgresDB.execute returns number of affected rows
229 rows_affected = result if isinstance(result, int) else 0
231 if rows_affected == 0:
232 logger.warning(f"Update affected 0 rows for id={id}. Record may not exist.")
234 return rows_affected > 0
236 def delete(self, id: str) -> bool:
237 """Delete a record by ID."""
238 self._check_connection()
239 sql = f"""
240 DELETE FROM {self.schema_name}.{self.table_name}
241 WHERE id = %(id)s
242 """
243 result = self.db.execute(sql, {"id": id})
244 return result > 0 if isinstance(result, int) else False
246 def exists(self, id: str) -> bool:
247 """Check if a record exists."""
248 self._check_connection()
249 sql = f"""
250 SELECT 1 FROM {self.schema_name}.{self.table_name}
251 WHERE id = %(id)s
252 LIMIT 1
253 """
254 df = self.db.query(sql, {"id": id})
255 return not df.empty
257 def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str:
258 """Update or insert a record.
260 Can be called as:
261 - upsert(id, record) - explicit ID and record
262 - upsert(record) - extract ID from record using Record's built-in logic
263 """
264 self._check_connection()
266 # Determine ID and record based on arguments
267 if isinstance(id_or_record, str):
268 id = id_or_record
269 if record is None:
270 raise ValueError("Record required when ID is provided")
271 else:
272 record = id_or_record
273 id = record.id
274 if id is None:
275 import uuid # type: ignore[unreachable]
276 id = str(uuid.uuid4())
277 record.storage_id = id
279 if self.exists(id):
280 self.update(id, record)
281 else:
282 # Insert with specific ID
283 row = self._record_to_row(record, id)
284 sql = f"""
285 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata)
286 VALUES (%(id)s, %(data)s, %(metadata)s)
287 """
288 self.db.execute(sql, row)
289 return id
291 def search(self, query: Query | ComplexQuery) -> list[Record]:
292 """Search for records matching the query."""
293 self._check_connection()
295 # Handle ComplexQuery with native SQL support
296 if isinstance(query, ComplexQuery):
297 sql_query, params_list = self.query_builder.build_complex_search_query(query)
298 else:
299 sql_query, params_list = self.query_builder.build_search_query(query)
301 # Build params dict for psycopg2
302 # The query builder now generates %(p0)s style placeholders directly
303 params_dict = {}
304 if params_list:
305 for i, param in enumerate(params_list):
306 params_dict[f"p{i}"] = param
308 # Execute query
309 df = self.db.query(sql_query, params_dict)
311 # Convert to records
312 records = []
313 for _, row in df.iterrows():
314 row_dict = row.to_dict()
315 record = self._row_to_record(row_dict)
317 # Populate storage_id from database ID
318 record.storage_id = str(row_dict['id'])
320 # Apply field projection if specified
321 if query.fields:
322 record = record.project(query.fields)
324 records.append(record)
326 return records
328 def _count_all(self) -> int:
329 """Count all records in the database."""
330 self._check_connection()
331 sql = f"SELECT COUNT(*) as count FROM {self.schema_name}.{self.table_name}"
332 df = self.db.query(sql)
333 return int(df.iloc[0]["count"]) if not df.empty else 0
335 def clear(self) -> int:
336 """Clear all records from the database."""
337 self._check_connection()
338 # Get count first
339 count = self._count_all()
341 # Delete all records
342 sql = f"TRUNCATE TABLE {self.schema_name}.{self.table_name}"
343 self.db.execute(sql)
345 return count
347 def create_batch(self, records: list[Record]) -> list[str]:
348 """Create multiple records efficiently using a single query.
350 Uses multi-value INSERT for better performance.
352 Args:
353 records: List of records to create
355 Returns:
356 List of created record IDs
357 """
358 if not records:
359 return []
361 self._check_connection()
363 # Create a query builder for PostgreSQL with pyformat style
364 from .sql_base import SQLQueryBuilder
365 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat")
367 # Use the shared batch create query builder
368 query, params_list, ids = query_builder.build_batch_create_query(records)
370 # Build params dict for psycopg2
371 params_dict = {}
372 for i, param in enumerate(params_list):
373 params_dict[f"p{i}"] = param
375 # Execute the batch insert and get returned IDs
376 result_df = self.db.query(query, params_dict)
378 # PostgreSQL RETURNING clause gives us the actual inserted IDs
379 if not result_df.empty:
380 return result_df['id'].tolist()
381 return ids
383 def delete_batch(self, ids: list[str]) -> list[bool]:
384 """Delete multiple records efficiently using a single query.
386 Uses single DELETE with IN clause for better performance.
388 Args:
389 ids: List of record IDs to delete
391 Returns:
392 List of success flags for each deletion
393 """
394 if not ids:
395 return []
397 self._check_connection()
399 # Create a query builder for PostgreSQL with pyformat style
400 from .sql_base import SQLQueryBuilder
401 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat")
403 # Use the shared batch delete query builder (includes RETURNING clause)
404 query, params_list = query_builder.build_batch_delete_query(ids)
406 # Build params dict for psycopg2
407 params_dict = {}
408 for i, param in enumerate(params_list):
409 params_dict[f"p{i}"] = param
411 # Execute the batch delete and get returned IDs
412 result_df = self.db.query(query, params_dict)
414 # Get list of deleted IDs from RETURNING clause
415 deleted_ids = set(result_df['id'].tolist()) if not result_df.empty else set()
417 # Return results based on which IDs were actually deleted
418 results = []
419 for id in ids:
420 results.append(id in deleted_ids)
422 return results
424 def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
425 """Update multiple records efficiently using a single query.
427 Uses PostgreSQL's CASE expressions for batch updates via shared SQL builder.
429 Args:
430 updates: List of (id, record) tuples to update
432 Returns:
433 List of success flags for each update
434 """
435 if not updates:
436 return []
438 self._check_connection()
440 # Create a query builder for PostgreSQL with pyformat style
441 from .sql_base import SQLQueryBuilder
442 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat")
444 # Use the shared batch update query builder
445 query, params_list = query_builder.build_batch_update_query(updates)
447 # Build params dict for psycopg2
448 params_dict = {}
449 for i, param in enumerate(params_list):
450 params_dict[f"p{i}"] = param
452 # Execute the batch update and get returned IDs (query now includes RETURNING clause)
453 result_df = self.db.query(query, params_dict)
455 # Get list of updated IDs from RETURNING clause
456 updated_ids = set(result_df['id'].tolist()) if not result_df.empty else set()
458 results = []
459 for record_id, _ in updates:
460 results.append(record_id in updated_ids)
462 return results
464 def stream_read(
465 self,
466 query: Query | None = None,
467 config: StreamConfig | None = None
468 ) -> Iterator[Record]:
469 """Stream records from PostgreSQL."""
470 self._check_connection()
471 config = config or StreamConfig()
473 # Build SQL query
474 sql = f"SELECT id, data, metadata FROM {self.schema_name}.{self.table_name}"
475 params = {}
477 if query and query.filters:
478 # Add WHERE clause (simplified for now)
479 where_clauses = []
480 for i, filter in enumerate(query.filters):
481 field_path = f"data->>'{filter.field}'"
482 param_name = f"param_{i}"
484 if filter.operator == Operator.EQ:
485 where_clauses.append(f"{field_path} = %({param_name})s")
486 params[param_name] = str(filter.value)
488 if where_clauses:
489 sql += " WHERE " + " AND ".join(where_clauses)
491 # Use cursor for streaming
492 # Note: PostgresDB may need modification to support cursors
493 # For now, we'll fetch in batches
494 sql += f" LIMIT {config.batch_size} OFFSET %(offset)s"
496 offset = 0
497 while True:
498 params["offset"] = offset
499 df = self.db.query(sql, params)
501 if df.empty:
502 break
504 for _, row in df.iterrows():
505 record = self._row_to_record(row.to_dict())
506 if query and query.fields:
507 record = record.project(query.fields)
508 yield record
510 offset += config.batch_size
512 # If we got less than batch_size, we're done
513 if len(df) < config.batch_size:
514 break
516 def stream_write(
517 self,
518 records: Iterator[Record],
519 config: StreamConfig | None = None
520 ) -> StreamResult:
521 """Stream records into PostgreSQL."""
522 self._check_connection()
523 config = config or StreamConfig()
524 result = StreamResult()
525 start_time = time.time()
526 quitting = False
528 batch = []
529 for record in records:
530 batch.append(record)
532 if len(batch) >= config.batch_size:
533 # Write batch with graceful fallback
534 # Use lambda wrapper for _write_batch
535 continue_processing = process_batch_with_fallback(
536 batch,
537 lambda b: self._write_batch(b),
538 self.create,
539 result,
540 config
541 )
543 if not continue_processing:
544 quitting = True
545 break
547 batch = []
549 # Write remaining batch
550 if batch and not quitting:
551 process_batch_with_fallback(
552 batch,
553 lambda b: self._write_batch(b),
554 self.create,
555 result,
556 config
557 )
559 result.duration = time.time() - start_time
560 return result
562 def _write_batch(self, records: list[Record]) -> list[str]:
563 """Write a batch of records to the database.
565 Returns:
566 List of created record IDs
567 """
568 # Build batch insert SQL
569 values = []
570 params = {}
571 ids = []
573 for i, record in enumerate(records):
574 id = str(uuid.uuid4())
575 ids.append(id)
576 row = self._record_to_row(record, id)
577 values.append(f"(%(id_{i})s, %(data_{i})s, %(metadata_{i})s)")
578 params[f"id_{i}"] = row["id"]
579 params[f"data_{i}"] = row["data"]
580 params[f"metadata_{i}"] = row["metadata"]
582 sql = f"""
583 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata)
584 VALUES {', '.join(values)}
585 """
586 self.db.execute(sql, params)
587 return ids
589 def vector_search(
590 self,
591 query_vector: np.ndarray | list[float] | VectorField,
592 field_name: str,
593 k: int = 10,
594 filter: Query | None = None,
595 metric: DistanceMetric | str = "cosine"
596 ) -> list[VectorSearchResult]:
597 """Search for similar vectors using PostgreSQL pgvector.
599 Args:
600 query_vector: Query vector (numpy array, list, or VectorField)
601 field_name: Name of vector field to search (must be in data JSON)
602 limit: Maximum number of results
603 filters: Optional filters to apply
604 metric: Distance metric to use (cosine, euclidean, l2, inner_product)
606 Returns:
607 List of VectorSearchResult objects ordered by similarity
608 """
609 if not self._vector_enabled:
610 raise RuntimeError("Vector search not available - pgvector not installed")
612 self._check_connection()
614 from ..fields import VectorField
615 from ..vector.types import DistanceMetric, VectorSearchResult
616 from .postgres_vector import format_vector_for_postgres, get_vector_operator
618 # Convert query vector to proper format
619 if isinstance(query_vector, VectorField):
620 vector_str = format_vector_for_postgres(query_vector.value)
621 else:
622 vector_str = format_vector_for_postgres(query_vector)
624 # Get the appropriate operator
625 if isinstance(metric, DistanceMetric):
626 metric_str = metric.value
627 else:
628 metric_str = str(metric).lower()
630 operator = get_vector_operator(metric_str)
632 # Build the query - vectors are stored in JSON data field
633 # Use centralized vector extraction logic
634 vector_expr = self.get_vector_extraction_sql(field_name, dialect="postgres")
636 # Build the base SQL with pyformat placeholders
637 sql = f"""
638 SELECT
639 id,
640 data,
641 metadata,
642 {vector_expr} {operator} %(p0)s::vector AS distance
643 FROM {self.schema_name}.{self.table_name}
644 WHERE data ? %(p1)s -- Check field exists
645 """
647 params: list[Any] = [vector_str, field_name]
649 # Add filters if provided using the query builder
650 if filter:
651 # Query builder will generate pyformat placeholders since we configured it that way
652 where_clause, filter_params = self.query_builder.build_where_clause(filter, len(params) + 1)
653 if where_clause:
654 sql += where_clause
655 params.extend(filter_params)
657 # Order by distance and limit
658 next_param = len(params)
659 sql += f" ORDER BY distance LIMIT %(p{next_param})s"
660 params.append(k)
662 # Build param dict for psycopg2
663 param_dict = {}
664 for i, param in enumerate(params):
665 param_dict[f"p{i}"] = param
667 df = self.db.query(sql, param_dict)
669 # Convert results
670 results = []
671 for _, row in df.iterrows():
672 record = self._row_to_record(row)
674 # Calculate similarity score from distance
675 distance = row["distance"]
676 if metric_str in ["cosine", "cosine_similarity"]:
677 score = 1.0 - distance # Cosine distance to similarity
678 elif metric_str in ["euclidean", "l2"]:
679 score = 1.0 / (1.0 + distance) # Convert distance to similarity
680 elif metric_str in ["inner_product", "dot_product"]:
681 score = -distance # Negative because pgvector uses negative for descending
682 else:
683 score = -distance # Default: lower distance = better
685 result = VectorSearchResult(
686 record=record,
687 score=float(score),
688 vector_field=field_name
689 )
690 results.append(result)
692 return results
694 def has_vector_support(self) -> bool:
695 """Check if this database has vector support enabled.
697 Returns:
698 True if vector operations are supported
699 """
700 return self._vector_enabled
702 def enable_vector_support(self) -> bool:
703 """Enable vector support for this database if possible.
705 Returns:
706 True if vector support is now enabled
707 """
708 if self._vector_enabled:
709 return True
711 self._detect_vector_support()
712 return self._vector_enabled
714 def bulk_embed_and_store(
715 self,
716 records: list[Record],
717 text_field: str | list[str],
718 vector_field: str = "embedding",
719 embedding_fn: Any = None,
720 batch_size: int = 100,
721 model_name: str | None = None,
722 model_version: str | None = None,
723 ) -> list[str]:
724 """Embed text fields and store vectors with records (stub for abstract requirement).
726 This is a placeholder implementation to satisfy the abstract method requirement.
727 Full implementation would require actual embedding function.
728 """
729 raise NotImplementedError("bulk_embed_and_store requires an embedding function")
732# Global pool manager instance for async PostgreSQL connections
733_pool_manager = ConnectionPoolManager[asyncpg.Pool]()
736class AsyncPostgresDatabase(
737 AsyncDatabase,
738 VectorOperationsMixin,
739 ConfigurableBase,
740 PostgresBaseConfig,
741 PostgresTableManager,
742 PostgresVectorSupport,
743 PostgresConnectionValidator,
744 PostgresErrorHandler,
745):
746 """Native async PostgreSQL database backend with vector support and event loop-aware connection pooling."""
748 def __init__(self, config: dict[str, Any] | None = None):
749 """Initialize async PostgreSQL database."""
750 super().__init__(config)
752 # Parse configuration using mixin
753 table_name, schema_name, conn_config = self._parse_postgres_config(config or {})
754 self._init_postgres_attributes(table_name, schema_name)
756 # Extract pool configuration
757 self._pool_config = PostgresPoolConfig.from_dict(conn_config)
758 self._pool: asyncpg.Pool | None = None
760 @classmethod
761 def from_config(cls, config: dict) -> AsyncPostgresDatabase:
762 """Create from config dictionary."""
763 return cls(config)
765 async def connect(self) -> None:
766 """Connect to the database."""
767 if self._connected:
768 return
770 # Get or create pool for current event loop
771 from ..pooling import BasePoolConfig
772 self._pool = await _pool_manager.get_pool(
773 self._pool_config,
774 cast("Callable[[BasePoolConfig], Awaitable[Any]]", create_asyncpg_pool),
775 validate_asyncpg_pool
776 )
778 # Initialize query builder
779 self.query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres")
781 # Ensure table exists
782 await self._ensure_table()
784 # Check and enable vector support if requested
785 if self.vector_enabled:
786 await self._detect_vector_support()
788 self._connected = True
789 self.log_operation("connect", f"Connected to table: {self.schema_name}.{self.table_name}")
791 async def close(self) -> None:
792 """Close the database connection and properly close the pool."""
793 if self._connected:
794 # Properly close the pool if we have one
795 if self._pool:
796 try:
797 await self._pool.close()
798 except Exception as e:
799 logger.warning(f"Error closing connection pool: {e}")
800 self._pool = None
801 self._connected = False
803 def _initialize(self) -> None:
804 """Initialize is handled in connect."""
805 pass
807 async def _ensure_table(self) -> None:
808 """Ensure the records table exists."""
809 if not self._pool:
810 raise RuntimeError("Database not connected. Call connect() first.")
812 create_table_sql = self.get_create_table_sql(self.schema_name, self.table_name)
814 async with self._pool.acquire() as conn:
815 await conn.execute(create_table_sql)
817 async def _detect_vector_support(self) -> None:
818 """Detect and enable vector support if pgvector is available."""
819 from .postgres_vector import check_pgvector_extension, install_pgvector_extension
821 async with self._pool.acquire() as conn:
822 # Check if pgvector is available
823 if await check_pgvector_extension(conn):
824 self._vector_enabled = True
825 logger.info("pgvector extension detected and enabled")
826 else:
827 # Try to install it
828 if await install_pgvector_extension(conn):
829 self._vector_enabled = True
830 logger.info("pgvector extension installed and enabled")
831 else:
832 logger.debug("pgvector extension not available")
834 async def _ensure_vector_column(self, field_name: str, dimensions: int) -> None:
835 """Ensure a vector column exists for the given field.
837 Args:
838 field_name: Name of the vector field
839 dimensions: Number of dimensions
840 """
841 if not self._vector_enabled:
842 return
844 column_name = f"vector_{field_name}"
846 # Check if column already exists
847 check_sql = """
848 SELECT column_name FROM information_schema.columns
849 WHERE table_schema = $1 AND table_name = $2 AND column_name = $3
850 """
852 async with self._pool.acquire() as conn:
853 existing = await conn.fetchval(check_sql, self.schema_name, self.table_name, column_name)
855 if not existing:
856 # Add vector column
857 alter_sql = f"""
858 ALTER TABLE {self.schema_name}.{self.table_name}
859 ADD COLUMN IF NOT EXISTS {column_name} vector({dimensions})
860 """
861 try:
862 await conn.execute(alter_sql)
863 self._vector_dimensions[field_name] = dimensions
864 logger.info(f"Added vector column {column_name} with {dimensions} dimensions")
866 # Create index for the vector column
867 from .postgres_vector import build_vector_index_sql, get_optimal_index_type
869 # Get row count for optimal index selection
870 count_sql = f"SELECT COUNT(*) FROM {self.schema_name}.{self.table_name}"
871 count = await conn.fetchval(count_sql)
873 index_type, index_params = get_optimal_index_type(count)
874 index_sql = build_vector_index_sql(
875 self.table_name,
876 self.schema_name,
877 column_name,
878 dimensions,
879 metric="cosine",
880 index_type=index_type,
881 index_params=index_params
882 )
884 # Note: IVFFlat requires table to have data before creating index
885 if count > 0 or index_type != "ivfflat":
886 await conn.execute(index_sql)
887 logger.info(f"Created {index_type} index for {column_name}")
889 except Exception as e:
890 logger.warning(f"Could not create vector column {column_name}: {e}")
891 else:
892 self._vector_dimensions[field_name] = dimensions
894 def _check_connection(self) -> None:
895 """Check if async database is connected."""
896 self._check_async_connection()
898 def _record_to_row(self, record: Record, id: str | None = None) -> dict[str, Any]:
899 """Convert a Record to a database row using common serializer."""
900 from .sql_base import SQLRecordSerializer
902 return {
903 "id": id or str(uuid.uuid4()),
904 "data": SQLRecordSerializer.record_to_json(record),
905 "metadata": json.dumps(record.metadata) if record.metadata else None,
906 }
908 def _row_to_record(self, row: asyncpg.Record) -> Record:
909 """Convert a database row to a Record using the common serializer."""
910 from .sql_base import SQLRecordSerializer
912 # Convert asyncpg.Record to dict format expected by SQLRecordSerializer
913 data_json = row.get("data", {})
914 if not isinstance(data_json, str):
915 data_json = json.dumps(data_json)
917 metadata_json = row.get("metadata")
918 if metadata_json and not isinstance(metadata_json, str):
919 metadata_json = json.dumps(metadata_json)
921 # Use the common serializer to reconstruct the record
922 return SQLRecordSerializer.json_to_record(data_json, metadata_json)
924 async def create(self, record: Record) -> str:
925 """Create a new record with vector support."""
926 self._check_connection()
928 # Check for vector fields and ensure columns exist
929 from ..fields import VectorField
930 for field_name, field_obj in record.fields.items():
931 if isinstance(field_obj, VectorField) and self._vector_enabled:
932 await self._ensure_vector_column(field_name, field_obj.dimensions)
934 # Use record's ID if it has one, otherwise generate a new one
935 id = record.id if record.id else str(uuid.uuid4())
936 row = self._record_to_row(record, id)
938 # Build dynamic SQL based on vector columns present
939 columns = ["id", "data", "metadata"]
940 values = [row["id"], row["data"], row["metadata"]]
941 placeholders = ["$1", "$2", "$3"]
943 # Add vector columns
944 param_num = 4
945 for key, value in row.items():
946 if key.startswith("vector_"):
947 columns.append(key)
948 values.append(value)
949 placeholders.append(f"${param_num}")
950 param_num += 1
952 sql = f"""
953 INSERT INTO {self.schema_name}.{self.table_name} ({', '.join(columns)})
954 VALUES ({', '.join(placeholders)})
955 """
957 async with self._pool.acquire() as conn:
958 await conn.execute(sql, *values)
960 return id
962 async def read(self, id: str) -> Record | None:
963 """Read a record by ID."""
964 self._check_connection()
965 sql = f"""
966 SELECT id, data, metadata
967 FROM {self.schema_name}.{self.table_name}
968 WHERE id = $1
969 """
971 async with self._pool.acquire() as conn:
972 row = await conn.fetchrow(sql, id)
974 if not row:
975 return None
977 return self._row_to_record(row)
979 async def update(self, id: str, record: Record) -> bool:
980 """Update an existing record.
982 Args:
983 id: The record ID to update
984 record: The record data to update with
986 Returns:
987 True if the record was updated, False if no record with the given ID exists
988 """
989 self._check_connection()
990 row = self._record_to_row(record, id)
992 sql = f"""
993 UPDATE {self.schema_name}.{self.table_name}
994 SET data = $2, metadata = $3, updated_at = CURRENT_TIMESTAMP
995 WHERE id = $1
996 """
998 async with self._pool.acquire() as conn:
999 result = await conn.execute(sql, row["id"], row["data"], row["metadata"])
1001 # Returns UPDATE n where n is rows affected
1002 rows_affected = int(result.split()[-1])
1004 if rows_affected == 0:
1005 logger.warning(f"Update affected 0 rows for id={id}. Record may not exist.")
1007 return rows_affected > 0
1009 async def delete(self, id: str) -> bool:
1010 """Delete a record by ID."""
1011 self._check_connection()
1012 sql = f"""
1013 DELETE FROM {self.schema_name}.{self.table_name}
1014 WHERE id = $1
1015 """
1017 async with self._pool.acquire() as conn:
1018 result = await conn.execute(sql, id)
1020 # Returns DELETE n where n is rows affected
1021 return result.split()[-1] != "0"
1023 async def exists(self, id: str) -> bool:
1024 """Check if a record exists."""
1025 self._check_connection()
1026 sql = f"""
1027 SELECT 1 FROM {self.schema_name}.{self.table_name}
1028 WHERE id = $1
1029 LIMIT 1
1030 """
1032 async with self._pool.acquire() as conn:
1033 row = await conn.fetchrow(sql, id)
1035 return row is not None
1037 async def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str:
1038 """Update or insert a record.
1040 Can be called as:
1041 - upsert(id, record) - explicit ID and record
1042 - upsert(record) - extract ID from record using Record's built-in logic
1043 """
1044 self._check_connection()
1046 # Determine ID and record based on arguments
1047 if isinstance(id_or_record, str):
1048 id = id_or_record
1049 if record is None:
1050 raise ValueError("Record required when ID is provided")
1051 else:
1052 record = id_or_record
1053 id = record.id
1054 if id is None:
1055 import uuid # type: ignore[unreachable]
1056 id = str(uuid.uuid4())
1057 record.storage_id = id
1059 row = self._record_to_row(record, id)
1061 sql = f"""
1062 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata)
1063 VALUES ($1, $2, $3)
1064 ON CONFLICT (id) DO UPDATE
1065 SET data = EXCLUDED.data, metadata = EXCLUDED.metadata, updated_at = CURRENT_TIMESTAMP
1066 """
1068 async with self._pool.acquire() as conn:
1069 await conn.execute(sql, row["id"], row["data"], row["metadata"])
1071 return id
1073 async def search(self, query: Query | ComplexQuery) -> list[Record]:
1074 """Search for records matching the query."""
1075 self._check_connection()
1077 # Initialize query builder if not already done
1078 if not hasattr(self, 'query_builder'):
1079 self.query_builder = SQLQueryBuilder(
1080 self.table_name, self.schema_name, dialect="postgres"
1081 )
1083 # Handle ComplexQuery with native SQL support
1084 if isinstance(query, ComplexQuery):
1085 sql, params = self.query_builder.build_complex_search_query(query)
1086 else:
1087 sql, params = self.query_builder.build_search_query(query)
1089 # Execute query with asyncpg (already uses positional parameters)
1090 async with self._pool.acquire() as conn:
1091 rows = await conn.fetch(sql, *params)
1093 # Convert to records
1094 records = []
1095 for row in rows:
1096 record = self._row_to_record(row)
1098 # Populate storage_id from database ID
1099 record.storage_id = str(row['id'])
1101 # Apply field projection if specified
1102 if query.fields:
1103 record = record.project(query.fields)
1105 records.append(record)
1107 return records
1109 async def _count_all(self) -> int:
1110 """Count all records in the database."""
1111 self._check_connection()
1112 sql = f"SELECT COUNT(*) as count FROM {self.schema_name}.{self.table_name}"
1114 async with self._pool.acquire() as conn:
1115 row = await conn.fetchrow(sql)
1117 return row["count"] if row else 0
1119 async def clear(self) -> int:
1120 """Clear all records from the database."""
1121 self._check_connection()
1122 # Get count first
1123 count = await self._count_all()
1125 # Delete all records
1126 sql = f"TRUNCATE TABLE {self.schema_name}.{self.table_name}"
1128 async with self._pool.acquire() as conn:
1129 await conn.execute(sql)
1131 return count
1133 async def create_batch(self, records: list[Record]) -> list[str]:
1134 """Create multiple records efficiently using a single query.
1136 Uses multi-value INSERT with RETURNING for better performance.
1138 Args:
1139 records: List of records to create
1141 Returns:
1142 List of created record IDs
1143 """
1144 if not records:
1145 return []
1147 self._check_connection()
1149 # Create a query builder for PostgreSQL
1150 from .sql_base import SQLQueryBuilder
1151 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres")
1153 # Use the shared batch create query builder
1154 query, params, ids = query_builder.build_batch_create_query(records)
1156 # Execute the batch insert with RETURNING
1157 async with self._pool.acquire() as conn:
1158 rows = await conn.fetch(query, *params)
1160 # Return the actual inserted IDs from RETURNING clause
1161 if rows:
1162 return [row["id"] for row in rows]
1163 return ids # Fallback to generated IDs
1165 async def delete_batch(self, ids: list[str]) -> list[bool]:
1166 """Delete multiple records efficiently using a single query.
1168 Uses single DELETE with IN clause and RETURNING for verification.
1170 Args:
1171 ids: List of record IDs to delete
1173 Returns:
1174 List of success flags for each deletion
1175 """
1176 if not ids:
1177 return []
1179 self._check_connection()
1181 # Create a query builder for PostgreSQL
1182 from .sql_base import SQLQueryBuilder
1183 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres")
1185 # Use the shared batch delete query builder
1186 query, params = query_builder.build_batch_delete_query(ids)
1188 # Execute the batch delete with RETURNING
1189 async with self._pool.acquire() as conn:
1190 rows = await conn.fetch(query, *params)
1192 # Convert returned rows to set of deleted IDs
1193 deleted_ids = {row["id"] for row in rows}
1195 # Return results for each deletion
1196 results = []
1197 for id in ids:
1198 results.append(id in deleted_ids)
1200 return results
1202 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
1203 """Update multiple records efficiently using a single query.
1205 Uses PostgreSQL's CASE expressions for batch updates with native asyncpg.
1207 Args:
1208 updates: List of (id, record) tuples to update
1210 Returns:
1211 List of success flags for each update
1212 """
1213 if not updates:
1214 return []
1216 self._check_connection()
1218 # Create a query builder for PostgreSQL
1219 from .sql_base import SQLQueryBuilder
1220 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres")
1222 # Use the shared batch update query builder
1223 # It already produces positional parameters ($1, $2) for PostgreSQL
1224 query, params = query_builder.build_batch_update_query(updates)
1226 # Add RETURNING clause for PostgreSQL to get updated IDs
1227 query = query.rstrip() + " RETURNING id"
1229 # Execute the batch update
1230 async with self._pool.acquire() as conn:
1231 rows = await conn.fetch(query, *params)
1233 # Convert returned rows to set of updated IDs
1234 updated_ids = {row["id"] for row in rows}
1236 # Return results for each update
1237 results = []
1238 for record_id, _ in updates:
1239 results.append(record_id in updated_ids)
1241 return results
1243 async def vector_search(
1244 self,
1245 query_vector: np.ndarray | list[float] | VectorField,
1246 field_name: str,
1247 k: int = 10,
1248 filter: Query | None = None,
1249 metric: DistanceMetric | str = "cosine"
1250 ) -> list[VectorSearchResult]:
1251 """Search for similar vectors using PostgreSQL pgvector.
1253 Args:
1254 query_vector: Query vector (numpy array, list, or VectorField)
1255 field_name: Name of vector field to search
1256 limit: Maximum number of results
1257 filters: Optional filters to apply
1258 metric: Distance metric to use
1260 Returns:
1261 List of VectorSearchResult objects
1262 """
1263 if not self._vector_enabled:
1264 raise RuntimeError("Vector search not available - pgvector not installed")
1266 self._check_connection()
1268 from ..fields import VectorField
1269 from ..vector.types import DistanceMetric, VectorSearchResult
1270 from .postgres_vector import format_vector_for_postgres, get_vector_operator
1272 # Convert query vector to proper format
1273 if isinstance(query_vector, VectorField):
1274 vector_str = format_vector_for_postgres(query_vector.value)
1275 else:
1276 vector_str = format_vector_for_postgres(query_vector)
1278 # Get the appropriate operator
1279 if isinstance(metric, DistanceMetric):
1280 metric_str = metric.value
1281 else:
1282 metric_str = str(metric).lower()
1283 operator = get_vector_operator(metric_str)
1285 vector_column = f"vector_{field_name}"
1287 # Build query
1288 sql = f"""
1289 SELECT id, data, metadata, {vector_column},
1290 {vector_column} {operator} $1::vector AS distance
1291 FROM {self.schema_name}.{self.table_name}
1292 WHERE {vector_column} IS NOT NULL
1293 """
1295 params = [vector_str]
1296 param_num = 2
1298 # Add filters if provided using the query builder
1299 if filter:
1300 # First get the where clause from query builder
1301 where_clause, filter_params = self.query_builder.build_where_clause(filter, param_num)
1302 if where_clause:
1303 # Convert %s placeholders to $N for asyncpg
1304 for param in filter_params:
1305 where_clause = where_clause.replace("%s", f"${param_num}", 1)
1306 params.append(param)
1307 param_num += 1
1308 sql += where_clause
1310 # Order by distance and limit
1311 sql += f"""
1312 ORDER BY distance
1313 LIMIT {k}
1314 """
1316 # Execute query
1317 async with self._pool.acquire() as conn:
1318 rows = await conn.fetch(sql, *params)
1320 # Convert to VectorSearchResult objects
1321 results = []
1322 for row in rows:
1323 record = self._row_to_record(row)
1325 # Convert distance to similarity score (1 - normalized_distance for cosine)
1326 distance = float(row['distance'])
1327 if metric_str == "cosine":
1328 score = 1.0 - min(distance, 2.0) / 2.0 # Normalize cosine distance [0,2] to similarity [0,1]
1329 elif metric_str in ["euclidean", "l2"]:
1330 score = 1.0 / (1.0 + distance) # Convert distance to similarity
1331 else:
1332 score = 1.0 - distance # Generic conversion
1334 result = VectorSearchResult(
1335 record=record,
1336 score=score,
1337 vector_field=field_name,
1338 metadata={"distance": distance, "metric": metric_str}
1339 )
1340 results.append(result)
1342 return results
1344 async def enable_vector_support(self) -> bool:
1345 """Enable vector support for this database.
1347 Returns:
1348 True if vector support is enabled
1349 """
1350 if self._vector_enabled:
1351 return True
1353 await self._detect_vector_support()
1354 return self._vector_enabled
1356 async def has_vector_support(self) -> bool:
1357 """Check if this database has vector support enabled.
1359 Returns:
1360 True if vector support is available
1361 """
1362 return self._vector_enabled
1364 async def bulk_embed_and_store(
1365 self,
1366 records: list[Record],
1367 text_field: str | list[str],
1368 vector_field: str,
1369 embedding_fn: Any | None = None,
1370 batch_size: int = 100,
1371 model_name: str | None = None,
1372 model_version: str | None = None,
1373 ) -> list[str]:
1374 """Embed text fields and store vectors with records.
1376 This is a placeholder implementation. In a real scenario, you would:
1377 1. Extract text from the specified fields
1378 2. Call the embedding function to generate vectors
1379 3. Store the vectors alongside the records
1381 Args:
1382 records: Records to process
1383 text_field: Field name(s) containing text to embed
1384 vector_field: Field name to store vectors in
1385 embedding_fn: Function to generate embeddings
1386 batch_size: Number of records to process at once
1387 model_name: Name of the embedding model
1388 model_version: Version of the embedding model
1390 Returns:
1391 List of record IDs that were processed
1392 """
1393 if not embedding_fn:
1394 raise ValueError("embedding_fn is required for bulk_embed_and_store")
1396 from ..fields import VectorField
1398 processed_ids = []
1400 # Process in batches
1401 for i in range(0, len(records), batch_size):
1402 batch = records[i:i + batch_size]
1404 # Extract texts
1405 texts = []
1406 for record in batch:
1407 if isinstance(text_field, list):
1408 text = " ".join(str(record.fields.get(f, {}).value) for f in text_field if f in record.fields)
1409 else:
1410 text = str(record.fields.get(text_field, {}).value) if text_field in record.fields else ""
1411 texts.append(text)
1413 # Generate embeddings
1414 if texts:
1415 embeddings = await embedding_fn(texts)
1417 # Store vectors with records
1418 for j, record in enumerate(batch):
1419 if j < len(embeddings):
1420 vector = embeddings[j]
1422 # Add vector field to record
1423 record.fields[vector_field] = VectorField(
1424 name=vector_field,
1425 value=vector,
1426 dimensions=len(vector) if hasattr(vector, '__len__') else None,
1427 source_field=text_field if isinstance(text_field, str) else ",".join(text_field),
1428 model_name=model_name,
1429 model_version=model_version,
1430 )
1432 # Create or update record
1433 if record.has_storage_id():
1434 if record.storage_id is None:
1435 raise ValueError("Record has_storage_id() returned True but storage_id is None")
1436 await self.update(record.storage_id, record)
1437 else:
1438 record_id = await self.create(record)
1439 record.storage_id = record_id
1441 if record.storage_id is None:
1442 raise ValueError("Record storage_id is None after create/update")
1443 processed_ids.append(record.storage_id)
1445 return processed_ids
1447 async def create_vector_index(
1448 self,
1449 vector_field: str,
1450 dimensions: int,
1451 metric: DistanceMetric | str = "cosine",
1452 index_type: str = "ivfflat",
1453 lists: int | None = None,
1454 ) -> bool:
1455 """Create a vector index for efficient similarity search.
1457 Args:
1458 vector_field: Name of the vector field to index
1459 dimensions: Number of dimensions in the vectors
1460 metric: Distance metric for the index
1461 index_type: Type of index (ivfflat, hnsw)
1462 lists: Number of lists for IVFFlat index
1464 Returns:
1465 True if index was created successfully
1466 """
1467 from .postgres_vector import (
1468 build_vector_column_expression,
1469 build_vector_index_sql,
1470 get_optimal_index_type,
1471 get_vector_count_sql,
1472 )
1474 self._check_connection()
1476 if not self._vector_enabled:
1477 return False
1479 # Determine optimal parameters if not provided
1480 if not lists and index_type == "ivfflat":
1481 # Count vectors to determine optimal lists
1482 count_sql = get_vector_count_sql(self.schema_name, self.table_name, vector_field)
1483 async with self._pool.acquire() as conn:
1484 count = await conn.fetchval(count_sql) or 0
1485 _, params = get_optimal_index_type(count)
1486 lists = params.get("lists", 100)
1488 # Convert metric enum to string if needed
1489 if hasattr(metric, 'value'):
1490 metric_str = metric.value
1491 else:
1492 metric_str = str(metric).lower()
1494 # Build vector column expression for index
1495 column_expr = build_vector_column_expression(vector_field, dimensions, for_index=True)
1497 # Build index SQL - pass field_name for proper index naming
1498 index_sql = build_vector_index_sql(
1499 table_name=self.table_name,
1500 schema_name=self.schema_name,
1501 column_name=column_expr,
1502 dimensions=dimensions,
1503 metric=metric_str,
1504 index_type=index_type,
1505 index_params={"lists": lists} if lists else None,
1506 field_name=vector_field
1507 )
1509 # Create the index
1510 try:
1511 logger.debug(f"Creating vector index with SQL: {index_sql}")
1512 async with self._pool.acquire() as conn:
1513 await conn.execute(index_sql)
1514 return True
1515 except Exception as e:
1516 logger.warning(f"Failed to create vector index: {e}")
1517 logger.debug(f"Index SQL was: {index_sql}")
1518 return False
1520 async def drop_vector_index(self, vector_field: str, metric: str = "cosine") -> bool:
1521 """Drop a vector index.
1523 Args:
1524 vector_field: Name of the vector field
1525 metric: Distance metric used in the index
1527 Returns:
1528 True if index was dropped successfully
1529 """
1530 from .postgres_vector import get_vector_index_name
1532 self._check_connection()
1534 index_name = get_vector_index_name(self.table_name, vector_field, metric)
1536 try:
1537 async with self._pool.acquire() as conn:
1538 await conn.execute(f"DROP INDEX IF EXISTS {self.schema_name}.{index_name}")
1539 return True
1540 except Exception as e:
1541 logger.warning(f"Failed to drop vector index: {e}")
1542 return False
1544 async def get_vector_index_stats(self, vector_field: str) -> dict[str, Any]:
1545 """Get statistics about a vector field and its index.
1547 Args:
1548 vector_field: Name of the vector field
1550 Returns:
1551 Dictionary with index statistics
1552 """
1553 from .postgres_vector import get_index_check_sql, get_vector_count_sql
1555 self._check_connection()
1557 stats = {
1558 "field": vector_field,
1559 "indexed": False,
1560 "vector_count": 0,
1561 }
1563 try:
1564 async with self._pool.acquire() as conn:
1565 # Count vectors
1566 count_sql = get_vector_count_sql(self.schema_name, self.table_name, vector_field)
1567 stats["vector_count"] = await conn.fetchval(count_sql) or 0
1569 # Check for index
1570 index_sql, params = get_index_check_sql(self.schema_name, self.table_name, vector_field)
1571 stats["indexed"] = await conn.fetchval(index_sql, *params) or False
1572 except Exception as e:
1573 logger.warning(f"Failed to get vector index stats: {e}")
1575 return stats
1577 async def stream_read(
1578 self,
1579 query: Query | None = None,
1580 config: StreamConfig | None = None
1581 ) -> AsyncIterator[Record]:
1582 """Stream records from PostgreSQL using cursor."""
1583 self._check_connection()
1584 config = config or StreamConfig()
1586 # Build SQL query
1587 sql = f"SELECT id, data, metadata FROM {self.schema_name}.{self.table_name}"
1588 params = []
1590 if query and query.filters:
1591 where_clauses = []
1592 param_count = 0
1594 for filter in query.filters:
1595 param_count += 1
1596 field_path = f"data->>'{filter.field}'"
1598 if filter.operator == Operator.EQ:
1599 where_clauses.append(f"{field_path} = ${param_count}")
1600 params.append(str(filter.value))
1602 if where_clauses:
1603 sql += " WHERE " + " AND ".join(where_clauses)
1605 # Use cursor for efficient streaming
1606 async with self._pool.acquire() as conn:
1607 async with conn.transaction():
1608 cursor = await conn.cursor(sql, *params)
1610 batch = []
1611 async for row in cursor:
1612 record = self._row_to_record(row)
1613 if query and query.fields:
1614 record = record.project(query.fields)
1616 batch.append(record)
1618 if len(batch) >= config.batch_size:
1619 for rec in batch:
1620 yield rec
1621 batch = []
1623 # Yield remaining records
1624 for rec in batch:
1625 yield rec
1627 async def stream_write(
1628 self,
1629 records: AsyncIterator[Record],
1630 config: StreamConfig | None = None
1631 ) -> StreamResult:
1632 """Stream records into PostgreSQL using batch inserts."""
1633 self._check_connection()
1634 config = config or StreamConfig()
1635 result = StreamResult()
1636 start_time = time.time()
1637 quitting = False
1639 batch = []
1640 async for record in records:
1641 batch.append(record)
1643 if len(batch) >= config.batch_size:
1644 # Write batch with graceful fallback
1645 # Use lambda wrapper for _write_batch
1646 async def batch_func(b):
1647 await self._write_batch(b)
1648 return [r.id for r in b]
1650 continue_processing = await async_process_batch_with_fallback(
1651 batch,
1652 batch_func,
1653 self.create,
1654 result,
1655 config
1656 )
1658 if not continue_processing:
1659 quitting = True
1660 break
1662 batch = []
1664 # Write remaining batch
1665 if batch and not quitting:
1666 async def batch_func(b):
1667 await self._write_batch(b)
1668 return [r.id for r in b]
1670 await async_process_batch_with_fallback(
1671 batch,
1672 batch_func,
1673 self.create,
1674 result,
1675 config
1676 )
1678 result.duration = time.time() - start_time
1679 return result
1681 async def _write_batch(self, records: list[Record]) -> list[str]:
1682 """Write a batch of records using COPY for performance.
1684 Returns:
1685 List of created record IDs
1686 """
1687 if not records:
1688 return []
1690 # Prepare data for COPY
1691 rows = []
1692 ids = []
1693 for record in records:
1694 row_data = self._record_to_row(record)
1695 ids.append(row_data["id"])
1696 rows.append((
1697 row_data["id"],
1698 row_data["data"],
1699 row_data["metadata"]
1700 ))
1702 # Use COPY for efficient bulk insert
1703 async with self._pool.acquire() as conn:
1704 await conn.copy_records_to_table(
1705 f"{self.schema_name}.{self.table_name}",
1706 records=rows,
1707 columns=["id", "data", "metadata"]
1708 )
1710 return ids
1712 async def _supports_native_hybrid(self) -> bool:
1713 """Check if this PostgreSQL backend supports native hybrid search.
1715 PostgreSQL with pgvector and full-text search (tsvector) supports
1716 native hybrid search.
1718 Returns:
1719 True if vector support is enabled (pgvector available)
1720 """
1721 return self._vector_enabled
1723 async def hybrid_search(
1724 self,
1725 query_text: str,
1726 query_vector: np.ndarray | list[float],
1727 text_fields: list[str] | None = None,
1728 vector_field: str = "embedding",
1729 k: int = 10,
1730 config: Any = None, # HybridSearchConfig
1731 filter: Query | None = None,
1732 metric: DistanceMetric | str = DistanceMetric.COSINE,
1733 ) -> list[Any]: # list[HybridSearchResult]
1734 """Perform hybrid search using PostgreSQL full-text search and pgvector.
1736 Combines PostgreSQL's tsvector full-text search with pgvector similarity
1737 search using configurable fusion strategies.
1739 Args:
1740 query_text: Text query for full-text matching
1741 query_vector: Vector for pgvector similarity search
1742 text_fields: Fields to search for text matching
1743 vector_field: Name of the vector field to search
1744 k: Number of results to return
1745 config: Hybrid search configuration (weights, fusion strategy)
1746 filter: Optional additional filters to apply
1747 metric: Distance metric for vector search
1749 Returns:
1750 List of HybridSearchResult ordered by combined score (descending)
1751 """
1752 from ..vector.hybrid import (
1753 FusionStrategy,
1754 HybridSearchConfig,
1755 HybridSearchResult,
1756 reciprocal_rank_fusion,
1757 weighted_score_fusion,
1758 )
1759 from .postgres_vector import format_vector_for_postgres, get_vector_operator
1761 self._check_connection()
1763 config = config or HybridSearchConfig()
1765 # For NATIVE strategy with pgvector, we can do a combined query
1766 # For other strategies, use the parent implementation
1767 if config.fusion_strategy not in (FusionStrategy.NATIVE, FusionStrategy.RRF):
1768 from ..vector.mixins import VectorOperationsMixin
1769 return await VectorOperationsMixin.hybrid_search(
1770 self,
1771 query_text=query_text,
1772 query_vector=query_vector,
1773 text_fields=text_fields,
1774 vector_field=vector_field,
1775 k=k,
1776 config=config,
1777 filter=filter,
1778 metric=metric,
1779 )
1781 # Use config.text_fields if provided, otherwise use parameter
1782 search_text_fields = config.text_fields or text_fields or ["content", "title", "text"]
1784 # Get more results for fusion
1785 fetch_k = min(k * 3, 100)
1787 # Prepare vector search
1788 if isinstance(query_vector, (list, tuple)):
1789 import numpy as np
1790 query_vector = np.array(query_vector, dtype=np.float32)
1792 vector_str = format_vector_for_postgres(query_vector)
1794 # Get metric operator
1795 if isinstance(metric, str):
1796 metric_str = metric.lower()
1797 else:
1798 metric_str = metric.value
1799 operator = get_vector_operator(metric_str)
1801 vector_column = f"vector_{vector_field}"
1803 # Build combined query using CTE for efficient hybrid search
1804 # This performs both searches in a single query
1805 sql = f"""
1806 WITH text_search AS (
1807 SELECT
1808 id,
1809 data,
1810 metadata,
1811 ts_rank_cd(
1812 to_tsvector('english', {self._build_text_field_concat(search_text_fields)}),
1813 plainto_tsquery('english', $1)
1814 ) as text_score,
1815 ROW_NUMBER() OVER (
1816 ORDER BY ts_rank_cd(
1817 to_tsvector('english', {self._build_text_field_concat(search_text_fields)}),
1818 plainto_tsquery('english', $1)
1819 ) DESC
1820 ) as text_rank
1821 FROM {self.schema_name}.{self.table_name}
1822 WHERE to_tsvector('english', {self._build_text_field_concat(search_text_fields)}) @@ plainto_tsquery('english', $1)
1823 LIMIT {fetch_k}
1824 ),
1825 vector_search AS (
1826 SELECT
1827 id,
1828 data,
1829 metadata,
1830 {vector_column},
1831 1.0 - ({vector_column} {operator} $2::vector) as vector_score,
1832 ROW_NUMBER() OVER (
1833 ORDER BY {vector_column} {operator} $2::vector
1834 ) as vector_rank
1835 FROM {self.schema_name}.{self.table_name}
1836 WHERE {vector_column} IS NOT NULL
1837 LIMIT {fetch_k}
1838 ),
1839 combined AS (
1840 SELECT
1841 COALESCE(t.id, v.id) as id,
1842 COALESCE(t.data, v.data) as data,
1843 COALESCE(t.metadata, v.metadata) as metadata,
1844 t.text_score,
1845 t.text_rank,
1846 v.vector_score,
1847 v.vector_rank
1848 FROM text_search t
1849 FULL OUTER JOIN vector_search v ON t.id = v.id
1850 )
1851 SELECT * FROM combined
1852 """
1854 params = [query_text, vector_str]
1856 try:
1857 async with self._pool.acquire() as conn:
1858 rows = await conn.fetch(sql, *params)
1859 except Exception as e:
1860 # If full-text search fails, fall back to client-side fusion
1861 logger.warning(f"Native PostgreSQL hybrid search failed ({e}), falling back to client-side")
1862 from ..vector.mixins import VectorOperationsMixin
1863 return await VectorOperationsMixin.hybrid_search(
1864 self,
1865 query_text=query_text,
1866 query_vector=query_vector,
1867 text_fields=text_fields,
1868 vector_field=vector_field,
1869 k=k,
1870 config=HybridSearchConfig(
1871 text_weight=config.text_weight,
1872 vector_weight=config.vector_weight,
1873 fusion_strategy=FusionStrategy.RRF,
1874 rrf_k=config.rrf_k,
1875 text_fields=config.text_fields,
1876 ),
1877 filter=filter,
1878 metric=metric,
1879 )
1881 # Build result lists for fusion
1882 records_by_id: dict[str, Record] = {}
1883 text_scores: list[tuple[str, float]] = []
1884 vector_scores: list[tuple[str, float]] = []
1886 for row in rows:
1887 record = self._row_to_record(row)
1888 record_id = row['id']
1889 records_by_id[record_id] = record
1891 if row['text_score'] is not None:
1892 text_scores.append((record_id, float(row['text_score'])))
1893 if row['vector_score'] is not None:
1894 vector_scores.append((record_id, float(row['vector_score'])))
1896 # Sort by score for rank-based fusion
1897 text_scores.sort(key=lambda x: x[1], reverse=True)
1898 vector_scores.sort(key=lambda x: x[1], reverse=True)
1900 # Apply RRF fusion
1901 fused = reciprocal_rank_fusion(
1902 text_results=text_scores,
1903 vector_results=vector_scores,
1904 k=config.rrf_k,
1905 text_weight=config.text_weight,
1906 vector_weight=config.vector_weight,
1907 )
1909 # Build HybridSearchResult objects
1910 text_score_map = dict(text_scores)
1911 vector_score_map = dict(vector_scores)
1912 text_rank_map = {rid: i + 1 for i, (rid, _) in enumerate(text_scores)}
1913 vector_rank_map = {rid: i + 1 for i, (rid, _) in enumerate(vector_scores)}
1915 results: list[HybridSearchResult] = []
1916 for record_id, combined_score in fused[:k]:
1917 if record_id not in records_by_id:
1918 continue
1920 results.append(HybridSearchResult(
1921 record=records_by_id[record_id],
1922 combined_score=combined_score,
1923 text_score=text_score_map.get(record_id),
1924 vector_score=vector_score_map.get(record_id),
1925 text_rank=text_rank_map.get(record_id),
1926 vector_rank=vector_rank_map.get(record_id),
1927 metadata={
1928 "fusion_strategy": config.fusion_strategy.value,
1929 "text_weight": config.text_weight,
1930 "vector_weight": config.vector_weight,
1931 "backend": "postgresql",
1932 },
1933 ))
1935 return results
1937 def _build_text_field_concat(self, text_fields: list[str]) -> str:
1938 """Build SQL expression to concatenate text fields for full-text search.
1940 Args:
1941 text_fields: List of field names to concatenate
1943 Returns:
1944 SQL expression for concatenated text fields
1945 """
1946 if not text_fields:
1947 return "COALESCE(data->>'content', '')"
1949 parts = [f"COALESCE(data->>'{field}', '')" for field in text_fields]
1950 return " || ' ' || ".join(parts)
1952 async def _text_search_for_hybrid(
1953 self,
1954 query_text: str,
1955 text_fields: list[str] | None,
1956 k: int,
1957 filter: Query | None = None,
1958 ) -> list[tuple[Record, float]]:
1959 """Perform PostgreSQL full-text search for hybrid search fusion.
1961 Uses PostgreSQL's tsvector/tsquery full-text search with ts_rank_cd scoring.
1963 Args:
1964 query_text: Text to search for
1965 text_fields: Fields to search in
1966 k: Maximum results to return
1967 filter: Additional filters
1969 Returns:
1970 List of (record, score) tuples ordered by text relevance
1971 """
1972 self._check_connection()
1974 search_fields = text_fields or ["content", "title", "text"]
1975 text_concat = self._build_text_field_concat(search_fields)
1977 sql = f"""
1978 SELECT
1979 id,
1980 data,
1981 metadata,
1982 ts_rank_cd(
1983 to_tsvector('english', {text_concat}),
1984 plainto_tsquery('english', $1)
1985 ) as score
1986 FROM {self.schema_name}.{self.table_name}
1987 WHERE to_tsvector('english', {text_concat}) @@ plainto_tsquery('english', $1)
1988 ORDER BY score DESC
1989 LIMIT {k}
1990 """
1992 try:
1993 async with self._pool.acquire() as conn:
1994 rows = await conn.fetch(sql, query_text)
1995 except Exception as e:
1996 # Fall back to LIKE-based search if full-text search fails
1997 logger.warning(f"PostgreSQL full-text search failed ({e}), falling back to LIKE")
1998 from ..vector.mixins import VectorOperationsMixin
1999 return await VectorOperationsMixin._text_search_for_hybrid(
2000 self,
2001 query_text=query_text,
2002 text_fields=text_fields,
2003 k=k,
2004 filter=filter,
2005 )
2007 # Normalize scores
2008 results: list[tuple[Record, float]] = []
2009 max_score = max((float(row['score']) for row in rows), default=1.0) or 1.0
2011 for row in rows:
2012 record = self._row_to_record(row)
2013 score = float(row['score']) / max_score if max_score > 0 else 0.0
2014 results.append((record, score))
2016 return results