Coverage for src / dataknobs_data / backends / postgres.py: 11%

813 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-26 16:34 -0700

1"""PostgreSQL backend implementation with proper connection management and vector support.""" 

2 

3from __future__ import annotations 

4 

5import json 

6import logging 

7import time 

8import uuid 

9from typing import TYPE_CHECKING, Any, cast 

10 

11import asyncpg 

12from dataknobs_config import ConfigurableBase 

13 

14from dataknobs_utils.sql_utils import DotenvPostgresConnector, PostgresDB 

15 

16from ..database import AsyncDatabase, SyncDatabase 

17from ..pooling import ConnectionPoolManager 

18from ..pooling.postgres import PostgresPoolConfig, create_asyncpg_pool, validate_asyncpg_pool 

19from ..query import Operator, Query 

20from ..query_logic import ComplexQuery 

21from ..streaming import ( 

22 StreamConfig, 

23 StreamResult, 

24 async_process_batch_with_fallback, 

25 process_batch_with_fallback, 

26) 

27from ..vector.mixins import VectorOperationsMixin 

28from .postgres_mixins import ( 

29 PostgresBaseConfig, 

30 PostgresConnectionValidator, 

31 PostgresErrorHandler, 

32 PostgresTableManager, 

33 PostgresVectorSupport, 

34) 

35from .sql_base import SQLQueryBuilder, SQLRecordSerializer 

36from ..vector.types import DistanceMetric 

37 

38if TYPE_CHECKING: 

39 import numpy as np 

40 

41 from collections.abc import AsyncIterator, Iterator, Callable, Awaitable 

42 from ..fields import VectorField 

43 from ..records import Record 

44 from ..vector.types import VectorSearchResult 

45 

46logger = logging.getLogger(__name__) 

47 

48 

49class SyncPostgresDatabase( 

50 SyncDatabase, 

51 ConfigurableBase, 

52 VectorOperationsMixin, 

53 SQLRecordSerializer, 

54 PostgresBaseConfig, 

55 PostgresTableManager, 

56 PostgresVectorSupport, 

57 PostgresConnectionValidator, 

58 PostgresErrorHandler, 

59): 

60 """Synchronous PostgreSQL database backend with proper connection management.""" 

61 

62 def __init__(self, config: dict[str, Any] | None = None): 

63 """Initialize PostgreSQL database configuration. 

64 

65 Args: 

66 config: Configuration with the following optional keys: 

67 - host: PostgreSQL host (default: from env/localhost) 

68 - port: PostgreSQL port (default: 5432) 

69 - database: Database name (default: from env/postgres) 

70 - user: Username (default: from env/postgres) 

71 - password: Password (default: from env) 

72 - table: Table name (default: "records") 

73 - schema: Schema name (default: "public") 

74 - enable_vector: Enable vector support (default: False) 

75 """ 

76 super().__init__(config) 

77 

78 # Parse configuration using mixin 

79 table_name, schema_name, conn_config = self._parse_postgres_config(config or {}) 

80 self._init_postgres_attributes(table_name, schema_name) 

81 

82 # Store connection config for later use 

83 self._conn_config = conn_config 

84 self.db = None # Will be initialized in connect() 

85 self.query_builder = None # Will be initialized in connect() 

86 

87 @classmethod 

88 def from_config(cls, config: dict) -> SyncPostgresDatabase: 

89 """Create from config dictionary.""" 

90 return cls(config) 

91 

92 def connect(self) -> None: 

93 """Connect to the PostgreSQL database.""" 

94 if self._connected: 

95 return # Already connected 

96 

97 # Initialize query builder with pyformat style for psycopg2 

98 self.query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat") 

99 

100 # Create connection using existing utilities 

101 if not any(key in self._conn_config for key in ["host", "database", "user"]): 

102 # Use dotenv connector for environment-based config 

103 connector = DotenvPostgresConnector() 

104 self.db = PostgresDB(connector) 

105 else: 

106 # Direct configuration - map 'database' to 'db' for PostgresDB 

107 self.db = PostgresDB( 

108 host=self._conn_config.get("host", "localhost"), 

109 db=self._conn_config.get("database", "postgres"), # Note: PostgresDB expects 'db' not 'database' 

110 user=self._conn_config.get("user", "postgres"), 

111 pwd=self._conn_config.get("password"), # Note: PostgresDB expects 'pwd' not 'password' 

112 port=self._conn_config.get("port", 5432), 

113 ) 

114 

115 # Create table if it doesn't exist 

116 self._ensure_table() 

117 

118 # Detect and enable vector support if requested 

119 if self.vector_enabled: 

120 self._detect_vector_support() 

121 

122 self._connected = True 

123 self.log_operation("connect", f"Connected to table: {self.schema_name}.{self.table_name}") 

124 

125 def close(self) -> None: 

126 """Close the database connection.""" 

127 if self.db: 

128 # PostgresDB manages its own connections via context managers 

129 # but we can mark as disconnected 

130 self._connected = False # type: ignore[unreachable] 

131 

132 def _initialize(self) -> None: 

133 """Initialize method - connection setup moved to connect().""" 

134 # Configuration parsing stays here, actual connection in connect() 

135 pass 

136 

137 def _detect_vector_support(self) -> None: 

138 """Detect and enable vector support if pgvector is available.""" 

139 from .postgres_vector import check_pgvector_extension_sync, install_pgvector_extension_sync 

140 

141 try: 

142 # Check if pgvector is installed 

143 if check_pgvector_extension_sync(self.db): 

144 self._vector_enabled = True 

145 logger.info("pgvector extension detected and enabled") 

146 else: 

147 # Try to install it 

148 if install_pgvector_extension_sync(self.db): 

149 self._vector_enabled = True 

150 logger.info("pgvector extension installed and enabled") 

151 else: 

152 logger.debug("pgvector extension not available") 

153 except Exception as e: 

154 logger.debug(f"Could not enable vector support: {e}") 

155 self._vector_enabled = False 

156 

157 def _ensure_table(self) -> None: 

158 """Ensure the records table exists.""" 

159 if not self.db: 

160 raise RuntimeError("Database not connected. Call connect() first.") 

161 

162 create_table_sql = self.get_create_table_sql(self.schema_name, self.table_name) # type: ignore[unreachable] 

163 self.db.execute(create_table_sql) 

164 

165 

166 def _record_to_row(self, record: Record, id: str | None = None) -> dict[str, Any]: 

167 """Convert a Record to a database row.""" 

168 return { 

169 "id": id or str(uuid.uuid4()), 

170 "data": self.record_to_json(record), 

171 "metadata": json.dumps(record.metadata) if record.metadata else None, 

172 } 

173 

174 def _row_to_record(self, row: dict[str, Any]) -> Record: 

175 """Convert a database row to a Record.""" 

176 return self.row_to_record(row) 

177 

178 def create(self, record: Record) -> str: 

179 """Create a new record.""" 

180 self._check_connection() 

181 # Use record's ID if it has one, otherwise generate a new one 

182 id = record.id if record.id else str(uuid.uuid4()) 

183 row = self._record_to_row(record, id) 

184 

185 sql = f""" 

186 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata) 

187 VALUES (%(id)s, %(data)s, %(metadata)s) 

188 """ 

189 self.db.execute(sql, row) 

190 return id 

191 

192 def read(self, id: str) -> Record | None: 

193 """Read a record by ID.""" 

194 self._check_connection() 

195 sql = f""" 

196 SELECT id, data, metadata 

197 FROM {self.schema_name}.{self.table_name} 

198 WHERE id = %(id)s 

199 """ 

200 df = self.db.query(sql, {"id": id}) 

201 

202 if df.empty: 

203 return None 

204 

205 row = df.iloc[0].to_dict() 

206 return self._row_to_record(row) 

207 

208 def update(self, id: str, record: Record) -> bool: 

209 """Update an existing record. 

210 

211 Args: 

212 id: The record ID to update 

213 record: The record data to update with 

214 

215 Returns: 

216 True if the record was updated, False if no record with the given ID exists 

217 """ 

218 self._check_connection() 

219 row = self._record_to_row(record, id) 

220 

221 sql = f""" 

222 UPDATE {self.schema_name}.{self.table_name} 

223 SET data = %(data)s, metadata = %(metadata)s, updated_at = CURRENT_TIMESTAMP 

224 WHERE id = %(id)s 

225 """ 

226 result = self.db.execute(sql, row) 

227 

228 # PostgresDB.execute returns number of affected rows 

229 rows_affected = result if isinstance(result, int) else 0 

230 

231 if rows_affected == 0: 

232 logger.warning(f"Update affected 0 rows for id={id}. Record may not exist.") 

233 

234 return rows_affected > 0 

235 

236 def delete(self, id: str) -> bool: 

237 """Delete a record by ID.""" 

238 self._check_connection() 

239 sql = f""" 

240 DELETE FROM {self.schema_name}.{self.table_name} 

241 WHERE id = %(id)s 

242 """ 

243 result = self.db.execute(sql, {"id": id}) 

244 return result > 0 if isinstance(result, int) else False 

245 

246 def exists(self, id: str) -> bool: 

247 """Check if a record exists.""" 

248 self._check_connection() 

249 sql = f""" 

250 SELECT 1 FROM {self.schema_name}.{self.table_name} 

251 WHERE id = %(id)s 

252 LIMIT 1 

253 """ 

254 df = self.db.query(sql, {"id": id}) 

255 return not df.empty 

256 

257 def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str: 

258 """Update or insert a record. 

259  

260 Can be called as: 

261 - upsert(id, record) - explicit ID and record 

262 - upsert(record) - extract ID from record using Record's built-in logic 

263 """ 

264 self._check_connection() 

265 

266 # Determine ID and record based on arguments 

267 if isinstance(id_or_record, str): 

268 id = id_or_record 

269 if record is None: 

270 raise ValueError("Record required when ID is provided") 

271 else: 

272 record = id_or_record 

273 id = record.id 

274 if id is None: 

275 import uuid # type: ignore[unreachable] 

276 id = str(uuid.uuid4()) 

277 record.storage_id = id 

278 

279 if self.exists(id): 

280 self.update(id, record) 

281 else: 

282 # Insert with specific ID 

283 row = self._record_to_row(record, id) 

284 sql = f""" 

285 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata) 

286 VALUES (%(id)s, %(data)s, %(metadata)s) 

287 """ 

288 self.db.execute(sql, row) 

289 return id 

290 

291 def search(self, query: Query | ComplexQuery) -> list[Record]: 

292 """Search for records matching the query.""" 

293 self._check_connection() 

294 

295 # Handle ComplexQuery with native SQL support 

296 if isinstance(query, ComplexQuery): 

297 sql_query, params_list = self.query_builder.build_complex_search_query(query) 

298 else: 

299 sql_query, params_list = self.query_builder.build_search_query(query) 

300 

301 # Build params dict for psycopg2 

302 # The query builder now generates %(p0)s style placeholders directly 

303 params_dict = {} 

304 if params_list: 

305 for i, param in enumerate(params_list): 

306 params_dict[f"p{i}"] = param 

307 

308 # Execute query 

309 df = self.db.query(sql_query, params_dict) 

310 

311 # Convert to records 

312 records = [] 

313 for _, row in df.iterrows(): 

314 row_dict = row.to_dict() 

315 record = self._row_to_record(row_dict) 

316 

317 # Populate storage_id from database ID 

318 record.storage_id = str(row_dict['id']) 

319 

320 # Apply field projection if specified 

321 if query.fields: 

322 record = record.project(query.fields) 

323 

324 records.append(record) 

325 

326 return records 

327 

328 def _count_all(self) -> int: 

329 """Count all records in the database.""" 

330 self._check_connection() 

331 sql = f"SELECT COUNT(*) as count FROM {self.schema_name}.{self.table_name}" 

332 df = self.db.query(sql) 

333 return int(df.iloc[0]["count"]) if not df.empty else 0 

334 

335 def clear(self) -> int: 

336 """Clear all records from the database.""" 

337 self._check_connection() 

338 # Get count first 

339 count = self._count_all() 

340 

341 # Delete all records 

342 sql = f"TRUNCATE TABLE {self.schema_name}.{self.table_name}" 

343 self.db.execute(sql) 

344 

345 return count 

346 

347 def create_batch(self, records: list[Record]) -> list[str]: 

348 """Create multiple records efficiently using a single query. 

349  

350 Uses multi-value INSERT for better performance. 

351  

352 Args: 

353 records: List of records to create 

354  

355 Returns: 

356 List of created record IDs 

357 """ 

358 if not records: 

359 return [] 

360 

361 self._check_connection() 

362 

363 # Create a query builder for PostgreSQL with pyformat style 

364 from .sql_base import SQLQueryBuilder 

365 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat") 

366 

367 # Use the shared batch create query builder 

368 query, params_list, ids = query_builder.build_batch_create_query(records) 

369 

370 # Build params dict for psycopg2 

371 params_dict = {} 

372 for i, param in enumerate(params_list): 

373 params_dict[f"p{i}"] = param 

374 

375 # Execute the batch insert and get returned IDs 

376 result_df = self.db.query(query, params_dict) 

377 

378 # PostgreSQL RETURNING clause gives us the actual inserted IDs 

379 if not result_df.empty: 

380 return result_df['id'].tolist() 

381 return ids 

382 

383 def delete_batch(self, ids: list[str]) -> list[bool]: 

384 """Delete multiple records efficiently using a single query. 

385  

386 Uses single DELETE with IN clause for better performance. 

387  

388 Args: 

389 ids: List of record IDs to delete 

390  

391 Returns: 

392 List of success flags for each deletion 

393 """ 

394 if not ids: 

395 return [] 

396 

397 self._check_connection() 

398 

399 # Create a query builder for PostgreSQL with pyformat style 

400 from .sql_base import SQLQueryBuilder 

401 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat") 

402 

403 # Use the shared batch delete query builder (includes RETURNING clause) 

404 query, params_list = query_builder.build_batch_delete_query(ids) 

405 

406 # Build params dict for psycopg2 

407 params_dict = {} 

408 for i, param in enumerate(params_list): 

409 params_dict[f"p{i}"] = param 

410 

411 # Execute the batch delete and get returned IDs 

412 result_df = self.db.query(query, params_dict) 

413 

414 # Get list of deleted IDs from RETURNING clause 

415 deleted_ids = set(result_df['id'].tolist()) if not result_df.empty else set() 

416 

417 # Return results based on which IDs were actually deleted 

418 results = [] 

419 for id in ids: 

420 results.append(id in deleted_ids) 

421 

422 return results 

423 

424 def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]: 

425 """Update multiple records efficiently using a single query. 

426  

427 Uses PostgreSQL's CASE expressions for batch updates via shared SQL builder. 

428  

429 Args: 

430 updates: List of (id, record) tuples to update 

431  

432 Returns: 

433 List of success flags for each update 

434 """ 

435 if not updates: 

436 return [] 

437 

438 self._check_connection() 

439 

440 # Create a query builder for PostgreSQL with pyformat style 

441 from .sql_base import SQLQueryBuilder 

442 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat") 

443 

444 # Use the shared batch update query builder 

445 query, params_list = query_builder.build_batch_update_query(updates) 

446 

447 # Build params dict for psycopg2 

448 params_dict = {} 

449 for i, param in enumerate(params_list): 

450 params_dict[f"p{i}"] = param 

451 

452 # Execute the batch update and get returned IDs (query now includes RETURNING clause) 

453 result_df = self.db.query(query, params_dict) 

454 

455 # Get list of updated IDs from RETURNING clause 

456 updated_ids = set(result_df['id'].tolist()) if not result_df.empty else set() 

457 

458 results = [] 

459 for record_id, _ in updates: 

460 results.append(record_id in updated_ids) 

461 

462 return results 

463 

464 def stream_read( 

465 self, 

466 query: Query | None = None, 

467 config: StreamConfig | None = None 

468 ) -> Iterator[Record]: 

469 """Stream records from PostgreSQL.""" 

470 self._check_connection() 

471 config = config or StreamConfig() 

472 

473 # Build SQL query 

474 sql = f"SELECT id, data, metadata FROM {self.schema_name}.{self.table_name}" 

475 params = {} 

476 

477 if query and query.filters: 

478 # Add WHERE clause (simplified for now) 

479 where_clauses = [] 

480 for i, filter in enumerate(query.filters): 

481 field_path = f"data->>'{filter.field}'" 

482 param_name = f"param_{i}" 

483 

484 if filter.operator == Operator.EQ: 

485 where_clauses.append(f"{field_path} = %({param_name})s") 

486 params[param_name] = str(filter.value) 

487 

488 if where_clauses: 

489 sql += " WHERE " + " AND ".join(where_clauses) 

490 

491 # Use cursor for streaming 

492 # Note: PostgresDB may need modification to support cursors 

493 # For now, we'll fetch in batches 

494 sql += f" LIMIT {config.batch_size} OFFSET %(offset)s" 

495 

496 offset = 0 

497 while True: 

498 params["offset"] = offset 

499 df = self.db.query(sql, params) 

500 

501 if df.empty: 

502 break 

503 

504 for _, row in df.iterrows(): 

505 record = self._row_to_record(row.to_dict()) 

506 if query and query.fields: 

507 record = record.project(query.fields) 

508 yield record 

509 

510 offset += config.batch_size 

511 

512 # If we got less than batch_size, we're done 

513 if len(df) < config.batch_size: 

514 break 

515 

516 def stream_write( 

517 self, 

518 records: Iterator[Record], 

519 config: StreamConfig | None = None 

520 ) -> StreamResult: 

521 """Stream records into PostgreSQL.""" 

522 self._check_connection() 

523 config = config or StreamConfig() 

524 result = StreamResult() 

525 start_time = time.time() 

526 quitting = False 

527 

528 batch = [] 

529 for record in records: 

530 batch.append(record) 

531 

532 if len(batch) >= config.batch_size: 

533 # Write batch with graceful fallback 

534 # Use lambda wrapper for _write_batch 

535 continue_processing = process_batch_with_fallback( 

536 batch, 

537 lambda b: self._write_batch(b), 

538 self.create, 

539 result, 

540 config 

541 ) 

542 

543 if not continue_processing: 

544 quitting = True 

545 break 

546 

547 batch = [] 

548 

549 # Write remaining batch 

550 if batch and not quitting: 

551 process_batch_with_fallback( 

552 batch, 

553 lambda b: self._write_batch(b), 

554 self.create, 

555 result, 

556 config 

557 ) 

558 

559 result.duration = time.time() - start_time 

560 return result 

561 

562 def _write_batch(self, records: list[Record]) -> list[str]: 

563 """Write a batch of records to the database. 

564  

565 Returns: 

566 List of created record IDs 

567 """ 

568 # Build batch insert SQL 

569 values = [] 

570 params = {} 

571 ids = [] 

572 

573 for i, record in enumerate(records): 

574 id = str(uuid.uuid4()) 

575 ids.append(id) 

576 row = self._record_to_row(record, id) 

577 values.append(f"(%(id_{i})s, %(data_{i})s, %(metadata_{i})s)") 

578 params[f"id_{i}"] = row["id"] 

579 params[f"data_{i}"] = row["data"] 

580 params[f"metadata_{i}"] = row["metadata"] 

581 

582 sql = f""" 

583 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata) 

584 VALUES {', '.join(values)} 

585 """ 

586 self.db.execute(sql, params) 

587 return ids 

588 

589 def vector_search( 

590 self, 

591 query_vector: np.ndarray | list[float] | VectorField, 

592 field_name: str, 

593 k: int = 10, 

594 filter: Query | None = None, 

595 metric: DistanceMetric | str = "cosine" 

596 ) -> list[VectorSearchResult]: 

597 """Search for similar vectors using PostgreSQL pgvector. 

598  

599 Args: 

600 query_vector: Query vector (numpy array, list, or VectorField) 

601 field_name: Name of vector field to search (must be in data JSON) 

602 limit: Maximum number of results 

603 filters: Optional filters to apply 

604 metric: Distance metric to use (cosine, euclidean, l2, inner_product) 

605  

606 Returns: 

607 List of VectorSearchResult objects ordered by similarity 

608 """ 

609 if not self._vector_enabled: 

610 raise RuntimeError("Vector search not available - pgvector not installed") 

611 

612 self._check_connection() 

613 

614 from ..fields import VectorField 

615 from ..vector.types import DistanceMetric, VectorSearchResult 

616 from .postgres_vector import format_vector_for_postgres, get_vector_operator 

617 

618 # Convert query vector to proper format 

619 if isinstance(query_vector, VectorField): 

620 vector_str = format_vector_for_postgres(query_vector.value) 

621 else: 

622 vector_str = format_vector_for_postgres(query_vector) 

623 

624 # Get the appropriate operator 

625 if isinstance(metric, DistanceMetric): 

626 metric_str = metric.value 

627 else: 

628 metric_str = str(metric).lower() 

629 

630 operator = get_vector_operator(metric_str) 

631 

632 # Build the query - vectors are stored in JSON data field 

633 # Use centralized vector extraction logic 

634 vector_expr = self.get_vector_extraction_sql(field_name, dialect="postgres") 

635 

636 # Build the base SQL with pyformat placeholders 

637 sql = f""" 

638 SELECT  

639 id,  

640 data, 

641 metadata, 

642 {vector_expr} {operator} %(p0)s::vector AS distance 

643 FROM {self.schema_name}.{self.table_name} 

644 WHERE data ? %(p1)s -- Check field exists 

645 """ 

646 

647 params: list[Any] = [vector_str, field_name] 

648 

649 # Add filters if provided using the query builder 

650 if filter: 

651 # Query builder will generate pyformat placeholders since we configured it that way 

652 where_clause, filter_params = self.query_builder.build_where_clause(filter, len(params) + 1) 

653 if where_clause: 

654 sql += where_clause 

655 params.extend(filter_params) 

656 

657 # Order by distance and limit 

658 next_param = len(params) 

659 sql += f" ORDER BY distance LIMIT %(p{next_param})s" 

660 params.append(k) 

661 

662 # Build param dict for psycopg2 

663 param_dict = {} 

664 for i, param in enumerate(params): 

665 param_dict[f"p{i}"] = param 

666 

667 df = self.db.query(sql, param_dict) 

668 

669 # Convert results 

670 results = [] 

671 for _, row in df.iterrows(): 

672 record = self._row_to_record(row) 

673 

674 # Calculate similarity score from distance 

675 distance = row["distance"] 

676 if metric_str in ["cosine", "cosine_similarity"]: 

677 score = 1.0 - distance # Cosine distance to similarity 

678 elif metric_str in ["euclidean", "l2"]: 

679 score = 1.0 / (1.0 + distance) # Convert distance to similarity 

680 elif metric_str in ["inner_product", "dot_product"]: 

681 score = -distance # Negative because pgvector uses negative for descending 

682 else: 

683 score = -distance # Default: lower distance = better 

684 

685 result = VectorSearchResult( 

686 record=record, 

687 score=float(score), 

688 vector_field=field_name 

689 ) 

690 results.append(result) 

691 

692 return results 

693 

694 def has_vector_support(self) -> bool: 

695 """Check if this database has vector support enabled. 

696  

697 Returns: 

698 True if vector operations are supported 

699 """ 

700 return self._vector_enabled 

701 

702 def enable_vector_support(self) -> bool: 

703 """Enable vector support for this database if possible. 

704  

705 Returns: 

706 True if vector support is now enabled 

707 """ 

708 if self._vector_enabled: 

709 return True 

710 

711 self._detect_vector_support() 

712 return self._vector_enabled 

713 

714 def bulk_embed_and_store( 

715 self, 

716 records: list[Record], 

717 text_field: str | list[str], 

718 vector_field: str = "embedding", 

719 embedding_fn: Any = None, 

720 batch_size: int = 100, 

721 model_name: str | None = None, 

722 model_version: str | None = None, 

723 ) -> list[str]: 

724 """Embed text fields and store vectors with records (stub for abstract requirement). 

725  

726 This is a placeholder implementation to satisfy the abstract method requirement. 

727 Full implementation would require actual embedding function. 

728 """ 

729 raise NotImplementedError("bulk_embed_and_store requires an embedding function") 

730 

731 

732# Global pool manager instance for async PostgreSQL connections 

733_pool_manager = ConnectionPoolManager[asyncpg.Pool]() 

734 

735 

736class AsyncPostgresDatabase( 

737 AsyncDatabase, 

738 VectorOperationsMixin, 

739 ConfigurableBase, 

740 PostgresBaseConfig, 

741 PostgresTableManager, 

742 PostgresVectorSupport, 

743 PostgresConnectionValidator, 

744 PostgresErrorHandler, 

745): 

746 """Native async PostgreSQL database backend with vector support and event loop-aware connection pooling.""" 

747 

748 def __init__(self, config: dict[str, Any] | None = None): 

749 """Initialize async PostgreSQL database.""" 

750 super().__init__(config) 

751 

752 # Parse configuration using mixin 

753 table_name, schema_name, conn_config = self._parse_postgres_config(config or {}) 

754 self._init_postgres_attributes(table_name, schema_name) 

755 

756 # Extract pool configuration 

757 self._pool_config = PostgresPoolConfig.from_dict(conn_config) 

758 self._pool: asyncpg.Pool | None = None 

759 

760 @classmethod 

761 def from_config(cls, config: dict) -> AsyncPostgresDatabase: 

762 """Create from config dictionary.""" 

763 return cls(config) 

764 

765 async def connect(self) -> None: 

766 """Connect to the database.""" 

767 if self._connected: 

768 return 

769 

770 # Get or create pool for current event loop 

771 from ..pooling import BasePoolConfig 

772 self._pool = await _pool_manager.get_pool( 

773 self._pool_config, 

774 cast("Callable[[BasePoolConfig], Awaitable[Any]]", create_asyncpg_pool), 

775 validate_asyncpg_pool 

776 ) 

777 

778 # Initialize query builder 

779 self.query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres") 

780 

781 # Ensure table exists 

782 await self._ensure_table() 

783 

784 # Check and enable vector support if requested 

785 if self.vector_enabled: 

786 await self._detect_vector_support() 

787 

788 self._connected = True 

789 self.log_operation("connect", f"Connected to table: {self.schema_name}.{self.table_name}") 

790 

791 async def close(self) -> None: 

792 """Close the database connection and properly close the pool.""" 

793 if self._connected: 

794 # Properly close the pool if we have one 

795 if self._pool: 

796 try: 

797 await self._pool.close() 

798 except Exception as e: 

799 logger.warning(f"Error closing connection pool: {e}") 

800 self._pool = None 

801 self._connected = False 

802 

803 def _initialize(self) -> None: 

804 """Initialize is handled in connect.""" 

805 pass 

806 

807 async def _ensure_table(self) -> None: 

808 """Ensure the records table exists.""" 

809 if not self._pool: 

810 raise RuntimeError("Database not connected. Call connect() first.") 

811 

812 create_table_sql = self.get_create_table_sql(self.schema_name, self.table_name) 

813 

814 async with self._pool.acquire() as conn: 

815 await conn.execute(create_table_sql) 

816 

817 async def _detect_vector_support(self) -> None: 

818 """Detect and enable vector support if pgvector is available.""" 

819 from .postgres_vector import check_pgvector_extension, install_pgvector_extension 

820 

821 async with self._pool.acquire() as conn: 

822 # Check if pgvector is available 

823 if await check_pgvector_extension(conn): 

824 self._vector_enabled = True 

825 logger.info("pgvector extension detected and enabled") 

826 else: 

827 # Try to install it 

828 if await install_pgvector_extension(conn): 

829 self._vector_enabled = True 

830 logger.info("pgvector extension installed and enabled") 

831 else: 

832 logger.debug("pgvector extension not available") 

833 

834 async def _ensure_vector_column(self, field_name: str, dimensions: int) -> None: 

835 """Ensure a vector column exists for the given field. 

836  

837 Args: 

838 field_name: Name of the vector field 

839 dimensions: Number of dimensions 

840 """ 

841 if not self._vector_enabled: 

842 return 

843 

844 column_name = f"vector_{field_name}" 

845 

846 # Check if column already exists 

847 check_sql = """ 

848 SELECT column_name FROM information_schema.columns 

849 WHERE table_schema = $1 AND table_name = $2 AND column_name = $3 

850 """ 

851 

852 async with self._pool.acquire() as conn: 

853 existing = await conn.fetchval(check_sql, self.schema_name, self.table_name, column_name) 

854 

855 if not existing: 

856 # Add vector column 

857 alter_sql = f""" 

858 ALTER TABLE {self.schema_name}.{self.table_name} 

859 ADD COLUMN IF NOT EXISTS {column_name} vector({dimensions}) 

860 """ 

861 try: 

862 await conn.execute(alter_sql) 

863 self._vector_dimensions[field_name] = dimensions 

864 logger.info(f"Added vector column {column_name} with {dimensions} dimensions") 

865 

866 # Create index for the vector column 

867 from .postgres_vector import build_vector_index_sql, get_optimal_index_type 

868 

869 # Get row count for optimal index selection 

870 count_sql = f"SELECT COUNT(*) FROM {self.schema_name}.{self.table_name}" 

871 count = await conn.fetchval(count_sql) 

872 

873 index_type, index_params = get_optimal_index_type(count) 

874 index_sql = build_vector_index_sql( 

875 self.table_name, 

876 self.schema_name, 

877 column_name, 

878 dimensions, 

879 metric="cosine", 

880 index_type=index_type, 

881 index_params=index_params 

882 ) 

883 

884 # Note: IVFFlat requires table to have data before creating index 

885 if count > 0 or index_type != "ivfflat": 

886 await conn.execute(index_sql) 

887 logger.info(f"Created {index_type} index for {column_name}") 

888 

889 except Exception as e: 

890 logger.warning(f"Could not create vector column {column_name}: {e}") 

891 else: 

892 self._vector_dimensions[field_name] = dimensions 

893 

894 def _check_connection(self) -> None: 

895 """Check if async database is connected.""" 

896 self._check_async_connection() 

897 

898 def _record_to_row(self, record: Record, id: str | None = None) -> dict[str, Any]: 

899 """Convert a Record to a database row using common serializer.""" 

900 from .sql_base import SQLRecordSerializer 

901 

902 return { 

903 "id": id or str(uuid.uuid4()), 

904 "data": SQLRecordSerializer.record_to_json(record), 

905 "metadata": json.dumps(record.metadata) if record.metadata else None, 

906 } 

907 

908 def _row_to_record(self, row: asyncpg.Record) -> Record: 

909 """Convert a database row to a Record using the common serializer.""" 

910 from .sql_base import SQLRecordSerializer 

911 

912 # Convert asyncpg.Record to dict format expected by SQLRecordSerializer 

913 data_json = row.get("data", {}) 

914 if not isinstance(data_json, str): 

915 data_json = json.dumps(data_json) 

916 

917 metadata_json = row.get("metadata") 

918 if metadata_json and not isinstance(metadata_json, str): 

919 metadata_json = json.dumps(metadata_json) 

920 

921 # Use the common serializer to reconstruct the record 

922 return SQLRecordSerializer.json_to_record(data_json, metadata_json) 

923 

924 async def create(self, record: Record) -> str: 

925 """Create a new record with vector support.""" 

926 self._check_connection() 

927 

928 # Check for vector fields and ensure columns exist 

929 from ..fields import VectorField 

930 for field_name, field_obj in record.fields.items(): 

931 if isinstance(field_obj, VectorField) and self._vector_enabled: 

932 await self._ensure_vector_column(field_name, field_obj.dimensions) 

933 

934 # Use record's ID if it has one, otherwise generate a new one 

935 id = record.id if record.id else str(uuid.uuid4()) 

936 row = self._record_to_row(record, id) 

937 

938 # Build dynamic SQL based on vector columns present 

939 columns = ["id", "data", "metadata"] 

940 values = [row["id"], row["data"], row["metadata"]] 

941 placeholders = ["$1", "$2", "$3"] 

942 

943 # Add vector columns 

944 param_num = 4 

945 for key, value in row.items(): 

946 if key.startswith("vector_"): 

947 columns.append(key) 

948 values.append(value) 

949 placeholders.append(f"${param_num}") 

950 param_num += 1 

951 

952 sql = f""" 

953 INSERT INTO {self.schema_name}.{self.table_name} ({', '.join(columns)}) 

954 VALUES ({', '.join(placeholders)}) 

955 """ 

956 

957 async with self._pool.acquire() as conn: 

958 await conn.execute(sql, *values) 

959 

960 return id 

961 

962 async def read(self, id: str) -> Record | None: 

963 """Read a record by ID.""" 

964 self._check_connection() 

965 sql = f""" 

966 SELECT id, data, metadata 

967 FROM {self.schema_name}.{self.table_name} 

968 WHERE id = $1 

969 """ 

970 

971 async with self._pool.acquire() as conn: 

972 row = await conn.fetchrow(sql, id) 

973 

974 if not row: 

975 return None 

976 

977 return self._row_to_record(row) 

978 

979 async def update(self, id: str, record: Record) -> bool: 

980 """Update an existing record. 

981 

982 Args: 

983 id: The record ID to update 

984 record: The record data to update with 

985 

986 Returns: 

987 True if the record was updated, False if no record with the given ID exists 

988 """ 

989 self._check_connection() 

990 row = self._record_to_row(record, id) 

991 

992 sql = f""" 

993 UPDATE {self.schema_name}.{self.table_name} 

994 SET data = $2, metadata = $3, updated_at = CURRENT_TIMESTAMP 

995 WHERE id = $1 

996 """ 

997 

998 async with self._pool.acquire() as conn: 

999 result = await conn.execute(sql, row["id"], row["data"], row["metadata"]) 

1000 

1001 # Returns UPDATE n where n is rows affected 

1002 rows_affected = int(result.split()[-1]) 

1003 

1004 if rows_affected == 0: 

1005 logger.warning(f"Update affected 0 rows for id={id}. Record may not exist.") 

1006 

1007 return rows_affected > 0 

1008 

1009 async def delete(self, id: str) -> bool: 

1010 """Delete a record by ID.""" 

1011 self._check_connection() 

1012 sql = f""" 

1013 DELETE FROM {self.schema_name}.{self.table_name} 

1014 WHERE id = $1 

1015 """ 

1016 

1017 async with self._pool.acquire() as conn: 

1018 result = await conn.execute(sql, id) 

1019 

1020 # Returns DELETE n where n is rows affected 

1021 return result.split()[-1] != "0" 

1022 

1023 async def exists(self, id: str) -> bool: 

1024 """Check if a record exists.""" 

1025 self._check_connection() 

1026 sql = f""" 

1027 SELECT 1 FROM {self.schema_name}.{self.table_name} 

1028 WHERE id = $1 

1029 LIMIT 1 

1030 """ 

1031 

1032 async with self._pool.acquire() as conn: 

1033 row = await conn.fetchrow(sql, id) 

1034 

1035 return row is not None 

1036 

1037 async def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str: 

1038 """Update or insert a record. 

1039  

1040 Can be called as: 

1041 - upsert(id, record) - explicit ID and record 

1042 - upsert(record) - extract ID from record using Record's built-in logic 

1043 """ 

1044 self._check_connection() 

1045 

1046 # Determine ID and record based on arguments 

1047 if isinstance(id_or_record, str): 

1048 id = id_or_record 

1049 if record is None: 

1050 raise ValueError("Record required when ID is provided") 

1051 else: 

1052 record = id_or_record 

1053 id = record.id 

1054 if id is None: 

1055 import uuid # type: ignore[unreachable] 

1056 id = str(uuid.uuid4()) 

1057 record.storage_id = id 

1058 

1059 row = self._record_to_row(record, id) 

1060 

1061 sql = f""" 

1062 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata) 

1063 VALUES ($1, $2, $3) 

1064 ON CONFLICT (id) DO UPDATE 

1065 SET data = EXCLUDED.data, metadata = EXCLUDED.metadata, updated_at = CURRENT_TIMESTAMP 

1066 """ 

1067 

1068 async with self._pool.acquire() as conn: 

1069 await conn.execute(sql, row["id"], row["data"], row["metadata"]) 

1070 

1071 return id 

1072 

1073 async def search(self, query: Query | ComplexQuery) -> list[Record]: 

1074 """Search for records matching the query.""" 

1075 self._check_connection() 

1076 

1077 # Initialize query builder if not already done 

1078 if not hasattr(self, 'query_builder'): 

1079 self.query_builder = SQLQueryBuilder( 

1080 self.table_name, self.schema_name, dialect="postgres" 

1081 ) 

1082 

1083 # Handle ComplexQuery with native SQL support 

1084 if isinstance(query, ComplexQuery): 

1085 sql, params = self.query_builder.build_complex_search_query(query) 

1086 else: 

1087 sql, params = self.query_builder.build_search_query(query) 

1088 

1089 # Execute query with asyncpg (already uses positional parameters) 

1090 async with self._pool.acquire() as conn: 

1091 rows = await conn.fetch(sql, *params) 

1092 

1093 # Convert to records 

1094 records = [] 

1095 for row in rows: 

1096 record = self._row_to_record(row) 

1097 

1098 # Populate storage_id from database ID 

1099 record.storage_id = str(row['id']) 

1100 

1101 # Apply field projection if specified 

1102 if query.fields: 

1103 record = record.project(query.fields) 

1104 

1105 records.append(record) 

1106 

1107 return records 

1108 

1109 async def _count_all(self) -> int: 

1110 """Count all records in the database.""" 

1111 self._check_connection() 

1112 sql = f"SELECT COUNT(*) as count FROM {self.schema_name}.{self.table_name}" 

1113 

1114 async with self._pool.acquire() as conn: 

1115 row = await conn.fetchrow(sql) 

1116 

1117 return row["count"] if row else 0 

1118 

1119 async def clear(self) -> int: 

1120 """Clear all records from the database.""" 

1121 self._check_connection() 

1122 # Get count first 

1123 count = await self._count_all() 

1124 

1125 # Delete all records 

1126 sql = f"TRUNCATE TABLE {self.schema_name}.{self.table_name}" 

1127 

1128 async with self._pool.acquire() as conn: 

1129 await conn.execute(sql) 

1130 

1131 return count 

1132 

1133 async def create_batch(self, records: list[Record]) -> list[str]: 

1134 """Create multiple records efficiently using a single query. 

1135  

1136 Uses multi-value INSERT with RETURNING for better performance. 

1137  

1138 Args: 

1139 records: List of records to create 

1140  

1141 Returns: 

1142 List of created record IDs 

1143 """ 

1144 if not records: 

1145 return [] 

1146 

1147 self._check_connection() 

1148 

1149 # Create a query builder for PostgreSQL 

1150 from .sql_base import SQLQueryBuilder 

1151 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres") 

1152 

1153 # Use the shared batch create query builder 

1154 query, params, ids = query_builder.build_batch_create_query(records) 

1155 

1156 # Execute the batch insert with RETURNING 

1157 async with self._pool.acquire() as conn: 

1158 rows = await conn.fetch(query, *params) 

1159 

1160 # Return the actual inserted IDs from RETURNING clause 

1161 if rows: 

1162 return [row["id"] for row in rows] 

1163 return ids # Fallback to generated IDs 

1164 

1165 async def delete_batch(self, ids: list[str]) -> list[bool]: 

1166 """Delete multiple records efficiently using a single query. 

1167  

1168 Uses single DELETE with IN clause and RETURNING for verification. 

1169  

1170 Args: 

1171 ids: List of record IDs to delete 

1172  

1173 Returns: 

1174 List of success flags for each deletion 

1175 """ 

1176 if not ids: 

1177 return [] 

1178 

1179 self._check_connection() 

1180 

1181 # Create a query builder for PostgreSQL 

1182 from .sql_base import SQLQueryBuilder 

1183 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres") 

1184 

1185 # Use the shared batch delete query builder 

1186 query, params = query_builder.build_batch_delete_query(ids) 

1187 

1188 # Execute the batch delete with RETURNING 

1189 async with self._pool.acquire() as conn: 

1190 rows = await conn.fetch(query, *params) 

1191 

1192 # Convert returned rows to set of deleted IDs 

1193 deleted_ids = {row["id"] for row in rows} 

1194 

1195 # Return results for each deletion 

1196 results = [] 

1197 for id in ids: 

1198 results.append(id in deleted_ids) 

1199 

1200 return results 

1201 

1202 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]: 

1203 """Update multiple records efficiently using a single query. 

1204  

1205 Uses PostgreSQL's CASE expressions for batch updates with native asyncpg. 

1206  

1207 Args: 

1208 updates: List of (id, record) tuples to update 

1209  

1210 Returns: 

1211 List of success flags for each update 

1212 """ 

1213 if not updates: 

1214 return [] 

1215 

1216 self._check_connection() 

1217 

1218 # Create a query builder for PostgreSQL 

1219 from .sql_base import SQLQueryBuilder 

1220 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres") 

1221 

1222 # Use the shared batch update query builder 

1223 # It already produces positional parameters ($1, $2) for PostgreSQL 

1224 query, params = query_builder.build_batch_update_query(updates) 

1225 

1226 # Add RETURNING clause for PostgreSQL to get updated IDs 

1227 query = query.rstrip() + " RETURNING id" 

1228 

1229 # Execute the batch update 

1230 async with self._pool.acquire() as conn: 

1231 rows = await conn.fetch(query, *params) 

1232 

1233 # Convert returned rows to set of updated IDs 

1234 updated_ids = {row["id"] for row in rows} 

1235 

1236 # Return results for each update 

1237 results = [] 

1238 for record_id, _ in updates: 

1239 results.append(record_id in updated_ids) 

1240 

1241 return results 

1242 

1243 async def vector_search( 

1244 self, 

1245 query_vector: np.ndarray | list[float] | VectorField, 

1246 field_name: str, 

1247 k: int = 10, 

1248 filter: Query | None = None, 

1249 metric: DistanceMetric | str = "cosine" 

1250 ) -> list[VectorSearchResult]: 

1251 """Search for similar vectors using PostgreSQL pgvector. 

1252  

1253 Args: 

1254 query_vector: Query vector (numpy array, list, or VectorField) 

1255 field_name: Name of vector field to search 

1256 limit: Maximum number of results 

1257 filters: Optional filters to apply 

1258 metric: Distance metric to use 

1259  

1260 Returns: 

1261 List of VectorSearchResult objects 

1262 """ 

1263 if not self._vector_enabled: 

1264 raise RuntimeError("Vector search not available - pgvector not installed") 

1265 

1266 self._check_connection() 

1267 

1268 from ..fields import VectorField 

1269 from ..vector.types import DistanceMetric, VectorSearchResult 

1270 from .postgres_vector import format_vector_for_postgres, get_vector_operator 

1271 

1272 # Convert query vector to proper format 

1273 if isinstance(query_vector, VectorField): 

1274 vector_str = format_vector_for_postgres(query_vector.value) 

1275 else: 

1276 vector_str = format_vector_for_postgres(query_vector) 

1277 

1278 # Get the appropriate operator 

1279 if isinstance(metric, DistanceMetric): 

1280 metric_str = metric.value 

1281 else: 

1282 metric_str = str(metric).lower() 

1283 operator = get_vector_operator(metric_str) 

1284 

1285 vector_column = f"vector_{field_name}" 

1286 

1287 # Build query 

1288 sql = f""" 

1289 SELECT id, data, metadata, {vector_column}, 

1290 {vector_column} {operator} $1::vector AS distance 

1291 FROM {self.schema_name}.{self.table_name} 

1292 WHERE {vector_column} IS NOT NULL 

1293 """ 

1294 

1295 params = [vector_str] 

1296 param_num = 2 

1297 

1298 # Add filters if provided using the query builder 

1299 if filter: 

1300 # First get the where clause from query builder 

1301 where_clause, filter_params = self.query_builder.build_where_clause(filter, param_num) 

1302 if where_clause: 

1303 # Convert %s placeholders to $N for asyncpg 

1304 for param in filter_params: 

1305 where_clause = where_clause.replace("%s", f"${param_num}", 1) 

1306 params.append(param) 

1307 param_num += 1 

1308 sql += where_clause 

1309 

1310 # Order by distance and limit 

1311 sql += f""" 

1312 ORDER BY distance 

1313 LIMIT {k} 

1314 """ 

1315 

1316 # Execute query 

1317 async with self._pool.acquire() as conn: 

1318 rows = await conn.fetch(sql, *params) 

1319 

1320 # Convert to VectorSearchResult objects 

1321 results = [] 

1322 for row in rows: 

1323 record = self._row_to_record(row) 

1324 

1325 # Convert distance to similarity score (1 - normalized_distance for cosine) 

1326 distance = float(row['distance']) 

1327 if metric_str == "cosine": 

1328 score = 1.0 - min(distance, 2.0) / 2.0 # Normalize cosine distance [0,2] to similarity [0,1] 

1329 elif metric_str in ["euclidean", "l2"]: 

1330 score = 1.0 / (1.0 + distance) # Convert distance to similarity 

1331 else: 

1332 score = 1.0 - distance # Generic conversion 

1333 

1334 result = VectorSearchResult( 

1335 record=record, 

1336 score=score, 

1337 vector_field=field_name, 

1338 metadata={"distance": distance, "metric": metric_str} 

1339 ) 

1340 results.append(result) 

1341 

1342 return results 

1343 

1344 async def enable_vector_support(self) -> bool: 

1345 """Enable vector support for this database. 

1346  

1347 Returns: 

1348 True if vector support is enabled 

1349 """ 

1350 if self._vector_enabled: 

1351 return True 

1352 

1353 await self._detect_vector_support() 

1354 return self._vector_enabled 

1355 

1356 async def has_vector_support(self) -> bool: 

1357 """Check if this database has vector support enabled. 

1358  

1359 Returns: 

1360 True if vector support is available 

1361 """ 

1362 return self._vector_enabled 

1363 

1364 async def bulk_embed_and_store( 

1365 self, 

1366 records: list[Record], 

1367 text_field: str | list[str], 

1368 vector_field: str, 

1369 embedding_fn: Any | None = None, 

1370 batch_size: int = 100, 

1371 model_name: str | None = None, 

1372 model_version: str | None = None, 

1373 ) -> list[str]: 

1374 """Embed text fields and store vectors with records. 

1375  

1376 This is a placeholder implementation. In a real scenario, you would: 

1377 1. Extract text from the specified fields 

1378 2. Call the embedding function to generate vectors 

1379 3. Store the vectors alongside the records 

1380  

1381 Args: 

1382 records: Records to process 

1383 text_field: Field name(s) containing text to embed 

1384 vector_field: Field name to store vectors in 

1385 embedding_fn: Function to generate embeddings 

1386 batch_size: Number of records to process at once 

1387 model_name: Name of the embedding model 

1388 model_version: Version of the embedding model 

1389  

1390 Returns: 

1391 List of record IDs that were processed 

1392 """ 

1393 if not embedding_fn: 

1394 raise ValueError("embedding_fn is required for bulk_embed_and_store") 

1395 

1396 from ..fields import VectorField 

1397 

1398 processed_ids = [] 

1399 

1400 # Process in batches 

1401 for i in range(0, len(records), batch_size): 

1402 batch = records[i:i + batch_size] 

1403 

1404 # Extract texts 

1405 texts = [] 

1406 for record in batch: 

1407 if isinstance(text_field, list): 

1408 text = " ".join(str(record.fields.get(f, {}).value) for f in text_field if f in record.fields) 

1409 else: 

1410 text = str(record.fields.get(text_field, {}).value) if text_field in record.fields else "" 

1411 texts.append(text) 

1412 

1413 # Generate embeddings 

1414 if texts: 

1415 embeddings = await embedding_fn(texts) 

1416 

1417 # Store vectors with records 

1418 for j, record in enumerate(batch): 

1419 if j < len(embeddings): 

1420 vector = embeddings[j] 

1421 

1422 # Add vector field to record 

1423 record.fields[vector_field] = VectorField( 

1424 name=vector_field, 

1425 value=vector, 

1426 dimensions=len(vector) if hasattr(vector, '__len__') else None, 

1427 source_field=text_field if isinstance(text_field, str) else ",".join(text_field), 

1428 model_name=model_name, 

1429 model_version=model_version, 

1430 ) 

1431 

1432 # Create or update record 

1433 if record.has_storage_id(): 

1434 if record.storage_id is None: 

1435 raise ValueError("Record has_storage_id() returned True but storage_id is None") 

1436 await self.update(record.storage_id, record) 

1437 else: 

1438 record_id = await self.create(record) 

1439 record.storage_id = record_id 

1440 

1441 if record.storage_id is None: 

1442 raise ValueError("Record storage_id is None after create/update") 

1443 processed_ids.append(record.storage_id) 

1444 

1445 return processed_ids 

1446 

1447 async def create_vector_index( 

1448 self, 

1449 vector_field: str, 

1450 dimensions: int, 

1451 metric: DistanceMetric | str = "cosine", 

1452 index_type: str = "ivfflat", 

1453 lists: int | None = None, 

1454 ) -> bool: 

1455 """Create a vector index for efficient similarity search. 

1456  

1457 Args: 

1458 vector_field: Name of the vector field to index 

1459 dimensions: Number of dimensions in the vectors 

1460 metric: Distance metric for the index 

1461 index_type: Type of index (ivfflat, hnsw) 

1462 lists: Number of lists for IVFFlat index 

1463  

1464 Returns: 

1465 True if index was created successfully 

1466 """ 

1467 from .postgres_vector import ( 

1468 build_vector_column_expression, 

1469 build_vector_index_sql, 

1470 get_optimal_index_type, 

1471 get_vector_count_sql, 

1472 ) 

1473 

1474 self._check_connection() 

1475 

1476 if not self._vector_enabled: 

1477 return False 

1478 

1479 # Determine optimal parameters if not provided 

1480 if not lists and index_type == "ivfflat": 

1481 # Count vectors to determine optimal lists 

1482 count_sql = get_vector_count_sql(self.schema_name, self.table_name, vector_field) 

1483 async with self._pool.acquire() as conn: 

1484 count = await conn.fetchval(count_sql) or 0 

1485 _, params = get_optimal_index_type(count) 

1486 lists = params.get("lists", 100) 

1487 

1488 # Convert metric enum to string if needed 

1489 if hasattr(metric, 'value'): 

1490 metric_str = metric.value 

1491 else: 

1492 metric_str = str(metric).lower() 

1493 

1494 # Build vector column expression for index 

1495 column_expr = build_vector_column_expression(vector_field, dimensions, for_index=True) 

1496 

1497 # Build index SQL - pass field_name for proper index naming 

1498 index_sql = build_vector_index_sql( 

1499 table_name=self.table_name, 

1500 schema_name=self.schema_name, 

1501 column_name=column_expr, 

1502 dimensions=dimensions, 

1503 metric=metric_str, 

1504 index_type=index_type, 

1505 index_params={"lists": lists} if lists else None, 

1506 field_name=vector_field 

1507 ) 

1508 

1509 # Create the index 

1510 try: 

1511 logger.debug(f"Creating vector index with SQL: {index_sql}") 

1512 async with self._pool.acquire() as conn: 

1513 await conn.execute(index_sql) 

1514 return True 

1515 except Exception as e: 

1516 logger.warning(f"Failed to create vector index: {e}") 

1517 logger.debug(f"Index SQL was: {index_sql}") 

1518 return False 

1519 

1520 async def drop_vector_index(self, vector_field: str, metric: str = "cosine") -> bool: 

1521 """Drop a vector index. 

1522  

1523 Args: 

1524 vector_field: Name of the vector field 

1525 metric: Distance metric used in the index 

1526  

1527 Returns: 

1528 True if index was dropped successfully 

1529 """ 

1530 from .postgres_vector import get_vector_index_name 

1531 

1532 self._check_connection() 

1533 

1534 index_name = get_vector_index_name(self.table_name, vector_field, metric) 

1535 

1536 try: 

1537 async with self._pool.acquire() as conn: 

1538 await conn.execute(f"DROP INDEX IF EXISTS {self.schema_name}.{index_name}") 

1539 return True 

1540 except Exception as e: 

1541 logger.warning(f"Failed to drop vector index: {e}") 

1542 return False 

1543 

1544 async def get_vector_index_stats(self, vector_field: str) -> dict[str, Any]: 

1545 """Get statistics about a vector field and its index. 

1546  

1547 Args: 

1548 vector_field: Name of the vector field 

1549  

1550 Returns: 

1551 Dictionary with index statistics 

1552 """ 

1553 from .postgres_vector import get_index_check_sql, get_vector_count_sql 

1554 

1555 self._check_connection() 

1556 

1557 stats = { 

1558 "field": vector_field, 

1559 "indexed": False, 

1560 "vector_count": 0, 

1561 } 

1562 

1563 try: 

1564 async with self._pool.acquire() as conn: 

1565 # Count vectors 

1566 count_sql = get_vector_count_sql(self.schema_name, self.table_name, vector_field) 

1567 stats["vector_count"] = await conn.fetchval(count_sql) or 0 

1568 

1569 # Check for index 

1570 index_sql, params = get_index_check_sql(self.schema_name, self.table_name, vector_field) 

1571 stats["indexed"] = await conn.fetchval(index_sql, *params) or False 

1572 except Exception as e: 

1573 logger.warning(f"Failed to get vector index stats: {e}") 

1574 

1575 return stats 

1576 

1577 async def stream_read( 

1578 self, 

1579 query: Query | None = None, 

1580 config: StreamConfig | None = None 

1581 ) -> AsyncIterator[Record]: 

1582 """Stream records from PostgreSQL using cursor.""" 

1583 self._check_connection() 

1584 config = config or StreamConfig() 

1585 

1586 # Build SQL query 

1587 sql = f"SELECT id, data, metadata FROM {self.schema_name}.{self.table_name}" 

1588 params = [] 

1589 

1590 if query and query.filters: 

1591 where_clauses = [] 

1592 param_count = 0 

1593 

1594 for filter in query.filters: 

1595 param_count += 1 

1596 field_path = f"data->>'{filter.field}'" 

1597 

1598 if filter.operator == Operator.EQ: 

1599 where_clauses.append(f"{field_path} = ${param_count}") 

1600 params.append(str(filter.value)) 

1601 

1602 if where_clauses: 

1603 sql += " WHERE " + " AND ".join(where_clauses) 

1604 

1605 # Use cursor for efficient streaming 

1606 async with self._pool.acquire() as conn: 

1607 async with conn.transaction(): 

1608 cursor = await conn.cursor(sql, *params) 

1609 

1610 batch = [] 

1611 async for row in cursor: 

1612 record = self._row_to_record(row) 

1613 if query and query.fields: 

1614 record = record.project(query.fields) 

1615 

1616 batch.append(record) 

1617 

1618 if len(batch) >= config.batch_size: 

1619 for rec in batch: 

1620 yield rec 

1621 batch = [] 

1622 

1623 # Yield remaining records 

1624 for rec in batch: 

1625 yield rec 

1626 

1627 async def stream_write( 

1628 self, 

1629 records: AsyncIterator[Record], 

1630 config: StreamConfig | None = None 

1631 ) -> StreamResult: 

1632 """Stream records into PostgreSQL using batch inserts.""" 

1633 self._check_connection() 

1634 config = config or StreamConfig() 

1635 result = StreamResult() 

1636 start_time = time.time() 

1637 quitting = False 

1638 

1639 batch = [] 

1640 async for record in records: 

1641 batch.append(record) 

1642 

1643 if len(batch) >= config.batch_size: 

1644 # Write batch with graceful fallback 

1645 # Use lambda wrapper for _write_batch 

1646 async def batch_func(b): 

1647 await self._write_batch(b) 

1648 return [r.id for r in b] 

1649 

1650 continue_processing = await async_process_batch_with_fallback( 

1651 batch, 

1652 batch_func, 

1653 self.create, 

1654 result, 

1655 config 

1656 ) 

1657 

1658 if not continue_processing: 

1659 quitting = True 

1660 break 

1661 

1662 batch = [] 

1663 

1664 # Write remaining batch 

1665 if batch and not quitting: 

1666 async def batch_func(b): 

1667 await self._write_batch(b) 

1668 return [r.id for r in b] 

1669 

1670 await async_process_batch_with_fallback( 

1671 batch, 

1672 batch_func, 

1673 self.create, 

1674 result, 

1675 config 

1676 ) 

1677 

1678 result.duration = time.time() - start_time 

1679 return result 

1680 

1681 async def _write_batch(self, records: list[Record]) -> list[str]: 

1682 """Write a batch of records using COPY for performance. 

1683  

1684 Returns: 

1685 List of created record IDs 

1686 """ 

1687 if not records: 

1688 return [] 

1689 

1690 # Prepare data for COPY 

1691 rows = [] 

1692 ids = [] 

1693 for record in records: 

1694 row_data = self._record_to_row(record) 

1695 ids.append(row_data["id"]) 

1696 rows.append(( 

1697 row_data["id"], 

1698 row_data["data"], 

1699 row_data["metadata"] 

1700 )) 

1701 

1702 # Use COPY for efficient bulk insert 

1703 async with self._pool.acquire() as conn: 

1704 await conn.copy_records_to_table( 

1705 f"{self.schema_name}.{self.table_name}", 

1706 records=rows, 

1707 columns=["id", "data", "metadata"] 

1708 ) 

1709 

1710 return ids 

1711 

1712 async def _supports_native_hybrid(self) -> bool: 

1713 """Check if this PostgreSQL backend supports native hybrid search. 

1714 

1715 PostgreSQL with pgvector and full-text search (tsvector) supports 

1716 native hybrid search. 

1717 

1718 Returns: 

1719 True if vector support is enabled (pgvector available) 

1720 """ 

1721 return self._vector_enabled 

1722 

1723 async def hybrid_search( 

1724 self, 

1725 query_text: str, 

1726 query_vector: np.ndarray | list[float], 

1727 text_fields: list[str] | None = None, 

1728 vector_field: str = "embedding", 

1729 k: int = 10, 

1730 config: Any = None, # HybridSearchConfig 

1731 filter: Query | None = None, 

1732 metric: DistanceMetric | str = DistanceMetric.COSINE, 

1733 ) -> list[Any]: # list[HybridSearchResult] 

1734 """Perform hybrid search using PostgreSQL full-text search and pgvector. 

1735 

1736 Combines PostgreSQL's tsvector full-text search with pgvector similarity 

1737 search using configurable fusion strategies. 

1738 

1739 Args: 

1740 query_text: Text query for full-text matching 

1741 query_vector: Vector for pgvector similarity search 

1742 text_fields: Fields to search for text matching 

1743 vector_field: Name of the vector field to search 

1744 k: Number of results to return 

1745 config: Hybrid search configuration (weights, fusion strategy) 

1746 filter: Optional additional filters to apply 

1747 metric: Distance metric for vector search 

1748 

1749 Returns: 

1750 List of HybridSearchResult ordered by combined score (descending) 

1751 """ 

1752 from ..vector.hybrid import ( 

1753 FusionStrategy, 

1754 HybridSearchConfig, 

1755 HybridSearchResult, 

1756 reciprocal_rank_fusion, 

1757 weighted_score_fusion, 

1758 ) 

1759 from .postgres_vector import format_vector_for_postgres, get_vector_operator 

1760 

1761 self._check_connection() 

1762 

1763 config = config or HybridSearchConfig() 

1764 

1765 # For NATIVE strategy with pgvector, we can do a combined query 

1766 # For other strategies, use the parent implementation 

1767 if config.fusion_strategy not in (FusionStrategy.NATIVE, FusionStrategy.RRF): 

1768 from ..vector.mixins import VectorOperationsMixin 

1769 return await VectorOperationsMixin.hybrid_search( 

1770 self, 

1771 query_text=query_text, 

1772 query_vector=query_vector, 

1773 text_fields=text_fields, 

1774 vector_field=vector_field, 

1775 k=k, 

1776 config=config, 

1777 filter=filter, 

1778 metric=metric, 

1779 ) 

1780 

1781 # Use config.text_fields if provided, otherwise use parameter 

1782 search_text_fields = config.text_fields or text_fields or ["content", "title", "text"] 

1783 

1784 # Get more results for fusion 

1785 fetch_k = min(k * 3, 100) 

1786 

1787 # Prepare vector search 

1788 if isinstance(query_vector, (list, tuple)): 

1789 import numpy as np 

1790 query_vector = np.array(query_vector, dtype=np.float32) 

1791 

1792 vector_str = format_vector_for_postgres(query_vector) 

1793 

1794 # Get metric operator 

1795 if isinstance(metric, str): 

1796 metric_str = metric.lower() 

1797 else: 

1798 metric_str = metric.value 

1799 operator = get_vector_operator(metric_str) 

1800 

1801 vector_column = f"vector_{vector_field}" 

1802 

1803 # Build combined query using CTE for efficient hybrid search 

1804 # This performs both searches in a single query 

1805 sql = f""" 

1806 WITH text_search AS ( 

1807 SELECT 

1808 id, 

1809 data, 

1810 metadata, 

1811 ts_rank_cd( 

1812 to_tsvector('english', {self._build_text_field_concat(search_text_fields)}), 

1813 plainto_tsquery('english', $1) 

1814 ) as text_score, 

1815 ROW_NUMBER() OVER ( 

1816 ORDER BY ts_rank_cd( 

1817 to_tsvector('english', {self._build_text_field_concat(search_text_fields)}), 

1818 plainto_tsquery('english', $1) 

1819 ) DESC 

1820 ) as text_rank 

1821 FROM {self.schema_name}.{self.table_name} 

1822 WHERE to_tsvector('english', {self._build_text_field_concat(search_text_fields)}) @@ plainto_tsquery('english', $1) 

1823 LIMIT {fetch_k} 

1824 ), 

1825 vector_search AS ( 

1826 SELECT 

1827 id, 

1828 data, 

1829 metadata, 

1830 {vector_column}, 

1831 1.0 - ({vector_column} {operator} $2::vector) as vector_score, 

1832 ROW_NUMBER() OVER ( 

1833 ORDER BY {vector_column} {operator} $2::vector 

1834 ) as vector_rank 

1835 FROM {self.schema_name}.{self.table_name} 

1836 WHERE {vector_column} IS NOT NULL 

1837 LIMIT {fetch_k} 

1838 ), 

1839 combined AS ( 

1840 SELECT 

1841 COALESCE(t.id, v.id) as id, 

1842 COALESCE(t.data, v.data) as data, 

1843 COALESCE(t.metadata, v.metadata) as metadata, 

1844 t.text_score, 

1845 t.text_rank, 

1846 v.vector_score, 

1847 v.vector_rank 

1848 FROM text_search t 

1849 FULL OUTER JOIN vector_search v ON t.id = v.id 

1850 ) 

1851 SELECT * FROM combined 

1852 """ 

1853 

1854 params = [query_text, vector_str] 

1855 

1856 try: 

1857 async with self._pool.acquire() as conn: 

1858 rows = await conn.fetch(sql, *params) 

1859 except Exception as e: 

1860 # If full-text search fails, fall back to client-side fusion 

1861 logger.warning(f"Native PostgreSQL hybrid search failed ({e}), falling back to client-side") 

1862 from ..vector.mixins import VectorOperationsMixin 

1863 return await VectorOperationsMixin.hybrid_search( 

1864 self, 

1865 query_text=query_text, 

1866 query_vector=query_vector, 

1867 text_fields=text_fields, 

1868 vector_field=vector_field, 

1869 k=k, 

1870 config=HybridSearchConfig( 

1871 text_weight=config.text_weight, 

1872 vector_weight=config.vector_weight, 

1873 fusion_strategy=FusionStrategy.RRF, 

1874 rrf_k=config.rrf_k, 

1875 text_fields=config.text_fields, 

1876 ), 

1877 filter=filter, 

1878 metric=metric, 

1879 ) 

1880 

1881 # Build result lists for fusion 

1882 records_by_id: dict[str, Record] = {} 

1883 text_scores: list[tuple[str, float]] = [] 

1884 vector_scores: list[tuple[str, float]] = [] 

1885 

1886 for row in rows: 

1887 record = self._row_to_record(row) 

1888 record_id = row['id'] 

1889 records_by_id[record_id] = record 

1890 

1891 if row['text_score'] is not None: 

1892 text_scores.append((record_id, float(row['text_score']))) 

1893 if row['vector_score'] is not None: 

1894 vector_scores.append((record_id, float(row['vector_score']))) 

1895 

1896 # Sort by score for rank-based fusion 

1897 text_scores.sort(key=lambda x: x[1], reverse=True) 

1898 vector_scores.sort(key=lambda x: x[1], reverse=True) 

1899 

1900 # Apply RRF fusion 

1901 fused = reciprocal_rank_fusion( 

1902 text_results=text_scores, 

1903 vector_results=vector_scores, 

1904 k=config.rrf_k, 

1905 text_weight=config.text_weight, 

1906 vector_weight=config.vector_weight, 

1907 ) 

1908 

1909 # Build HybridSearchResult objects 

1910 text_score_map = dict(text_scores) 

1911 vector_score_map = dict(vector_scores) 

1912 text_rank_map = {rid: i + 1 for i, (rid, _) in enumerate(text_scores)} 

1913 vector_rank_map = {rid: i + 1 for i, (rid, _) in enumerate(vector_scores)} 

1914 

1915 results: list[HybridSearchResult] = [] 

1916 for record_id, combined_score in fused[:k]: 

1917 if record_id not in records_by_id: 

1918 continue 

1919 

1920 results.append(HybridSearchResult( 

1921 record=records_by_id[record_id], 

1922 combined_score=combined_score, 

1923 text_score=text_score_map.get(record_id), 

1924 vector_score=vector_score_map.get(record_id), 

1925 text_rank=text_rank_map.get(record_id), 

1926 vector_rank=vector_rank_map.get(record_id), 

1927 metadata={ 

1928 "fusion_strategy": config.fusion_strategy.value, 

1929 "text_weight": config.text_weight, 

1930 "vector_weight": config.vector_weight, 

1931 "backend": "postgresql", 

1932 }, 

1933 )) 

1934 

1935 return results 

1936 

1937 def _build_text_field_concat(self, text_fields: list[str]) -> str: 

1938 """Build SQL expression to concatenate text fields for full-text search. 

1939 

1940 Args: 

1941 text_fields: List of field names to concatenate 

1942 

1943 Returns: 

1944 SQL expression for concatenated text fields 

1945 """ 

1946 if not text_fields: 

1947 return "COALESCE(data->>'content', '')" 

1948 

1949 parts = [f"COALESCE(data->>'{field}', '')" for field in text_fields] 

1950 return " || ' ' || ".join(parts) 

1951 

1952 async def _text_search_for_hybrid( 

1953 self, 

1954 query_text: str, 

1955 text_fields: list[str] | None, 

1956 k: int, 

1957 filter: Query | None = None, 

1958 ) -> list[tuple[Record, float]]: 

1959 """Perform PostgreSQL full-text search for hybrid search fusion. 

1960 

1961 Uses PostgreSQL's tsvector/tsquery full-text search with ts_rank_cd scoring. 

1962 

1963 Args: 

1964 query_text: Text to search for 

1965 text_fields: Fields to search in 

1966 k: Maximum results to return 

1967 filter: Additional filters 

1968 

1969 Returns: 

1970 List of (record, score) tuples ordered by text relevance 

1971 """ 

1972 self._check_connection() 

1973 

1974 search_fields = text_fields or ["content", "title", "text"] 

1975 text_concat = self._build_text_field_concat(search_fields) 

1976 

1977 sql = f""" 

1978 SELECT 

1979 id, 

1980 data, 

1981 metadata, 

1982 ts_rank_cd( 

1983 to_tsvector('english', {text_concat}), 

1984 plainto_tsquery('english', $1) 

1985 ) as score 

1986 FROM {self.schema_name}.{self.table_name} 

1987 WHERE to_tsvector('english', {text_concat}) @@ plainto_tsquery('english', $1) 

1988 ORDER BY score DESC 

1989 LIMIT {k} 

1990 """ 

1991 

1992 try: 

1993 async with self._pool.acquire() as conn: 

1994 rows = await conn.fetch(sql, query_text) 

1995 except Exception as e: 

1996 # Fall back to LIKE-based search if full-text search fails 

1997 logger.warning(f"PostgreSQL full-text search failed ({e}), falling back to LIKE") 

1998 from ..vector.mixins import VectorOperationsMixin 

1999 return await VectorOperationsMixin._text_search_for_hybrid( 

2000 self, 

2001 query_text=query_text, 

2002 text_fields=text_fields, 

2003 k=k, 

2004 filter=filter, 

2005 ) 

2006 

2007 # Normalize scores 

2008 results: list[tuple[Record, float]] = [] 

2009 max_score = max((float(row['score']) for row in rows), default=1.0) or 1.0 

2010 

2011 for row in rows: 

2012 record = self._row_to_record(row) 

2013 score = float(row['score']) / max_score if max_score > 0 else 0.0 

2014 results.append((record, score)) 

2015 

2016 return results