Coverage for src/dataknobs_data/backends/postgres.py: 12%

734 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-29 14:14 -0600

1"""PostgreSQL backend implementation with proper connection management and vector support.""" 

2 

3from __future__ import annotations 

4 

5import json 

6import logging 

7import time 

8import uuid 

9from typing import TYPE_CHECKING, Any, cast 

10 

11import asyncpg 

12from dataknobs_config import ConfigurableBase 

13 

14from dataknobs_utils.sql_utils import DotenvPostgresConnector, PostgresDB 

15 

16from ..database import AsyncDatabase, SyncDatabase 

17from ..pooling import ConnectionPoolManager 

18from ..pooling.postgres import PostgresPoolConfig, create_asyncpg_pool, validate_asyncpg_pool 

19from ..query import Operator, Query 

20from ..query_logic import ComplexQuery 

21from ..streaming import ( 

22 StreamConfig, 

23 StreamResult, 

24 async_process_batch_with_fallback, 

25 process_batch_with_fallback, 

26) 

27from ..vector.mixins import VectorOperationsMixin 

28from .postgres_mixins import ( 

29 PostgresBaseConfig, 

30 PostgresConnectionValidator, 

31 PostgresErrorHandler, 

32 PostgresTableManager, 

33 PostgresVectorSupport, 

34) 

35from .sql_base import SQLQueryBuilder, SQLRecordSerializer 

36 

37if TYPE_CHECKING: 

38 import numpy as np 

39 

40 from collections.abc import AsyncIterator, Iterator, Callable, Awaitable 

41 from ..fields import VectorField 

42 from ..records import Record 

43 from ..vector.types import DistanceMetric, VectorSearchResult 

44 

45logger = logging.getLogger(__name__) 

46 

47 

48class SyncPostgresDatabase( 

49 SyncDatabase, 

50 ConfigurableBase, 

51 VectorOperationsMixin, 

52 SQLRecordSerializer, 

53 PostgresBaseConfig, 

54 PostgresTableManager, 

55 PostgresVectorSupport, 

56 PostgresConnectionValidator, 

57 PostgresErrorHandler, 

58): 

59 """Synchronous PostgreSQL database backend with proper connection management.""" 

60 

61 def __init__(self, config: dict[str, Any] | None = None): 

62 """Initialize PostgreSQL database configuration. 

63 

64 Args: 

65 config: Configuration with the following optional keys: 

66 - host: PostgreSQL host (default: from env/localhost) 

67 - port: PostgreSQL port (default: 5432) 

68 - database: Database name (default: from env/postgres) 

69 - user: Username (default: from env/postgres) 

70 - password: Password (default: from env) 

71 - table: Table name (default: "records") 

72 - schema: Schema name (default: "public") 

73 - enable_vector: Enable vector support (default: False) 

74 """ 

75 super().__init__(config) 

76 

77 # Parse configuration using mixin 

78 table_name, schema_name, conn_config = self._parse_postgres_config(config or {}) 

79 self._init_postgres_attributes(table_name, schema_name) 

80 

81 # Store connection config for later use 

82 self._conn_config = conn_config 

83 self.db = None # Will be initialized in connect() 

84 self.query_builder = None # Will be initialized in connect() 

85 

86 @classmethod 

87 def from_config(cls, config: dict) -> SyncPostgresDatabase: 

88 """Create from config dictionary.""" 

89 return cls(config) 

90 

91 def connect(self) -> None: 

92 """Connect to the PostgreSQL database.""" 

93 if self._connected: 

94 return # Already connected 

95 

96 # Initialize query builder with pyformat style for psycopg2 

97 self.query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat") 

98 

99 # Create connection using existing utilities 

100 if not any(key in self._conn_config for key in ["host", "database", "user"]): 

101 # Use dotenv connector for environment-based config 

102 connector = DotenvPostgresConnector() 

103 self.db = PostgresDB(connector) 

104 else: 

105 # Direct configuration - map 'database' to 'db' for PostgresDB 

106 self.db = PostgresDB( 

107 host=self._conn_config.get("host", "localhost"), 

108 db=self._conn_config.get("database", "postgres"), # Note: PostgresDB expects 'db' not 'database' 

109 user=self._conn_config.get("user", "postgres"), 

110 pwd=self._conn_config.get("password"), # Note: PostgresDB expects 'pwd' not 'password' 

111 port=self._conn_config.get("port", 5432), 

112 ) 

113 

114 # Create table if it doesn't exist 

115 self._ensure_table() 

116 

117 # Detect and enable vector support if requested 

118 if self.vector_enabled: 

119 self._detect_vector_support() 

120 

121 self._connected = True 

122 self.log_operation("connect", f"Connected to table: {self.schema_name}.{self.table_name}") 

123 

124 def close(self) -> None: 

125 """Close the database connection.""" 

126 if self.db: 

127 # PostgresDB manages its own connections via context managers 

128 # but we can mark as disconnected 

129 self._connected = False # type: ignore[unreachable] 

130 

131 def _initialize(self) -> None: 

132 """Initialize method - connection setup moved to connect().""" 

133 # Configuration parsing stays here, actual connection in connect() 

134 pass 

135 

136 def _detect_vector_support(self) -> None: 

137 """Detect and enable vector support if pgvector is available.""" 

138 from .postgres_vector import check_pgvector_extension_sync, install_pgvector_extension_sync 

139 

140 try: 

141 # Check if pgvector is installed 

142 if check_pgvector_extension_sync(self.db): 

143 self._vector_enabled = True 

144 logger.info("pgvector extension detected and enabled") 

145 else: 

146 # Try to install it 

147 if install_pgvector_extension_sync(self.db): 

148 self._vector_enabled = True 

149 logger.info("pgvector extension installed and enabled") 

150 else: 

151 logger.debug("pgvector extension not available") 

152 except Exception as e: 

153 logger.debug(f"Could not enable vector support: {e}") 

154 self._vector_enabled = False 

155 

156 def _ensure_table(self) -> None: 

157 """Ensure the records table exists.""" 

158 if not self.db: 

159 raise RuntimeError("Database not connected. Call connect() first.") 

160 

161 create_table_sql = self.get_create_table_sql(self.schema_name, self.table_name) # type: ignore[unreachable] 

162 self.db.execute(create_table_sql) 

163 

164 

165 def _record_to_row(self, record: Record, id: str | None = None) -> dict[str, Any]: 

166 """Convert a Record to a database row.""" 

167 return { 

168 "id": id or str(uuid.uuid4()), 

169 "data": self.record_to_json(record), 

170 "metadata": json.dumps(record.metadata) if record.metadata else None, 

171 } 

172 

173 def _row_to_record(self, row: dict[str, Any]) -> Record: 

174 """Convert a database row to a Record.""" 

175 return self.row_to_record(row) 

176 

177 def create(self, record: Record) -> str: 

178 """Create a new record.""" 

179 self._check_connection() 

180 # Use record's ID if it has one, otherwise generate a new one 

181 id = record.id if record.id else str(uuid.uuid4()) 

182 row = self._record_to_row(record, id) 

183 

184 sql = f""" 

185 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata) 

186 VALUES (%(id)s, %(data)s, %(metadata)s) 

187 """ 

188 self.db.execute(sql, row) 

189 return id 

190 

191 def read(self, id: str) -> Record | None: 

192 """Read a record by ID.""" 

193 self._check_connection() 

194 sql = f""" 

195 SELECT id, data, metadata 

196 FROM {self.schema_name}.{self.table_name} 

197 WHERE id = %(id)s 

198 """ 

199 df = self.db.query(sql, {"id": id}) 

200 

201 if df.empty: 

202 return None 

203 

204 row = df.iloc[0].to_dict() 

205 return self._row_to_record(row) 

206 

207 def update(self, id: str, record: Record) -> bool: 

208 """Update an existing record. 

209 

210 Args: 

211 id: The record ID to update 

212 record: The record data to update with 

213 

214 Returns: 

215 True if the record was updated, False if no record with the given ID exists 

216 """ 

217 self._check_connection() 

218 row = self._record_to_row(record, id) 

219 

220 sql = f""" 

221 UPDATE {self.schema_name}.{self.table_name} 

222 SET data = %(data)s, metadata = %(metadata)s, updated_at = CURRENT_TIMESTAMP 

223 WHERE id = %(id)s 

224 """ 

225 result = self.db.execute(sql, row) 

226 

227 # PostgresDB.execute returns number of affected rows 

228 rows_affected = result if isinstance(result, int) else 0 

229 

230 if rows_affected == 0: 

231 logger.warning(f"Update affected 0 rows for id={id}. Record may not exist.") 

232 

233 return rows_affected > 0 

234 

235 def delete(self, id: str) -> bool: 

236 """Delete a record by ID.""" 

237 self._check_connection() 

238 sql = f""" 

239 DELETE FROM {self.schema_name}.{self.table_name} 

240 WHERE id = %(id)s 

241 """ 

242 result = self.db.execute(sql, {"id": id}) 

243 return result > 0 if isinstance(result, int) else False 

244 

245 def exists(self, id: str) -> bool: 

246 """Check if a record exists.""" 

247 self._check_connection() 

248 sql = f""" 

249 SELECT 1 FROM {self.schema_name}.{self.table_name} 

250 WHERE id = %(id)s 

251 LIMIT 1 

252 """ 

253 df = self.db.query(sql, {"id": id}) 

254 return not df.empty 

255 

256 def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str: 

257 """Update or insert a record. 

258  

259 Can be called as: 

260 - upsert(id, record) - explicit ID and record 

261 - upsert(record) - extract ID from record using Record's built-in logic 

262 """ 

263 self._check_connection() 

264 

265 # Determine ID and record based on arguments 

266 if isinstance(id_or_record, str): 

267 id = id_or_record 

268 if record is None: 

269 raise ValueError("Record required when ID is provided") 

270 else: 

271 record = id_or_record 

272 id = record.id 

273 if id is None: 

274 import uuid # type: ignore[unreachable] 

275 id = str(uuid.uuid4()) 

276 record.storage_id = id 

277 

278 if self.exists(id): 

279 self.update(id, record) 

280 else: 

281 # Insert with specific ID 

282 row = self._record_to_row(record, id) 

283 sql = f""" 

284 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata) 

285 VALUES (%(id)s, %(data)s, %(metadata)s) 

286 """ 

287 self.db.execute(sql, row) 

288 return id 

289 

290 def search(self, query: Query | ComplexQuery) -> list[Record]: 

291 """Search for records matching the query.""" 

292 self._check_connection() 

293 

294 # Handle ComplexQuery with native SQL support 

295 if isinstance(query, ComplexQuery): 

296 sql_query, params_list = self.query_builder.build_complex_search_query(query) 

297 else: 

298 sql_query, params_list = self.query_builder.build_search_query(query) 

299 

300 # Build params dict for psycopg2 

301 # The query builder now generates %(p0)s style placeholders directly 

302 params_dict = {} 

303 if params_list: 

304 for i, param in enumerate(params_list): 

305 params_dict[f"p{i}"] = param 

306 

307 # Execute query 

308 df = self.db.query(sql_query, params_dict) 

309 

310 # Convert to records 

311 records = [] 

312 for _, row in df.iterrows(): 

313 row_dict = row.to_dict() 

314 record = self._row_to_record(row_dict) 

315 

316 # Populate storage_id from database ID 

317 record.storage_id = str(row_dict['id']) 

318 

319 # Apply field projection if specified 

320 if query.fields: 

321 record = record.project(query.fields) 

322 

323 records.append(record) 

324 

325 return records 

326 

327 def _count_all(self) -> int: 

328 """Count all records in the database.""" 

329 self._check_connection() 

330 sql = f"SELECT COUNT(*) as count FROM {self.schema_name}.{self.table_name}" 

331 df = self.db.query(sql) 

332 return int(df.iloc[0]["count"]) if not df.empty else 0 

333 

334 def clear(self) -> int: 

335 """Clear all records from the database.""" 

336 self._check_connection() 

337 # Get count first 

338 count = self._count_all() 

339 

340 # Delete all records 

341 sql = f"TRUNCATE TABLE {self.schema_name}.{self.table_name}" 

342 self.db.execute(sql) 

343 

344 return count 

345 

346 def create_batch(self, records: list[Record]) -> list[str]: 

347 """Create multiple records efficiently using a single query. 

348  

349 Uses multi-value INSERT for better performance. 

350  

351 Args: 

352 records: List of records to create 

353  

354 Returns: 

355 List of created record IDs 

356 """ 

357 if not records: 

358 return [] 

359 

360 self._check_connection() 

361 

362 # Create a query builder for PostgreSQL with pyformat style 

363 from .sql_base import SQLQueryBuilder 

364 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat") 

365 

366 # Use the shared batch create query builder 

367 query, params_list, ids = query_builder.build_batch_create_query(records) 

368 

369 # Build params dict for psycopg2 

370 params_dict = {} 

371 for i, param in enumerate(params_list): 

372 params_dict[f"p{i}"] = param 

373 

374 # Execute the batch insert and get returned IDs 

375 result_df = self.db.query(query, params_dict) 

376 

377 # PostgreSQL RETURNING clause gives us the actual inserted IDs 

378 if not result_df.empty: 

379 return result_df['id'].tolist() 

380 return ids 

381 

382 def delete_batch(self, ids: list[str]) -> list[bool]: 

383 """Delete multiple records efficiently using a single query. 

384  

385 Uses single DELETE with IN clause for better performance. 

386  

387 Args: 

388 ids: List of record IDs to delete 

389  

390 Returns: 

391 List of success flags for each deletion 

392 """ 

393 if not ids: 

394 return [] 

395 

396 self._check_connection() 

397 

398 # Create a query builder for PostgreSQL with pyformat style 

399 from .sql_base import SQLQueryBuilder 

400 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat") 

401 

402 # Use the shared batch delete query builder (includes RETURNING clause) 

403 query, params_list = query_builder.build_batch_delete_query(ids) 

404 

405 # Build params dict for psycopg2 

406 params_dict = {} 

407 for i, param in enumerate(params_list): 

408 params_dict[f"p{i}"] = param 

409 

410 # Execute the batch delete and get returned IDs 

411 result_df = self.db.query(query, params_dict) 

412 

413 # Get list of deleted IDs from RETURNING clause 

414 deleted_ids = set(result_df['id'].tolist()) if not result_df.empty else set() 

415 

416 # Return results based on which IDs were actually deleted 

417 results = [] 

418 for id in ids: 

419 results.append(id in deleted_ids) 

420 

421 return results 

422 

423 def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]: 

424 """Update multiple records efficiently using a single query. 

425  

426 Uses PostgreSQL's CASE expressions for batch updates via shared SQL builder. 

427  

428 Args: 

429 updates: List of (id, record) tuples to update 

430  

431 Returns: 

432 List of success flags for each update 

433 """ 

434 if not updates: 

435 return [] 

436 

437 self._check_connection() 

438 

439 # Create a query builder for PostgreSQL with pyformat style 

440 from .sql_base import SQLQueryBuilder 

441 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat") 

442 

443 # Use the shared batch update query builder 

444 query, params_list = query_builder.build_batch_update_query(updates) 

445 

446 # Build params dict for psycopg2 

447 params_dict = {} 

448 for i, param in enumerate(params_list): 

449 params_dict[f"p{i}"] = param 

450 

451 # Execute the batch update and get returned IDs (query now includes RETURNING clause) 

452 result_df = self.db.query(query, params_dict) 

453 

454 # Get list of updated IDs from RETURNING clause 

455 updated_ids = set(result_df['id'].tolist()) if not result_df.empty else set() 

456 

457 results = [] 

458 for record_id, _ in updates: 

459 results.append(record_id in updated_ids) 

460 

461 return results 

462 

463 def stream_read( 

464 self, 

465 query: Query | None = None, 

466 config: StreamConfig | None = None 

467 ) -> Iterator[Record]: 

468 """Stream records from PostgreSQL.""" 

469 self._check_connection() 

470 config = config or StreamConfig() 

471 

472 # Build SQL query 

473 sql = f"SELECT id, data, metadata FROM {self.schema_name}.{self.table_name}" 

474 params = {} 

475 

476 if query and query.filters: 

477 # Add WHERE clause (simplified for now) 

478 where_clauses = [] 

479 for i, filter in enumerate(query.filters): 

480 field_path = f"data->>'{filter.field}'" 

481 param_name = f"param_{i}" 

482 

483 if filter.operator == Operator.EQ: 

484 where_clauses.append(f"{field_path} = %({param_name})s") 

485 params[param_name] = str(filter.value) 

486 

487 if where_clauses: 

488 sql += " WHERE " + " AND ".join(where_clauses) 

489 

490 # Use cursor for streaming 

491 # Note: PostgresDB may need modification to support cursors 

492 # For now, we'll fetch in batches 

493 sql += f" LIMIT {config.batch_size} OFFSET %(offset)s" 

494 

495 offset = 0 

496 while True: 

497 params["offset"] = offset 

498 df = self.db.query(sql, params) 

499 

500 if df.empty: 

501 break 

502 

503 for _, row in df.iterrows(): 

504 record = self._row_to_record(row.to_dict()) 

505 if query and query.fields: 

506 record = record.project(query.fields) 

507 yield record 

508 

509 offset += config.batch_size 

510 

511 # If we got less than batch_size, we're done 

512 if len(df) < config.batch_size: 

513 break 

514 

515 def stream_write( 

516 self, 

517 records: Iterator[Record], 

518 config: StreamConfig | None = None 

519 ) -> StreamResult: 

520 """Stream records into PostgreSQL.""" 

521 self._check_connection() 

522 config = config or StreamConfig() 

523 result = StreamResult() 

524 start_time = time.time() 

525 quitting = False 

526 

527 batch = [] 

528 for record in records: 

529 batch.append(record) 

530 

531 if len(batch) >= config.batch_size: 

532 # Write batch with graceful fallback 

533 # Use lambda wrapper for _write_batch 

534 continue_processing = process_batch_with_fallback( 

535 batch, 

536 lambda b: self._write_batch(b), 

537 self.create, 

538 result, 

539 config 

540 ) 

541 

542 if not continue_processing: 

543 quitting = True 

544 break 

545 

546 batch = [] 

547 

548 # Write remaining batch 

549 if batch and not quitting: 

550 process_batch_with_fallback( 

551 batch, 

552 lambda b: self._write_batch(b), 

553 self.create, 

554 result, 

555 config 

556 ) 

557 

558 result.duration = time.time() - start_time 

559 return result 

560 

561 def _write_batch(self, records: list[Record]) -> list[str]: 

562 """Write a batch of records to the database. 

563  

564 Returns: 

565 List of created record IDs 

566 """ 

567 # Build batch insert SQL 

568 values = [] 

569 params = {} 

570 ids = [] 

571 

572 for i, record in enumerate(records): 

573 id = str(uuid.uuid4()) 

574 ids.append(id) 

575 row = self._record_to_row(record, id) 

576 values.append(f"(%(id_{i})s, %(data_{i})s, %(metadata_{i})s)") 

577 params[f"id_{i}"] = row["id"] 

578 params[f"data_{i}"] = row["data"] 

579 params[f"metadata_{i}"] = row["metadata"] 

580 

581 sql = f""" 

582 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata) 

583 VALUES {', '.join(values)} 

584 """ 

585 self.db.execute(sql, params) 

586 return ids 

587 

588 def vector_search( 

589 self, 

590 query_vector: np.ndarray | list[float] | VectorField, 

591 field_name: str, 

592 k: int = 10, 

593 filter: Query | None = None, 

594 metric: DistanceMetric | str = "cosine" 

595 ) -> list[VectorSearchResult]: 

596 """Search for similar vectors using PostgreSQL pgvector. 

597  

598 Args: 

599 query_vector: Query vector (numpy array, list, or VectorField) 

600 field_name: Name of vector field to search (must be in data JSON) 

601 limit: Maximum number of results 

602 filters: Optional filters to apply 

603 metric: Distance metric to use (cosine, euclidean, l2, inner_product) 

604  

605 Returns: 

606 List of VectorSearchResult objects ordered by similarity 

607 """ 

608 if not self._vector_enabled: 

609 raise RuntimeError("Vector search not available - pgvector not installed") 

610 

611 self._check_connection() 

612 

613 from ..fields import VectorField 

614 from ..vector.types import DistanceMetric, VectorSearchResult 

615 from .postgres_vector import format_vector_for_postgres, get_vector_operator 

616 

617 # Convert query vector to proper format 

618 if isinstance(query_vector, VectorField): 

619 vector_str = format_vector_for_postgres(query_vector.value) 

620 else: 

621 vector_str = format_vector_for_postgres(query_vector) 

622 

623 # Get the appropriate operator 

624 if isinstance(metric, DistanceMetric): 

625 metric_str = metric.value 

626 else: 

627 metric_str = str(metric).lower() 

628 

629 operator = get_vector_operator(metric_str) 

630 

631 # Build the query - vectors are stored in JSON data field 

632 # Use centralized vector extraction logic 

633 vector_expr = self.get_vector_extraction_sql(field_name, dialect="postgres") 

634 

635 # Build the base SQL with pyformat placeholders 

636 sql = f""" 

637 SELECT  

638 id,  

639 data, 

640 metadata, 

641 {vector_expr} {operator} %(p0)s::vector AS distance 

642 FROM {self.schema_name}.{self.table_name} 

643 WHERE data ? %(p1)s -- Check field exists 

644 """ 

645 

646 params: list[Any] = [vector_str, field_name] 

647 

648 # Add filters if provided using the query builder 

649 if filter: 

650 # Query builder will generate pyformat placeholders since we configured it that way 

651 where_clause, filter_params = self.query_builder.build_where_clause(filter, len(params) + 1) 

652 if where_clause: 

653 sql += where_clause 

654 params.extend(filter_params) 

655 

656 # Order by distance and limit 

657 next_param = len(params) 

658 sql += f" ORDER BY distance LIMIT %(p{next_param})s" 

659 params.append(k) 

660 

661 # Build param dict for psycopg2 

662 param_dict = {} 

663 for i, param in enumerate(params): 

664 param_dict[f"p{i}"] = param 

665 

666 df = self.db.query(sql, param_dict) 

667 

668 # Convert results 

669 results = [] 

670 for _, row in df.iterrows(): 

671 record = self._row_to_record(row) 

672 

673 # Calculate similarity score from distance 

674 distance = row["distance"] 

675 if metric_str in ["cosine", "cosine_similarity"]: 

676 score = 1.0 - distance # Cosine distance to similarity 

677 elif metric_str in ["euclidean", "l2"]: 

678 score = 1.0 / (1.0 + distance) # Convert distance to similarity 

679 elif metric_str in ["inner_product", "dot_product"]: 

680 score = -distance # Negative because pgvector uses negative for descending 

681 else: 

682 score = -distance # Default: lower distance = better 

683 

684 result = VectorSearchResult( 

685 record=record, 

686 score=float(score), 

687 vector_field=field_name 

688 ) 

689 results.append(result) 

690 

691 return results 

692 

693 def has_vector_support(self) -> bool: 

694 """Check if this database has vector support enabled. 

695  

696 Returns: 

697 True if vector operations are supported 

698 """ 

699 return self._vector_enabled 

700 

701 def enable_vector_support(self) -> bool: 

702 """Enable vector support for this database if possible. 

703  

704 Returns: 

705 True if vector support is now enabled 

706 """ 

707 if self._vector_enabled: 

708 return True 

709 

710 self._detect_vector_support() 

711 return self._vector_enabled 

712 

713 def bulk_embed_and_store( 

714 self, 

715 records: list[Record], 

716 text_field: str | list[str], 

717 vector_field: str = "embedding", 

718 embedding_fn: Any = None, 

719 batch_size: int = 100, 

720 model_name: str | None = None, 

721 model_version: str | None = None, 

722 ) -> list[str]: 

723 """Embed text fields and store vectors with records (stub for abstract requirement). 

724  

725 This is a placeholder implementation to satisfy the abstract method requirement. 

726 Full implementation would require actual embedding function. 

727 """ 

728 raise NotImplementedError("bulk_embed_and_store requires an embedding function") 

729 

730 

731# Global pool manager instance for async PostgreSQL connections 

732_pool_manager = ConnectionPoolManager[asyncpg.Pool]() 

733 

734 

735class AsyncPostgresDatabase( 

736 AsyncDatabase, 

737 VectorOperationsMixin, 

738 ConfigurableBase, 

739 PostgresBaseConfig, 

740 PostgresTableManager, 

741 PostgresVectorSupport, 

742 PostgresConnectionValidator, 

743 PostgresErrorHandler, 

744): 

745 """Native async PostgreSQL database backend with vector support and event loop-aware connection pooling.""" 

746 

747 def __init__(self, config: dict[str, Any] | None = None): 

748 """Initialize async PostgreSQL database.""" 

749 super().__init__(config) 

750 

751 # Parse configuration using mixin 

752 table_name, schema_name, conn_config = self._parse_postgres_config(config or {}) 

753 self._init_postgres_attributes(table_name, schema_name) 

754 

755 # Extract pool configuration 

756 self._pool_config = PostgresPoolConfig.from_dict(conn_config) 

757 self._pool: asyncpg.Pool | None = None 

758 

759 @classmethod 

760 def from_config(cls, config: dict) -> AsyncPostgresDatabase: 

761 """Create from config dictionary.""" 

762 return cls(config) 

763 

764 async def connect(self) -> None: 

765 """Connect to the database.""" 

766 if self._connected: 

767 return 

768 

769 # Get or create pool for current event loop 

770 from ..pooling import BasePoolConfig 

771 self._pool = await _pool_manager.get_pool( 

772 self._pool_config, 

773 cast("Callable[[BasePoolConfig], Awaitable[Any]]", create_asyncpg_pool), 

774 validate_asyncpg_pool 

775 ) 

776 

777 # Initialize query builder 

778 self.query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres") 

779 

780 # Ensure table exists 

781 await self._ensure_table() 

782 

783 # Check and enable vector support if requested 

784 if self.vector_enabled: 

785 await self._detect_vector_support() 

786 

787 self._connected = True 

788 self.log_operation("connect", f"Connected to table: {self.schema_name}.{self.table_name}") 

789 

790 async def close(self) -> None: 

791 """Close the database connection and properly close the pool.""" 

792 if self._connected: 

793 # Properly close the pool if we have one 

794 if self._pool: 

795 try: 

796 await self._pool.close() 

797 except Exception as e: 

798 logger.warning(f"Error closing connection pool: {e}") 

799 self._pool = None 

800 self._connected = False 

801 

802 def _initialize(self) -> None: 

803 """Initialize is handled in connect.""" 

804 pass 

805 

806 async def _ensure_table(self) -> None: 

807 """Ensure the records table exists.""" 

808 if not self._pool: 

809 raise RuntimeError("Database not connected. Call connect() first.") 

810 

811 create_table_sql = self.get_create_table_sql(self.schema_name, self.table_name) 

812 

813 async with self._pool.acquire() as conn: 

814 await conn.execute(create_table_sql) 

815 

816 async def _detect_vector_support(self) -> None: 

817 """Detect and enable vector support if pgvector is available.""" 

818 from .postgres_vector import check_pgvector_extension, install_pgvector_extension 

819 

820 async with self._pool.acquire() as conn: 

821 # Check if pgvector is available 

822 if await check_pgvector_extension(conn): 

823 self._vector_enabled = True 

824 logger.info("pgvector extension detected and enabled") 

825 else: 

826 # Try to install it 

827 if await install_pgvector_extension(conn): 

828 self._vector_enabled = True 

829 logger.info("pgvector extension installed and enabled") 

830 else: 

831 logger.debug("pgvector extension not available") 

832 

833 async def _ensure_vector_column(self, field_name: str, dimensions: int) -> None: 

834 """Ensure a vector column exists for the given field. 

835  

836 Args: 

837 field_name: Name of the vector field 

838 dimensions: Number of dimensions 

839 """ 

840 if not self._vector_enabled: 

841 return 

842 

843 column_name = f"vector_{field_name}" 

844 

845 # Check if column already exists 

846 check_sql = """ 

847 SELECT column_name FROM information_schema.columns 

848 WHERE table_schema = $1 AND table_name = $2 AND column_name = $3 

849 """ 

850 

851 async with self._pool.acquire() as conn: 

852 existing = await conn.fetchval(check_sql, self.schema_name, self.table_name, column_name) 

853 

854 if not existing: 

855 # Add vector column 

856 alter_sql = f""" 

857 ALTER TABLE {self.schema_name}.{self.table_name} 

858 ADD COLUMN IF NOT EXISTS {column_name} vector({dimensions}) 

859 """ 

860 try: 

861 await conn.execute(alter_sql) 

862 self._vector_dimensions[field_name] = dimensions 

863 logger.info(f"Added vector column {column_name} with {dimensions} dimensions") 

864 

865 # Create index for the vector column 

866 from .postgres_vector import build_vector_index_sql, get_optimal_index_type 

867 

868 # Get row count for optimal index selection 

869 count_sql = f"SELECT COUNT(*) FROM {self.schema_name}.{self.table_name}" 

870 count = await conn.fetchval(count_sql) 

871 

872 index_type, index_params = get_optimal_index_type(count) 

873 index_sql = build_vector_index_sql( 

874 self.table_name, 

875 self.schema_name, 

876 column_name, 

877 dimensions, 

878 metric="cosine", 

879 index_type=index_type, 

880 index_params=index_params 

881 ) 

882 

883 # Note: IVFFlat requires table to have data before creating index 

884 if count > 0 or index_type != "ivfflat": 

885 await conn.execute(index_sql) 

886 logger.info(f"Created {index_type} index for {column_name}") 

887 

888 except Exception as e: 

889 logger.warning(f"Could not create vector column {column_name}: {e}") 

890 else: 

891 self._vector_dimensions[field_name] = dimensions 

892 

893 def _check_connection(self) -> None: 

894 """Check if async database is connected.""" 

895 self._check_async_connection() 

896 

897 def _record_to_row(self, record: Record, id: str | None = None) -> dict[str, Any]: 

898 """Convert a Record to a database row using common serializer.""" 

899 from .sql_base import SQLRecordSerializer 

900 

901 return { 

902 "id": id or str(uuid.uuid4()), 

903 "data": SQLRecordSerializer.record_to_json(record), 

904 "metadata": json.dumps(record.metadata) if record.metadata else None, 

905 } 

906 

907 def _row_to_record(self, row: asyncpg.Record) -> Record: 

908 """Convert a database row to a Record using the common serializer.""" 

909 from .sql_base import SQLRecordSerializer 

910 

911 # Convert asyncpg.Record to dict format expected by SQLRecordSerializer 

912 data_json = row.get("data", {}) 

913 if not isinstance(data_json, str): 

914 data_json = json.dumps(data_json) 

915 

916 metadata_json = row.get("metadata") 

917 if metadata_json and not isinstance(metadata_json, str): 

918 metadata_json = json.dumps(metadata_json) 

919 

920 # Use the common serializer to reconstruct the record 

921 return SQLRecordSerializer.json_to_record(data_json, metadata_json) 

922 

923 async def create(self, record: Record) -> str: 

924 """Create a new record with vector support.""" 

925 self._check_connection() 

926 

927 # Check for vector fields and ensure columns exist 

928 from ..fields import VectorField 

929 for field_name, field_obj in record.fields.items(): 

930 if isinstance(field_obj, VectorField) and self._vector_enabled: 

931 await self._ensure_vector_column(field_name, field_obj.dimensions) 

932 

933 # Use record's ID if it has one, otherwise generate a new one 

934 id = record.id if record.id else str(uuid.uuid4()) 

935 row = self._record_to_row(record, id) 

936 

937 # Build dynamic SQL based on vector columns present 

938 columns = ["id", "data", "metadata"] 

939 values = [row["id"], row["data"], row["metadata"]] 

940 placeholders = ["$1", "$2", "$3"] 

941 

942 # Add vector columns 

943 param_num = 4 

944 for key, value in row.items(): 

945 if key.startswith("vector_"): 

946 columns.append(key) 

947 values.append(value) 

948 placeholders.append(f"${param_num}") 

949 param_num += 1 

950 

951 sql = f""" 

952 INSERT INTO {self.schema_name}.{self.table_name} ({', '.join(columns)}) 

953 VALUES ({', '.join(placeholders)}) 

954 """ 

955 

956 async with self._pool.acquire() as conn: 

957 await conn.execute(sql, *values) 

958 

959 return id 

960 

961 async def read(self, id: str) -> Record | None: 

962 """Read a record by ID.""" 

963 self._check_connection() 

964 sql = f""" 

965 SELECT id, data, metadata 

966 FROM {self.schema_name}.{self.table_name} 

967 WHERE id = $1 

968 """ 

969 

970 async with self._pool.acquire() as conn: 

971 row = await conn.fetchrow(sql, id) 

972 

973 if not row: 

974 return None 

975 

976 return self._row_to_record(row) 

977 

978 async def update(self, id: str, record: Record) -> bool: 

979 """Update an existing record. 

980 

981 Args: 

982 id: The record ID to update 

983 record: The record data to update with 

984 

985 Returns: 

986 True if the record was updated, False if no record with the given ID exists 

987 """ 

988 self._check_connection() 

989 row = self._record_to_row(record, id) 

990 

991 sql = f""" 

992 UPDATE {self.schema_name}.{self.table_name} 

993 SET data = $2, metadata = $3, updated_at = CURRENT_TIMESTAMP 

994 WHERE id = $1 

995 """ 

996 

997 async with self._pool.acquire() as conn: 

998 result = await conn.execute(sql, row["id"], row["data"], row["metadata"]) 

999 

1000 # Returns UPDATE n where n is rows affected 

1001 rows_affected = int(result.split()[-1]) 

1002 

1003 if rows_affected == 0: 

1004 logger.warning(f"Update affected 0 rows for id={id}. Record may not exist.") 

1005 

1006 return rows_affected > 0 

1007 

1008 async def delete(self, id: str) -> bool: 

1009 """Delete a record by ID.""" 

1010 self._check_connection() 

1011 sql = f""" 

1012 DELETE FROM {self.schema_name}.{self.table_name} 

1013 WHERE id = $1 

1014 """ 

1015 

1016 async with self._pool.acquire() as conn: 

1017 result = await conn.execute(sql, id) 

1018 

1019 # Returns DELETE n where n is rows affected 

1020 return result.split()[-1] != "0" 

1021 

1022 async def exists(self, id: str) -> bool: 

1023 """Check if a record exists.""" 

1024 self._check_connection() 

1025 sql = f""" 

1026 SELECT 1 FROM {self.schema_name}.{self.table_name} 

1027 WHERE id = $1 

1028 LIMIT 1 

1029 """ 

1030 

1031 async with self._pool.acquire() as conn: 

1032 row = await conn.fetchrow(sql, id) 

1033 

1034 return row is not None 

1035 

1036 async def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str: 

1037 """Update or insert a record. 

1038  

1039 Can be called as: 

1040 - upsert(id, record) - explicit ID and record 

1041 - upsert(record) - extract ID from record using Record's built-in logic 

1042 """ 

1043 self._check_connection() 

1044 

1045 # Determine ID and record based on arguments 

1046 if isinstance(id_or_record, str): 

1047 id = id_or_record 

1048 if record is None: 

1049 raise ValueError("Record required when ID is provided") 

1050 else: 

1051 record = id_or_record 

1052 id = record.id 

1053 if id is None: 

1054 import uuid # type: ignore[unreachable] 

1055 id = str(uuid.uuid4()) 

1056 record.storage_id = id 

1057 

1058 row = self._record_to_row(record, id) 

1059 

1060 sql = f""" 

1061 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata) 

1062 VALUES ($1, $2, $3) 

1063 ON CONFLICT (id) DO UPDATE 

1064 SET data = EXCLUDED.data, metadata = EXCLUDED.metadata, updated_at = CURRENT_TIMESTAMP 

1065 """ 

1066 

1067 async with self._pool.acquire() as conn: 

1068 await conn.execute(sql, row["id"], row["data"], row["metadata"]) 

1069 

1070 return id 

1071 

1072 async def search(self, query: Query | ComplexQuery) -> list[Record]: 

1073 """Search for records matching the query.""" 

1074 self._check_connection() 

1075 

1076 # Initialize query builder if not already done 

1077 if not hasattr(self, 'query_builder'): 

1078 self.query_builder = SQLQueryBuilder( 

1079 self.table_name, self.schema_name, dialect="postgres" 

1080 ) 

1081 

1082 # Handle ComplexQuery with native SQL support 

1083 if isinstance(query, ComplexQuery): 

1084 sql, params = self.query_builder.build_complex_search_query(query) 

1085 else: 

1086 sql, params = self.query_builder.build_search_query(query) 

1087 

1088 # Execute query with asyncpg (already uses positional parameters) 

1089 async with self._pool.acquire() as conn: 

1090 rows = await conn.fetch(sql, *params) 

1091 

1092 # Convert to records 

1093 records = [] 

1094 for row in rows: 

1095 record = self._row_to_record(row) 

1096 

1097 # Populate storage_id from database ID 

1098 record.storage_id = str(row['id']) 

1099 

1100 # Apply field projection if specified 

1101 if query.fields: 

1102 record = record.project(query.fields) 

1103 

1104 records.append(record) 

1105 

1106 return records 

1107 

1108 async def _count_all(self) -> int: 

1109 """Count all records in the database.""" 

1110 self._check_connection() 

1111 sql = f"SELECT COUNT(*) as count FROM {self.schema_name}.{self.table_name}" 

1112 

1113 async with self._pool.acquire() as conn: 

1114 row = await conn.fetchrow(sql) 

1115 

1116 return row["count"] if row else 0 

1117 

1118 async def clear(self) -> int: 

1119 """Clear all records from the database.""" 

1120 self._check_connection() 

1121 # Get count first 

1122 count = await self._count_all() 

1123 

1124 # Delete all records 

1125 sql = f"TRUNCATE TABLE {self.schema_name}.{self.table_name}" 

1126 

1127 async with self._pool.acquire() as conn: 

1128 await conn.execute(sql) 

1129 

1130 return count 

1131 

1132 async def create_batch(self, records: list[Record]) -> list[str]: 

1133 """Create multiple records efficiently using a single query. 

1134  

1135 Uses multi-value INSERT with RETURNING for better performance. 

1136  

1137 Args: 

1138 records: List of records to create 

1139  

1140 Returns: 

1141 List of created record IDs 

1142 """ 

1143 if not records: 

1144 return [] 

1145 

1146 self._check_connection() 

1147 

1148 # Create a query builder for PostgreSQL 

1149 from .sql_base import SQLQueryBuilder 

1150 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres") 

1151 

1152 # Use the shared batch create query builder 

1153 query, params, ids = query_builder.build_batch_create_query(records) 

1154 

1155 # Execute the batch insert with RETURNING 

1156 async with self._pool.acquire() as conn: 

1157 rows = await conn.fetch(query, *params) 

1158 

1159 # Return the actual inserted IDs from RETURNING clause 

1160 if rows: 

1161 return [row["id"] for row in rows] 

1162 return ids # Fallback to generated IDs 

1163 

1164 async def delete_batch(self, ids: list[str]) -> list[bool]: 

1165 """Delete multiple records efficiently using a single query. 

1166  

1167 Uses single DELETE with IN clause and RETURNING for verification. 

1168  

1169 Args: 

1170 ids: List of record IDs to delete 

1171  

1172 Returns: 

1173 List of success flags for each deletion 

1174 """ 

1175 if not ids: 

1176 return [] 

1177 

1178 self._check_connection() 

1179 

1180 # Create a query builder for PostgreSQL 

1181 from .sql_base import SQLQueryBuilder 

1182 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres") 

1183 

1184 # Use the shared batch delete query builder 

1185 query, params = query_builder.build_batch_delete_query(ids) 

1186 

1187 # Execute the batch delete with RETURNING 

1188 async with self._pool.acquire() as conn: 

1189 rows = await conn.fetch(query, *params) 

1190 

1191 # Convert returned rows to set of deleted IDs 

1192 deleted_ids = {row["id"] for row in rows} 

1193 

1194 # Return results for each deletion 

1195 results = [] 

1196 for id in ids: 

1197 results.append(id in deleted_ids) 

1198 

1199 return results 

1200 

1201 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]: 

1202 """Update multiple records efficiently using a single query. 

1203  

1204 Uses PostgreSQL's CASE expressions for batch updates with native asyncpg. 

1205  

1206 Args: 

1207 updates: List of (id, record) tuples to update 

1208  

1209 Returns: 

1210 List of success flags for each update 

1211 """ 

1212 if not updates: 

1213 return [] 

1214 

1215 self._check_connection() 

1216 

1217 # Create a query builder for PostgreSQL 

1218 from .sql_base import SQLQueryBuilder 

1219 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres") 

1220 

1221 # Use the shared batch update query builder 

1222 # It already produces positional parameters ($1, $2) for PostgreSQL 

1223 query, params = query_builder.build_batch_update_query(updates) 

1224 

1225 # Add RETURNING clause for PostgreSQL to get updated IDs 

1226 query = query.rstrip() + " RETURNING id" 

1227 

1228 # Execute the batch update 

1229 async with self._pool.acquire() as conn: 

1230 rows = await conn.fetch(query, *params) 

1231 

1232 # Convert returned rows to set of updated IDs 

1233 updated_ids = {row["id"] for row in rows} 

1234 

1235 # Return results for each update 

1236 results = [] 

1237 for record_id, _ in updates: 

1238 results.append(record_id in updated_ids) 

1239 

1240 return results 

1241 

1242 async def vector_search( 

1243 self, 

1244 query_vector: np.ndarray | list[float] | VectorField, 

1245 field_name: str, 

1246 k: int = 10, 

1247 filter: Query | None = None, 

1248 metric: DistanceMetric | str = "cosine" 

1249 ) -> list[VectorSearchResult]: 

1250 """Search for similar vectors using PostgreSQL pgvector. 

1251  

1252 Args: 

1253 query_vector: Query vector (numpy array, list, or VectorField) 

1254 field_name: Name of vector field to search 

1255 limit: Maximum number of results 

1256 filters: Optional filters to apply 

1257 metric: Distance metric to use 

1258  

1259 Returns: 

1260 List of VectorSearchResult objects 

1261 """ 

1262 if not self._vector_enabled: 

1263 raise RuntimeError("Vector search not available - pgvector not installed") 

1264 

1265 self._check_connection() 

1266 

1267 from ..fields import VectorField 

1268 from ..vector.types import DistanceMetric, VectorSearchResult 

1269 from .postgres_vector import format_vector_for_postgres, get_vector_operator 

1270 

1271 # Convert query vector to proper format 

1272 if isinstance(query_vector, VectorField): 

1273 vector_str = format_vector_for_postgres(query_vector.value) 

1274 else: 

1275 vector_str = format_vector_for_postgres(query_vector) 

1276 

1277 # Get the appropriate operator 

1278 if isinstance(metric, DistanceMetric): 

1279 metric_str = metric.value 

1280 else: 

1281 metric_str = str(metric).lower() 

1282 operator = get_vector_operator(metric_str) 

1283 

1284 vector_column = f"vector_{field_name}" 

1285 

1286 # Build query 

1287 sql = f""" 

1288 SELECT id, data, metadata, {vector_column}, 

1289 {vector_column} {operator} $1::vector AS distance 

1290 FROM {self.schema_name}.{self.table_name} 

1291 WHERE {vector_column} IS NOT NULL 

1292 """ 

1293 

1294 params = [vector_str] 

1295 param_num = 2 

1296 

1297 # Add filters if provided using the query builder 

1298 if filter: 

1299 # First get the where clause from query builder 

1300 where_clause, filter_params = self.query_builder.build_where_clause(filter, param_num) 

1301 if where_clause: 

1302 # Convert %s placeholders to $N for asyncpg 

1303 for param in filter_params: 

1304 where_clause = where_clause.replace("%s", f"${param_num}", 1) 

1305 params.append(param) 

1306 param_num += 1 

1307 sql += where_clause 

1308 

1309 # Order by distance and limit 

1310 sql += f""" 

1311 ORDER BY distance 

1312 LIMIT {k} 

1313 """ 

1314 

1315 # Execute query 

1316 async with self._pool.acquire() as conn: 

1317 rows = await conn.fetch(sql, *params) 

1318 

1319 # Convert to VectorSearchResult objects 

1320 results = [] 

1321 for row in rows: 

1322 record = self._row_to_record(row) 

1323 

1324 # Convert distance to similarity score (1 - normalized_distance for cosine) 

1325 distance = float(row['distance']) 

1326 if metric_str == "cosine": 

1327 score = 1.0 - min(distance, 2.0) / 2.0 # Normalize cosine distance [0,2] to similarity [0,1] 

1328 elif metric_str in ["euclidean", "l2"]: 

1329 score = 1.0 / (1.0 + distance) # Convert distance to similarity 

1330 else: 

1331 score = 1.0 - distance # Generic conversion 

1332 

1333 result = VectorSearchResult( 

1334 record=record, 

1335 score=score, 

1336 vector_field=field_name, 

1337 metadata={"distance": distance, "metric": metric_str} 

1338 ) 

1339 results.append(result) 

1340 

1341 return results 

1342 

1343 async def enable_vector_support(self) -> bool: 

1344 """Enable vector support for this database. 

1345  

1346 Returns: 

1347 True if vector support is enabled 

1348 """ 

1349 if self._vector_enabled: 

1350 return True 

1351 

1352 await self._detect_vector_support() 

1353 return self._vector_enabled 

1354 

1355 async def has_vector_support(self) -> bool: 

1356 """Check if this database has vector support enabled. 

1357  

1358 Returns: 

1359 True if vector support is available 

1360 """ 

1361 return self._vector_enabled 

1362 

1363 async def bulk_embed_and_store( 

1364 self, 

1365 records: list[Record], 

1366 text_field: str | list[str], 

1367 vector_field: str, 

1368 embedding_fn: Any | None = None, 

1369 batch_size: int = 100, 

1370 model_name: str | None = None, 

1371 model_version: str | None = None, 

1372 ) -> list[str]: 

1373 """Embed text fields and store vectors with records. 

1374  

1375 This is a placeholder implementation. In a real scenario, you would: 

1376 1. Extract text from the specified fields 

1377 2. Call the embedding function to generate vectors 

1378 3. Store the vectors alongside the records 

1379  

1380 Args: 

1381 records: Records to process 

1382 text_field: Field name(s) containing text to embed 

1383 vector_field: Field name to store vectors in 

1384 embedding_fn: Function to generate embeddings 

1385 batch_size: Number of records to process at once 

1386 model_name: Name of the embedding model 

1387 model_version: Version of the embedding model 

1388  

1389 Returns: 

1390 List of record IDs that were processed 

1391 """ 

1392 if not embedding_fn: 

1393 raise ValueError("embedding_fn is required for bulk_embed_and_store") 

1394 

1395 from ..fields import VectorField 

1396 

1397 processed_ids = [] 

1398 

1399 # Process in batches 

1400 for i in range(0, len(records), batch_size): 

1401 batch = records[i:i + batch_size] 

1402 

1403 # Extract texts 

1404 texts = [] 

1405 for record in batch: 

1406 if isinstance(text_field, list): 

1407 text = " ".join(str(record.fields.get(f, {}).value) for f in text_field if f in record.fields) 

1408 else: 

1409 text = str(record.fields.get(text_field, {}).value) if text_field in record.fields else "" 

1410 texts.append(text) 

1411 

1412 # Generate embeddings 

1413 if texts: 

1414 embeddings = await embedding_fn(texts) 

1415 

1416 # Store vectors with records 

1417 for j, record in enumerate(batch): 

1418 if j < len(embeddings): 

1419 vector = embeddings[j] 

1420 

1421 # Add vector field to record 

1422 record.fields[vector_field] = VectorField( 

1423 name=vector_field, 

1424 value=vector, 

1425 dimensions=len(vector) if hasattr(vector, '__len__') else None, 

1426 source_field=text_field if isinstance(text_field, str) else ",".join(text_field), 

1427 model_name=model_name, 

1428 model_version=model_version, 

1429 ) 

1430 

1431 # Create or update record 

1432 if record.has_storage_id(): 

1433 if record.storage_id is None: 

1434 raise ValueError("Record has_storage_id() returned True but storage_id is None") 

1435 await self.update(record.storage_id, record) 

1436 else: 

1437 record_id = await self.create(record) 

1438 record.storage_id = record_id 

1439 

1440 if record.storage_id is None: 

1441 raise ValueError("Record storage_id is None after create/update") 

1442 processed_ids.append(record.storage_id) 

1443 

1444 return processed_ids 

1445 

1446 async def create_vector_index( 

1447 self, 

1448 vector_field: str, 

1449 dimensions: int, 

1450 metric: DistanceMetric | str = "cosine", 

1451 index_type: str = "ivfflat", 

1452 lists: int | None = None, 

1453 ) -> bool: 

1454 """Create a vector index for efficient similarity search. 

1455  

1456 Args: 

1457 vector_field: Name of the vector field to index 

1458 dimensions: Number of dimensions in the vectors 

1459 metric: Distance metric for the index 

1460 index_type: Type of index (ivfflat, hnsw) 

1461 lists: Number of lists for IVFFlat index 

1462  

1463 Returns: 

1464 True if index was created successfully 

1465 """ 

1466 from .postgres_vector import ( 

1467 build_vector_column_expression, 

1468 build_vector_index_sql, 

1469 get_optimal_index_type, 

1470 get_vector_count_sql, 

1471 ) 

1472 

1473 self._check_connection() 

1474 

1475 if not self._vector_enabled: 

1476 return False 

1477 

1478 # Determine optimal parameters if not provided 

1479 if not lists and index_type == "ivfflat": 

1480 # Count vectors to determine optimal lists 

1481 count_sql = get_vector_count_sql(self.schema_name, self.table_name, vector_field) 

1482 async with self._pool.acquire() as conn: 

1483 count = await conn.fetchval(count_sql) or 0 

1484 _, params = get_optimal_index_type(count) 

1485 lists = params.get("lists", 100) 

1486 

1487 # Convert metric enum to string if needed 

1488 if hasattr(metric, 'value'): 

1489 metric_str = metric.value 

1490 else: 

1491 metric_str = str(metric).lower() 

1492 

1493 # Build vector column expression for index 

1494 column_expr = build_vector_column_expression(vector_field, dimensions, for_index=True) 

1495 

1496 # Build index SQL - pass field_name for proper index naming 

1497 index_sql = build_vector_index_sql( 

1498 table_name=self.table_name, 

1499 schema_name=self.schema_name, 

1500 column_name=column_expr, 

1501 dimensions=dimensions, 

1502 metric=metric_str, 

1503 index_type=index_type, 

1504 index_params={"lists": lists} if lists else None, 

1505 field_name=vector_field 

1506 ) 

1507 

1508 # Create the index 

1509 try: 

1510 logger.debug(f"Creating vector index with SQL: {index_sql}") 

1511 async with self._pool.acquire() as conn: 

1512 await conn.execute(index_sql) 

1513 return True 

1514 except Exception as e: 

1515 logger.warning(f"Failed to create vector index: {e}") 

1516 logger.debug(f"Index SQL was: {index_sql}") 

1517 return False 

1518 

1519 async def drop_vector_index(self, vector_field: str, metric: str = "cosine") -> bool: 

1520 """Drop a vector index. 

1521  

1522 Args: 

1523 vector_field: Name of the vector field 

1524 metric: Distance metric used in the index 

1525  

1526 Returns: 

1527 True if index was dropped successfully 

1528 """ 

1529 from .postgres_vector import get_vector_index_name 

1530 

1531 self._check_connection() 

1532 

1533 index_name = get_vector_index_name(self.table_name, vector_field, metric) 

1534 

1535 try: 

1536 async with self._pool.acquire() as conn: 

1537 await conn.execute(f"DROP INDEX IF EXISTS {self.schema_name}.{index_name}") 

1538 return True 

1539 except Exception as e: 

1540 logger.warning(f"Failed to drop vector index: {e}") 

1541 return False 

1542 

1543 async def get_vector_index_stats(self, vector_field: str) -> dict[str, Any]: 

1544 """Get statistics about a vector field and its index. 

1545  

1546 Args: 

1547 vector_field: Name of the vector field 

1548  

1549 Returns: 

1550 Dictionary with index statistics 

1551 """ 

1552 from .postgres_vector import get_index_check_sql, get_vector_count_sql 

1553 

1554 self._check_connection() 

1555 

1556 stats = { 

1557 "field": vector_field, 

1558 "indexed": False, 

1559 "vector_count": 0, 

1560 } 

1561 

1562 try: 

1563 async with self._pool.acquire() as conn: 

1564 # Count vectors 

1565 count_sql = get_vector_count_sql(self.schema_name, self.table_name, vector_field) 

1566 stats["vector_count"] = await conn.fetchval(count_sql) or 0 

1567 

1568 # Check for index 

1569 index_sql, params = get_index_check_sql(self.schema_name, self.table_name, vector_field) 

1570 stats["indexed"] = await conn.fetchval(index_sql, *params) or False 

1571 except Exception as e: 

1572 logger.warning(f"Failed to get vector index stats: {e}") 

1573 

1574 return stats 

1575 

1576 async def stream_read( 

1577 self, 

1578 query: Query | None = None, 

1579 config: StreamConfig | None = None 

1580 ) -> AsyncIterator[Record]: 

1581 """Stream records from PostgreSQL using cursor.""" 

1582 self._check_connection() 

1583 config = config or StreamConfig() 

1584 

1585 # Build SQL query 

1586 sql = f"SELECT id, data, metadata FROM {self.schema_name}.{self.table_name}" 

1587 params = [] 

1588 

1589 if query and query.filters: 

1590 where_clauses = [] 

1591 param_count = 0 

1592 

1593 for filter in query.filters: 

1594 param_count += 1 

1595 field_path = f"data->>'{filter.field}'" 

1596 

1597 if filter.operator == Operator.EQ: 

1598 where_clauses.append(f"{field_path} = ${param_count}") 

1599 params.append(str(filter.value)) 

1600 

1601 if where_clauses: 

1602 sql += " WHERE " + " AND ".join(where_clauses) 

1603 

1604 # Use cursor for efficient streaming 

1605 async with self._pool.acquire() as conn: 

1606 async with conn.transaction(): 

1607 cursor = await conn.cursor(sql, *params) 

1608 

1609 batch = [] 

1610 async for row in cursor: 

1611 record = self._row_to_record(row) 

1612 if query and query.fields: 

1613 record = record.project(query.fields) 

1614 

1615 batch.append(record) 

1616 

1617 if len(batch) >= config.batch_size: 

1618 for rec in batch: 

1619 yield rec 

1620 batch = [] 

1621 

1622 # Yield remaining records 

1623 for rec in batch: 

1624 yield rec 

1625 

1626 async def stream_write( 

1627 self, 

1628 records: AsyncIterator[Record], 

1629 config: StreamConfig | None = None 

1630 ) -> StreamResult: 

1631 """Stream records into PostgreSQL using batch inserts.""" 

1632 self._check_connection() 

1633 config = config or StreamConfig() 

1634 result = StreamResult() 

1635 start_time = time.time() 

1636 quitting = False 

1637 

1638 batch = [] 

1639 async for record in records: 

1640 batch.append(record) 

1641 

1642 if len(batch) >= config.batch_size: 

1643 # Write batch with graceful fallback 

1644 # Use lambda wrapper for _write_batch 

1645 async def batch_func(b): 

1646 await self._write_batch(b) 

1647 return [r.id for r in b] 

1648 

1649 continue_processing = await async_process_batch_with_fallback( 

1650 batch, 

1651 batch_func, 

1652 self.create, 

1653 result, 

1654 config 

1655 ) 

1656 

1657 if not continue_processing: 

1658 quitting = True 

1659 break 

1660 

1661 batch = [] 

1662 

1663 # Write remaining batch 

1664 if batch and not quitting: 

1665 async def batch_func(b): 

1666 await self._write_batch(b) 

1667 return [r.id for r in b] 

1668 

1669 await async_process_batch_with_fallback( 

1670 batch, 

1671 batch_func, 

1672 self.create, 

1673 result, 

1674 config 

1675 ) 

1676 

1677 result.duration = time.time() - start_time 

1678 return result 

1679 

1680 async def _write_batch(self, records: list[Record]) -> list[str]: 

1681 """Write a batch of records using COPY for performance. 

1682  

1683 Returns: 

1684 List of created record IDs 

1685 """ 

1686 if not records: 

1687 return [] 

1688 

1689 # Prepare data for COPY 

1690 rows = [] 

1691 ids = [] 

1692 for record in records: 

1693 row_data = self._record_to_row(record) 

1694 ids.append(row_data["id"]) 

1695 rows.append(( 

1696 row_data["id"], 

1697 row_data["data"], 

1698 row_data["metadata"] 

1699 )) 

1700 

1701 # Use COPY for efficient bulk insert 

1702 async with self._pool.acquire() as conn: 

1703 await conn.copy_records_to_table( 

1704 f"{self.schema_name}.{self.table_name}", 

1705 records=rows, 

1706 columns=["id", "data", "metadata"] 

1707 ) 

1708 

1709 return ids