Coverage for src/dataknobs_data/backends/postgres.py: 12%

705 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-31 15:06 -0600

1"""PostgreSQL backend implementation with proper connection management and vector support.""" 

2 

3from __future__ import annotations 

4 

5import json 

6import logging 

7import time 

8import uuid 

9from typing import TYPE_CHECKING, Any, cast 

10 

11import asyncpg 

12from dataknobs_config import ConfigurableBase 

13 

14from dataknobs_utils.sql_utils import DotenvPostgresConnector, PostgresDB 

15 

16from ..database import AsyncDatabase, SyncDatabase 

17from ..pooling import ConnectionPoolManager 

18from ..pooling.postgres import PostgresPoolConfig, create_asyncpg_pool, validate_asyncpg_pool 

19from ..query import Operator, Query 

20from ..query_logic import ComplexQuery 

21from ..streaming import ( 

22 StreamConfig, 

23 StreamResult, 

24 async_process_batch_with_fallback, 

25 process_batch_with_fallback, 

26) 

27from ..vector.mixins import VectorOperationsMixin 

28from .postgres_mixins import ( 

29 PostgresBaseConfig, 

30 PostgresConnectionValidator, 

31 PostgresErrorHandler, 

32 PostgresTableManager, 

33 PostgresVectorSupport, 

34) 

35from .sql_base import SQLQueryBuilder, SQLRecordSerializer 

36 

37if TYPE_CHECKING: 

38 import numpy as np 

39 

40 from collections.abc import AsyncIterator, Iterator, Callable, Awaitable 

41 from ..fields import VectorField 

42 from ..records import Record 

43 from ..vector.types import DistanceMetric, VectorSearchResult 

44 

45logger = logging.getLogger(__name__) 

46 

47 

48class SyncPostgresDatabase( 

49 SyncDatabase, 

50 ConfigurableBase, 

51 VectorOperationsMixin, 

52 SQLRecordSerializer, 

53 PostgresBaseConfig, 

54 PostgresTableManager, 

55 PostgresVectorSupport, 

56 PostgresConnectionValidator, 

57 PostgresErrorHandler, 

58): 

59 """Synchronous PostgreSQL database backend with proper connection management.""" 

60 

61 def __init__(self, config: dict[str, Any] | None = None): 

62 """Initialize PostgreSQL database configuration. 

63 

64 Args: 

65 config: Configuration with the following optional keys: 

66 - host: PostgreSQL host (default: from env/localhost) 

67 - port: PostgreSQL port (default: 5432) 

68 - database: Database name (default: from env/postgres) 

69 - user: Username (default: from env/postgres) 

70 - password: Password (default: from env) 

71 - table: Table name (default: "records") 

72 - schema: Schema name (default: "public") 

73 - enable_vector: Enable vector support (default: False) 

74 """ 

75 super().__init__(config) 

76 

77 # Parse configuration using mixin 

78 table_name, schema_name, conn_config = self._parse_postgres_config(config or {}) 

79 self._init_postgres_attributes(table_name, schema_name) 

80 

81 # Store connection config for later use 

82 self._conn_config = conn_config 

83 self.db = None # Will be initialized in connect() 

84 self.query_builder = None # Will be initialized in connect() 

85 

86 @classmethod 

87 def from_config(cls, config: dict) -> SyncPostgresDatabase: 

88 """Create from config dictionary.""" 

89 return cls(config) 

90 

91 def connect(self) -> None: 

92 """Connect to the PostgreSQL database.""" 

93 if self._connected: 

94 return # Already connected 

95 

96 # Initialize query builder with pyformat style for psycopg2 

97 self.query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat") 

98 

99 # Create connection using existing utilities 

100 if not any(key in self._conn_config for key in ["host", "database", "user"]): 

101 # Use dotenv connector for environment-based config 

102 connector = DotenvPostgresConnector() 

103 self.db = PostgresDB(connector) 

104 else: 

105 # Direct configuration - map 'database' to 'db' for PostgresDB 

106 self.db = PostgresDB( 

107 host=self._conn_config.get("host", "localhost"), 

108 db=self._conn_config.get("database", "postgres"), # Note: PostgresDB expects 'db' not 'database' 

109 user=self._conn_config.get("user", "postgres"), 

110 pwd=self._conn_config.get("password"), # Note: PostgresDB expects 'pwd' not 'password' 

111 port=self._conn_config.get("port", 5432), 

112 ) 

113 

114 # Create table if it doesn't exist 

115 self._ensure_table() 

116 

117 # Detect and enable vector support if requested 

118 if self.vector_enabled: 

119 self._detect_vector_support() 

120 

121 self._connected = True 

122 self.log_operation("connect", f"Connected to table: {self.schema_name}.{self.table_name}") 

123 

124 def close(self) -> None: 

125 """Close the database connection.""" 

126 if self.db: 

127 # PostgresDB manages its own connections via context managers 

128 # but we can mark as disconnected 

129 self._connected = False # type: ignore[unreachable] 

130 

131 def _initialize(self) -> None: 

132 """Initialize method - connection setup moved to connect().""" 

133 # Configuration parsing stays here, actual connection in connect() 

134 pass 

135 

136 def _detect_vector_support(self) -> None: 

137 """Detect and enable vector support if pgvector is available.""" 

138 from .postgres_vector import check_pgvector_extension_sync, install_pgvector_extension_sync 

139 

140 try: 

141 # Check if pgvector is installed 

142 if check_pgvector_extension_sync(self.db): 

143 self._vector_enabled = True 

144 logger.info("pgvector extension detected and enabled") 

145 else: 

146 # Try to install it 

147 if install_pgvector_extension_sync(self.db): 

148 self._vector_enabled = True 

149 logger.info("pgvector extension installed and enabled") 

150 else: 

151 logger.debug("pgvector extension not available") 

152 except Exception as e: 

153 logger.debug(f"Could not enable vector support: {e}") 

154 self._vector_enabled = False 

155 

156 def _ensure_table(self) -> None: 

157 """Ensure the records table exists.""" 

158 if not self.db: 

159 raise RuntimeError("Database not connected. Call connect() first.") 

160 

161 create_table_sql = self.get_create_table_sql(self.schema_name, self.table_name) # type: ignore[unreachable] 

162 self.db.execute(create_table_sql) 

163 

164 

165 def _record_to_row(self, record: Record, id: str | None = None) -> dict[str, Any]: 

166 """Convert a Record to a database row.""" 

167 return { 

168 "id": id or str(uuid.uuid4()), 

169 "data": self.record_to_json(record), 

170 "metadata": json.dumps(record.metadata) if record.metadata else None, 

171 } 

172 

173 def _row_to_record(self, row: dict[str, Any]) -> Record: 

174 """Convert a database row to a Record.""" 

175 return self.row_to_record(row) 

176 

177 def create(self, record: Record) -> str: 

178 """Create a new record.""" 

179 self._check_connection() 

180 # Use record's ID if it has one, otherwise generate a new one 

181 id = record.id if record.id else str(uuid.uuid4()) 

182 row = self._record_to_row(record, id) 

183 

184 sql = f""" 

185 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata) 

186 VALUES (%(id)s, %(data)s, %(metadata)s) 

187 """ 

188 self.db.execute(sql, row) 

189 return id 

190 

191 def read(self, id: str) -> Record | None: 

192 """Read a record by ID.""" 

193 self._check_connection() 

194 sql = f""" 

195 SELECT id, data, metadata 

196 FROM {self.schema_name}.{self.table_name} 

197 WHERE id = %(id)s 

198 """ 

199 df = self.db.query(sql, {"id": id}) 

200 

201 if df.empty: 

202 return None 

203 

204 row = df.iloc[0].to_dict() 

205 return self._row_to_record(row) 

206 

207 def update(self, id: str, record: Record) -> bool: 

208 """Update an existing record.""" 

209 self._check_connection() 

210 row = self._record_to_row(record, id) 

211 

212 sql = f""" 

213 UPDATE {self.schema_name}.{self.table_name} 

214 SET data = %(data)s, metadata = %(metadata)s, updated_at = CURRENT_TIMESTAMP 

215 WHERE id = %(id)s 

216 """ 

217 result = self.db.execute(sql, row) 

218 # PostgresDB.execute returns number of affected rows 

219 return result > 0 if isinstance(result, int) else False 

220 

221 def delete(self, id: str) -> bool: 

222 """Delete a record by ID.""" 

223 self._check_connection() 

224 sql = f""" 

225 DELETE FROM {self.schema_name}.{self.table_name} 

226 WHERE id = %(id)s 

227 """ 

228 result = self.db.execute(sql, {"id": id}) 

229 return result > 0 if isinstance(result, int) else False 

230 

231 def exists(self, id: str) -> bool: 

232 """Check if a record exists.""" 

233 self._check_connection() 

234 sql = f""" 

235 SELECT 1 FROM {self.schema_name}.{self.table_name} 

236 WHERE id = %(id)s 

237 LIMIT 1 

238 """ 

239 df = self.db.query(sql, {"id": id}) 

240 return not df.empty 

241 

242 def upsert(self, id: str, record: Record) -> str: 

243 """Update or insert a record with a specific ID.""" 

244 self._check_connection() 

245 if self.exists(id): 

246 self.update(id, record) 

247 else: 

248 # Insert with specific ID 

249 row = self._record_to_row(record, id) 

250 sql = f""" 

251 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata) 

252 VALUES (%(id)s, %(data)s, %(metadata)s) 

253 """ 

254 self.db.execute(sql, row) 

255 return id 

256 

257 def search(self, query: Query | ComplexQuery) -> list[Record]: 

258 """Search for records matching the query.""" 

259 self._check_connection() 

260 

261 # Handle ComplexQuery with native SQL support 

262 if isinstance(query, ComplexQuery): 

263 sql_query, params_list = self.query_builder.build_complex_search_query(query) 

264 else: 

265 sql_query, params_list = self.query_builder.build_search_query(query) 

266 

267 # Build params dict for psycopg2 

268 # The query builder now generates %(p0)s style placeholders directly 

269 params_dict = {} 

270 if params_list: 

271 for i, param in enumerate(params_list): 

272 params_dict[f"p{i}"] = param 

273 

274 # Execute query 

275 df = self.db.query(sql_query, params_dict) 

276 

277 # Convert to records 

278 records = [] 

279 for _, row in df.iterrows(): 

280 record = self._row_to_record(row.to_dict()) 

281 

282 # Apply field projection if specified 

283 if query.fields: 

284 record = record.project(query.fields) 

285 

286 records.append(record) 

287 

288 return records 

289 

290 def _count_all(self) -> int: 

291 """Count all records in the database.""" 

292 self._check_connection() 

293 sql = f"SELECT COUNT(*) as count FROM {self.schema_name}.{self.table_name}" 

294 df = self.db.query(sql) 

295 return int(df.iloc[0]["count"]) if not df.empty else 0 

296 

297 def clear(self) -> int: 

298 """Clear all records from the database.""" 

299 self._check_connection() 

300 # Get count first 

301 count = self._count_all() 

302 

303 # Delete all records 

304 sql = f"TRUNCATE TABLE {self.schema_name}.{self.table_name}" 

305 self.db.execute(sql) 

306 

307 return count 

308 

309 def create_batch(self, records: list[Record]) -> list[str]: 

310 """Create multiple records efficiently using a single query. 

311  

312 Uses multi-value INSERT for better performance. 

313  

314 Args: 

315 records: List of records to create 

316  

317 Returns: 

318 List of created record IDs 

319 """ 

320 if not records: 

321 return [] 

322 

323 self._check_connection() 

324 

325 # Create a query builder for PostgreSQL with pyformat style 

326 from .sql_base import SQLQueryBuilder 

327 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat") 

328 

329 # Use the shared batch create query builder 

330 query, params_list, ids = query_builder.build_batch_create_query(records) 

331 

332 # Build params dict for psycopg2 

333 params_dict = {} 

334 for i, param in enumerate(params_list): 

335 params_dict[f"p{i}"] = param 

336 

337 # Execute the batch insert and get returned IDs 

338 result_df = self.db.query(query, params_dict) 

339 

340 # PostgreSQL RETURNING clause gives us the actual inserted IDs 

341 if not result_df.empty: 

342 return result_df['id'].tolist() 

343 return ids 

344 

345 def delete_batch(self, ids: list[str]) -> list[bool]: 

346 """Delete multiple records efficiently using a single query. 

347  

348 Uses single DELETE with IN clause for better performance. 

349  

350 Args: 

351 ids: List of record IDs to delete 

352  

353 Returns: 

354 List of success flags for each deletion 

355 """ 

356 if not ids: 

357 return [] 

358 

359 self._check_connection() 

360 

361 # Create a query builder for PostgreSQL with pyformat style 

362 from .sql_base import SQLQueryBuilder 

363 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat") 

364 

365 # Use the shared batch delete query builder (includes RETURNING clause) 

366 query, params_list = query_builder.build_batch_delete_query(ids) 

367 

368 # Build params dict for psycopg2 

369 params_dict = {} 

370 for i, param in enumerate(params_list): 

371 params_dict[f"p{i}"] = param 

372 

373 # Execute the batch delete and get returned IDs 

374 result_df = self.db.query(query, params_dict) 

375 

376 # Get list of deleted IDs from RETURNING clause 

377 deleted_ids = set(result_df['id'].tolist()) if not result_df.empty else set() 

378 

379 # Return results based on which IDs were actually deleted 

380 results = [] 

381 for id in ids: 

382 results.append(id in deleted_ids) 

383 

384 return results 

385 

386 def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]: 

387 """Update multiple records efficiently using a single query. 

388  

389 Uses PostgreSQL's CASE expressions for batch updates via shared SQL builder. 

390  

391 Args: 

392 updates: List of (id, record) tuples to update 

393  

394 Returns: 

395 List of success flags for each update 

396 """ 

397 if not updates: 

398 return [] 

399 

400 self._check_connection() 

401 

402 # Create a query builder for PostgreSQL with pyformat style 

403 from .sql_base import SQLQueryBuilder 

404 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat") 

405 

406 # Use the shared batch update query builder 

407 query, params_list = query_builder.build_batch_update_query(updates) 

408 

409 # Build params dict for psycopg2 

410 params_dict = {} 

411 for i, param in enumerate(params_list): 

412 params_dict[f"p{i}"] = param 

413 

414 # Execute the batch update and get returned IDs (query now includes RETURNING clause) 

415 result_df = self.db.query(query, params_dict) 

416 

417 # Get list of updated IDs from RETURNING clause 

418 updated_ids = set(result_df['id'].tolist()) if not result_df.empty else set() 

419 

420 results = [] 

421 for record_id, _ in updates: 

422 results.append(record_id in updated_ids) 

423 

424 return results 

425 

426 def stream_read( 

427 self, 

428 query: Query | None = None, 

429 config: StreamConfig | None = None 

430 ) -> Iterator[Record]: 

431 """Stream records from PostgreSQL.""" 

432 self._check_connection() 

433 config = config or StreamConfig() 

434 

435 # Build SQL query 

436 sql = f"SELECT id, data, metadata FROM {self.schema_name}.{self.table_name}" 

437 params = {} 

438 

439 if query and query.filters: 

440 # Add WHERE clause (simplified for now) 

441 where_clauses = [] 

442 for i, filter in enumerate(query.filters): 

443 field_path = f"data->>'{filter.field}'" 

444 param_name = f"param_{i}" 

445 

446 if filter.operator == Operator.EQ: 

447 where_clauses.append(f"{field_path} = %({param_name})s") 

448 params[param_name] = str(filter.value) 

449 

450 if where_clauses: 

451 sql += " WHERE " + " AND ".join(where_clauses) 

452 

453 # Use cursor for streaming 

454 # Note: PostgresDB may need modification to support cursors 

455 # For now, we'll fetch in batches 

456 sql += f" LIMIT {config.batch_size} OFFSET %(offset)s" 

457 

458 offset = 0 

459 while True: 

460 params["offset"] = offset 

461 df = self.db.query(sql, params) 

462 

463 if df.empty: 

464 break 

465 

466 for _, row in df.iterrows(): 

467 record = self._row_to_record(row.to_dict()) 

468 if query and query.fields: 

469 record = record.project(query.fields) 

470 yield record 

471 

472 offset += config.batch_size 

473 

474 # If we got less than batch_size, we're done 

475 if len(df) < config.batch_size: 

476 break 

477 

478 def stream_write( 

479 self, 

480 records: Iterator[Record], 

481 config: StreamConfig | None = None 

482 ) -> StreamResult: 

483 """Stream records into PostgreSQL.""" 

484 self._check_connection() 

485 config = config or StreamConfig() 

486 result = StreamResult() 

487 start_time = time.time() 

488 quitting = False 

489 

490 batch = [] 

491 for record in records: 

492 batch.append(record) 

493 

494 if len(batch) >= config.batch_size: 

495 # Write batch with graceful fallback 

496 # Use lambda wrapper for _write_batch 

497 continue_processing = process_batch_with_fallback( 

498 batch, 

499 lambda b: self._write_batch(b), 

500 self.create, 

501 result, 

502 config 

503 ) 

504 

505 if not continue_processing: 

506 quitting = True 

507 break 

508 

509 batch = [] 

510 

511 # Write remaining batch 

512 if batch and not quitting: 

513 process_batch_with_fallback( 

514 batch, 

515 lambda b: self._write_batch(b), 

516 self.create, 

517 result, 

518 config 

519 ) 

520 

521 result.duration = time.time() - start_time 

522 return result 

523 

524 def _write_batch(self, records: list[Record]) -> list[str]: 

525 """Write a batch of records to the database. 

526  

527 Returns: 

528 List of created record IDs 

529 """ 

530 # Build batch insert SQL 

531 values = [] 

532 params = {} 

533 ids = [] 

534 

535 for i, record in enumerate(records): 

536 id = str(uuid.uuid4()) 

537 ids.append(id) 

538 row = self._record_to_row(record, id) 

539 values.append(f"(%(id_{i})s, %(data_{i})s, %(metadata_{i})s)") 

540 params[f"id_{i}"] = row["id"] 

541 params[f"data_{i}"] = row["data"] 

542 params[f"metadata_{i}"] = row["metadata"] 

543 

544 sql = f""" 

545 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata) 

546 VALUES {', '.join(values)} 

547 """ 

548 self.db.execute(sql, params) 

549 return ids 

550 

551 def vector_search( 

552 self, 

553 query_vector: np.ndarray | list[float] | VectorField, 

554 field_name: str, 

555 k: int = 10, 

556 filter: Query | None = None, 

557 metric: DistanceMetric | str = "cosine" 

558 ) -> list[VectorSearchResult]: 

559 """Search for similar vectors using PostgreSQL pgvector. 

560  

561 Args: 

562 query_vector: Query vector (numpy array, list, or VectorField) 

563 field_name: Name of vector field to search (must be in data JSON) 

564 limit: Maximum number of results 

565 filters: Optional filters to apply 

566 metric: Distance metric to use (cosine, euclidean, l2, inner_product) 

567  

568 Returns: 

569 List of VectorSearchResult objects ordered by similarity 

570 """ 

571 if not self._vector_enabled: 

572 raise RuntimeError("Vector search not available - pgvector not installed") 

573 

574 self._check_connection() 

575 

576 from ..fields import VectorField 

577 from ..vector.types import DistanceMetric, VectorSearchResult 

578 from .postgres_vector import format_vector_for_postgres, get_vector_operator 

579 

580 # Convert query vector to proper format 

581 if isinstance(query_vector, VectorField): 

582 vector_str = format_vector_for_postgres(query_vector.value) 

583 else: 

584 vector_str = format_vector_for_postgres(query_vector) 

585 

586 # Get the appropriate operator 

587 if isinstance(metric, DistanceMetric): 

588 metric_str = metric.value 

589 else: 

590 metric_str = str(metric).lower() 

591 

592 operator = get_vector_operator(metric_str) 

593 

594 # Build the query - vectors are stored in JSON data field 

595 # Use centralized vector extraction logic 

596 vector_expr = self.get_vector_extraction_sql(field_name, dialect="postgres") 

597 

598 # Build the base SQL with pyformat placeholders 

599 sql = f""" 

600 SELECT  

601 id,  

602 data, 

603 metadata, 

604 {vector_expr} {operator} %(p0)s::vector AS distance 

605 FROM {self.schema_name}.{self.table_name} 

606 WHERE data ? %(p1)s -- Check field exists 

607 """ 

608 

609 params: list[Any] = [vector_str, field_name] 

610 

611 # Add filters if provided using the query builder 

612 if filter: 

613 # Query builder will generate pyformat placeholders since we configured it that way 

614 where_clause, filter_params = self.query_builder.build_where_clause(filter, len(params) + 1) 

615 if where_clause: 

616 sql += where_clause 

617 params.extend(filter_params) 

618 

619 # Order by distance and limit 

620 next_param = len(params) 

621 sql += f" ORDER BY distance LIMIT %(p{next_param})s" 

622 params.append(k) 

623 

624 # Build param dict for psycopg2 

625 param_dict = {} 

626 for i, param in enumerate(params): 

627 param_dict[f"p{i}"] = param 

628 

629 df = self.db.query(sql, param_dict) 

630 

631 # Convert results 

632 results = [] 

633 for _, row in df.iterrows(): 

634 record = self._row_to_record(row) 

635 

636 # Calculate similarity score from distance 

637 distance = row["distance"] 

638 if metric_str in ["cosine", "cosine_similarity"]: 

639 score = 1.0 - distance # Cosine distance to similarity 

640 elif metric_str in ["euclidean", "l2"]: 

641 score = 1.0 / (1.0 + distance) # Convert distance to similarity 

642 elif metric_str in ["inner_product", "dot_product"]: 

643 score = -distance # Negative because pgvector uses negative for descending 

644 else: 

645 score = -distance # Default: lower distance = better 

646 

647 result = VectorSearchResult( 

648 record=record, 

649 score=float(score), 

650 vector_field=field_name 

651 ) 

652 results.append(result) 

653 

654 return results 

655 

656 def has_vector_support(self) -> bool: 

657 """Check if this database has vector support enabled. 

658  

659 Returns: 

660 True if vector operations are supported 

661 """ 

662 return self._vector_enabled 

663 

664 def enable_vector_support(self) -> bool: 

665 """Enable vector support for this database if possible. 

666  

667 Returns: 

668 True if vector support is now enabled 

669 """ 

670 if self._vector_enabled: 

671 return True 

672 

673 self._detect_vector_support() 

674 return self._vector_enabled 

675 

676 def bulk_embed_and_store( 

677 self, 

678 records: list[Record], 

679 text_field: str | list[str], 

680 vector_field: str = "embedding", 

681 embedding_fn: Any = None, 

682 batch_size: int = 100, 

683 model_name: str | None = None, 

684 model_version: str | None = None, 

685 ) -> list[str]: 

686 """Embed text fields and store vectors with records (stub for abstract requirement). 

687  

688 This is a placeholder implementation to satisfy the abstract method requirement. 

689 Full implementation would require actual embedding function. 

690 """ 

691 raise NotImplementedError("bulk_embed_and_store requires an embedding function") 

692 

693 

694# Global pool manager instance for async PostgreSQL connections 

695_pool_manager = ConnectionPoolManager[asyncpg.Pool]() 

696 

697 

698class AsyncPostgresDatabase( 

699 AsyncDatabase, 

700 VectorOperationsMixin, 

701 ConfigurableBase, 

702 PostgresBaseConfig, 

703 PostgresTableManager, 

704 PostgresVectorSupport, 

705 PostgresConnectionValidator, 

706 PostgresErrorHandler, 

707): 

708 """Native async PostgreSQL database backend with vector support and event loop-aware connection pooling.""" 

709 

710 def __init__(self, config: dict[str, Any] | None = None): 

711 """Initialize async PostgreSQL database.""" 

712 super().__init__(config) 

713 

714 # Parse configuration using mixin 

715 table_name, schema_name, conn_config = self._parse_postgres_config(config or {}) 

716 self._init_postgres_attributes(table_name, schema_name) 

717 

718 # Extract pool configuration 

719 self._pool_config = PostgresPoolConfig.from_dict(conn_config) 

720 self._pool: asyncpg.Pool | None = None 

721 

722 @classmethod 

723 def from_config(cls, config: dict) -> AsyncPostgresDatabase: 

724 """Create from config dictionary.""" 

725 return cls(config) 

726 

727 async def connect(self) -> None: 

728 """Connect to the database.""" 

729 if self._connected: 

730 return 

731 

732 # Get or create pool for current event loop 

733 from ..pooling import BasePoolConfig 

734 self._pool = await _pool_manager.get_pool( 

735 self._pool_config, 

736 cast("Callable[[BasePoolConfig], Awaitable[Any]]", create_asyncpg_pool), 

737 validate_asyncpg_pool 

738 ) 

739 

740 # Initialize query builder 

741 self.query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres") 

742 

743 # Ensure table exists 

744 await self._ensure_table() 

745 

746 # Check and enable vector support if requested 

747 if self.vector_enabled: 

748 await self._detect_vector_support() 

749 

750 self._connected = True 

751 self.log_operation("connect", f"Connected to table: {self.schema_name}.{self.table_name}") 

752 

753 async def close(self) -> None: 

754 """Close the database connection and properly close the pool.""" 

755 if self._connected: 

756 # Properly close the pool if we have one 

757 if self._pool: 

758 try: 

759 await self._pool.close() 

760 except Exception as e: 

761 logger.warning(f"Error closing connection pool: {e}") 

762 self._pool = None 

763 self._connected = False 

764 

765 def _initialize(self) -> None: 

766 """Initialize is handled in connect.""" 

767 pass 

768 

769 async def _ensure_table(self) -> None: 

770 """Ensure the records table exists.""" 

771 if not self._pool: 

772 raise RuntimeError("Database not connected. Call connect() first.") 

773 

774 create_table_sql = self.get_create_table_sql(self.schema_name, self.table_name) 

775 

776 async with self._pool.acquire() as conn: 

777 await conn.execute(create_table_sql) 

778 

779 async def _detect_vector_support(self) -> None: 

780 """Detect and enable vector support if pgvector is available.""" 

781 from .postgres_vector import check_pgvector_extension, install_pgvector_extension 

782 

783 async with self._pool.acquire() as conn: 

784 # Check if pgvector is available 

785 if await check_pgvector_extension(conn): 

786 self._vector_enabled = True 

787 logger.info("pgvector extension detected and enabled") 

788 else: 

789 # Try to install it 

790 if await install_pgvector_extension(conn): 

791 self._vector_enabled = True 

792 logger.info("pgvector extension installed and enabled") 

793 else: 

794 logger.debug("pgvector extension not available") 

795 

796 async def _ensure_vector_column(self, field_name: str, dimensions: int) -> None: 

797 """Ensure a vector column exists for the given field. 

798  

799 Args: 

800 field_name: Name of the vector field 

801 dimensions: Number of dimensions 

802 """ 

803 if not self._vector_enabled: 

804 return 

805 

806 column_name = f"vector_{field_name}" 

807 

808 # Check if column already exists 

809 check_sql = """ 

810 SELECT column_name FROM information_schema.columns 

811 WHERE table_schema = $1 AND table_name = $2 AND column_name = $3 

812 """ 

813 

814 async with self._pool.acquire() as conn: 

815 existing = await conn.fetchval(check_sql, self.schema_name, self.table_name, column_name) 

816 

817 if not existing: 

818 # Add vector column 

819 alter_sql = f""" 

820 ALTER TABLE {self.schema_name}.{self.table_name} 

821 ADD COLUMN IF NOT EXISTS {column_name} vector({dimensions}) 

822 """ 

823 try: 

824 await conn.execute(alter_sql) 

825 self._vector_dimensions[field_name] = dimensions 

826 logger.info(f"Added vector column {column_name} with {dimensions} dimensions") 

827 

828 # Create index for the vector column 

829 from .postgres_vector import build_vector_index_sql, get_optimal_index_type 

830 

831 # Get row count for optimal index selection 

832 count_sql = f"SELECT COUNT(*) FROM {self.schema_name}.{self.table_name}" 

833 count = await conn.fetchval(count_sql) 

834 

835 index_type, index_params = get_optimal_index_type(count) 

836 index_sql = build_vector_index_sql( 

837 self.table_name, 

838 self.schema_name, 

839 column_name, 

840 dimensions, 

841 metric="cosine", 

842 index_type=index_type, 

843 index_params=index_params 

844 ) 

845 

846 # Note: IVFFlat requires table to have data before creating index 

847 if count > 0 or index_type != "ivfflat": 

848 await conn.execute(index_sql) 

849 logger.info(f"Created {index_type} index for {column_name}") 

850 

851 except Exception as e: 

852 logger.warning(f"Could not create vector column {column_name}: {e}") 

853 else: 

854 self._vector_dimensions[field_name] = dimensions 

855 

856 def _check_connection(self) -> None: 

857 """Check if async database is connected.""" 

858 self._check_async_connection() 

859 

860 def _record_to_row(self, record: Record, id: str | None = None) -> dict[str, Any]: 

861 """Convert a Record to a database row using common serializer.""" 

862 from .sql_base import SQLRecordSerializer 

863 

864 return { 

865 "id": id or str(uuid.uuid4()), 

866 "data": SQLRecordSerializer.record_to_json(record), 

867 "metadata": json.dumps(record.metadata) if record.metadata else None, 

868 } 

869 

870 def _row_to_record(self, row: asyncpg.Record) -> Record: 

871 """Convert a database row to a Record using the common serializer.""" 

872 from .sql_base import SQLRecordSerializer 

873 

874 # Convert asyncpg.Record to dict format expected by SQLRecordSerializer 

875 data_json = row.get("data", {}) 

876 if not isinstance(data_json, str): 

877 data_json = json.dumps(data_json) 

878 

879 metadata_json = row.get("metadata") 

880 if metadata_json and not isinstance(metadata_json, str): 

881 metadata_json = json.dumps(metadata_json) 

882 

883 # Use the common serializer to reconstruct the record 

884 return SQLRecordSerializer.json_to_record(data_json, metadata_json) 

885 

886 async def create(self, record: Record) -> str: 

887 """Create a new record with vector support.""" 

888 self._check_connection() 

889 

890 # Check for vector fields and ensure columns exist 

891 from ..fields import VectorField 

892 for field_name, field_obj in record.fields.items(): 

893 if isinstance(field_obj, VectorField) and self._vector_enabled: 

894 await self._ensure_vector_column(field_name, field_obj.dimensions) 

895 

896 # Use record's ID if it has one, otherwise generate a new one 

897 id = record.id if record.id else str(uuid.uuid4()) 

898 row = self._record_to_row(record, id) 

899 

900 # Build dynamic SQL based on vector columns present 

901 columns = ["id", "data", "metadata"] 

902 values = [row["id"], row["data"], row["metadata"]] 

903 placeholders = ["$1", "$2", "$3"] 

904 

905 # Add vector columns 

906 param_num = 4 

907 for key, value in row.items(): 

908 if key.startswith("vector_"): 

909 columns.append(key) 

910 values.append(value) 

911 placeholders.append(f"${param_num}") 

912 param_num += 1 

913 

914 sql = f""" 

915 INSERT INTO {self.schema_name}.{self.table_name} ({', '.join(columns)}) 

916 VALUES ({', '.join(placeholders)}) 

917 """ 

918 

919 async with self._pool.acquire() as conn: 

920 await conn.execute(sql, *values) 

921 

922 return id 

923 

924 async def read(self, id: str) -> Record | None: 

925 """Read a record by ID.""" 

926 self._check_connection() 

927 sql = f""" 

928 SELECT id, data, metadata 

929 FROM {self.schema_name}.{self.table_name} 

930 WHERE id = $1 

931 """ 

932 

933 async with self._pool.acquire() as conn: 

934 row = await conn.fetchrow(sql, id) 

935 

936 if not row: 

937 return None 

938 

939 return self._row_to_record(row) 

940 

941 async def update(self, id: str, record: Record) -> bool: 

942 """Update an existing record.""" 

943 self._check_connection() 

944 row = self._record_to_row(record, id) 

945 

946 sql = f""" 

947 UPDATE {self.schema_name}.{self.table_name} 

948 SET data = $2, metadata = $3, updated_at = CURRENT_TIMESTAMP 

949 WHERE id = $1 

950 """ 

951 

952 async with self._pool.acquire() as conn: 

953 result = await conn.execute(sql, row["id"], row["data"], row["metadata"]) 

954 

955 # Returns UPDATE n where n is rows affected 

956 return result.split()[-1] != "0" 

957 

958 async def delete(self, id: str) -> bool: 

959 """Delete a record by ID.""" 

960 self._check_connection() 

961 sql = f""" 

962 DELETE FROM {self.schema_name}.{self.table_name} 

963 WHERE id = $1 

964 """ 

965 

966 async with self._pool.acquire() as conn: 

967 result = await conn.execute(sql, id) 

968 

969 # Returns DELETE n where n is rows affected 

970 return result.split()[-1] != "0" 

971 

972 async def exists(self, id: str) -> bool: 

973 """Check if a record exists.""" 

974 self._check_connection() 

975 sql = f""" 

976 SELECT 1 FROM {self.schema_name}.{self.table_name} 

977 WHERE id = $1 

978 LIMIT 1 

979 """ 

980 

981 async with self._pool.acquire() as conn: 

982 row = await conn.fetchrow(sql, id) 

983 

984 return row is not None 

985 

986 async def upsert(self, id: str, record: Record) -> str: 

987 """Update or insert a record with a specific ID.""" 

988 self._check_connection() 

989 row = self._record_to_row(record, id) 

990 

991 sql = f""" 

992 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata) 

993 VALUES ($1, $2, $3) 

994 ON CONFLICT (id) DO UPDATE 

995 SET data = EXCLUDED.data, metadata = EXCLUDED.metadata, updated_at = CURRENT_TIMESTAMP 

996 """ 

997 

998 async with self._pool.acquire() as conn: 

999 await conn.execute(sql, row["id"], row["data"], row["metadata"]) 

1000 

1001 return id 

1002 

1003 async def search(self, query: Query | ComplexQuery) -> list[Record]: 

1004 """Search for records matching the query.""" 

1005 self._check_connection() 

1006 

1007 # Initialize query builder if not already done 

1008 if not hasattr(self, 'query_builder'): 

1009 self.query_builder = SQLQueryBuilder( 

1010 self.table_name, self.schema_name, dialect="postgres" 

1011 ) 

1012 

1013 # Handle ComplexQuery with native SQL support 

1014 if isinstance(query, ComplexQuery): 

1015 sql, params = self.query_builder.build_complex_search_query(query) 

1016 else: 

1017 sql, params = self.query_builder.build_search_query(query) 

1018 

1019 # Execute query with asyncpg (already uses positional parameters) 

1020 async with self._pool.acquire() as conn: 

1021 rows = await conn.fetch(sql, *params) 

1022 

1023 # Convert to records 

1024 records = [] 

1025 for row in rows: 

1026 record = self._row_to_record(row) 

1027 

1028 # Apply field projection if specified 

1029 if query.fields: 

1030 record = record.project(query.fields) 

1031 

1032 records.append(record) 

1033 

1034 return records 

1035 

1036 async def _count_all(self) -> int: 

1037 """Count all records in the database.""" 

1038 self._check_connection() 

1039 sql = f"SELECT COUNT(*) as count FROM {self.schema_name}.{self.table_name}" 

1040 

1041 async with self._pool.acquire() as conn: 

1042 row = await conn.fetchrow(sql) 

1043 

1044 return row["count"] if row else 0 

1045 

1046 async def clear(self) -> int: 

1047 """Clear all records from the database.""" 

1048 self._check_connection() 

1049 # Get count first 

1050 count = await self._count_all() 

1051 

1052 # Delete all records 

1053 sql = f"TRUNCATE TABLE {self.schema_name}.{self.table_name}" 

1054 

1055 async with self._pool.acquire() as conn: 

1056 await conn.execute(sql) 

1057 

1058 return count 

1059 

1060 async def create_batch(self, records: list[Record]) -> list[str]: 

1061 """Create multiple records efficiently using a single query. 

1062  

1063 Uses multi-value INSERT with RETURNING for better performance. 

1064  

1065 Args: 

1066 records: List of records to create 

1067  

1068 Returns: 

1069 List of created record IDs 

1070 """ 

1071 if not records: 

1072 return [] 

1073 

1074 self._check_connection() 

1075 

1076 # Create a query builder for PostgreSQL 

1077 from .sql_base import SQLQueryBuilder 

1078 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres") 

1079 

1080 # Use the shared batch create query builder 

1081 query, params, ids = query_builder.build_batch_create_query(records) 

1082 

1083 # Execute the batch insert with RETURNING 

1084 async with self._pool.acquire() as conn: 

1085 rows = await conn.fetch(query, *params) 

1086 

1087 # Return the actual inserted IDs from RETURNING clause 

1088 if rows: 

1089 return [row["id"] for row in rows] 

1090 return ids # Fallback to generated IDs 

1091 

1092 async def delete_batch(self, ids: list[str]) -> list[bool]: 

1093 """Delete multiple records efficiently using a single query. 

1094  

1095 Uses single DELETE with IN clause and RETURNING for verification. 

1096  

1097 Args: 

1098 ids: List of record IDs to delete 

1099  

1100 Returns: 

1101 List of success flags for each deletion 

1102 """ 

1103 if not ids: 

1104 return [] 

1105 

1106 self._check_connection() 

1107 

1108 # Create a query builder for PostgreSQL 

1109 from .sql_base import SQLQueryBuilder 

1110 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres") 

1111 

1112 # Use the shared batch delete query builder 

1113 query, params = query_builder.build_batch_delete_query(ids) 

1114 

1115 # Execute the batch delete with RETURNING 

1116 async with self._pool.acquire() as conn: 

1117 rows = await conn.fetch(query, *params) 

1118 

1119 # Convert returned rows to set of deleted IDs 

1120 deleted_ids = {row["id"] for row in rows} 

1121 

1122 # Return results for each deletion 

1123 results = [] 

1124 for id in ids: 

1125 results.append(id in deleted_ids) 

1126 

1127 return results 

1128 

1129 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]: 

1130 """Update multiple records efficiently using a single query. 

1131  

1132 Uses PostgreSQL's CASE expressions for batch updates with native asyncpg. 

1133  

1134 Args: 

1135 updates: List of (id, record) tuples to update 

1136  

1137 Returns: 

1138 List of success flags for each update 

1139 """ 

1140 if not updates: 

1141 return [] 

1142 

1143 self._check_connection() 

1144 

1145 # Create a query builder for PostgreSQL 

1146 from .sql_base import SQLQueryBuilder 

1147 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres") 

1148 

1149 # Use the shared batch update query builder 

1150 # It already produces positional parameters ($1, $2) for PostgreSQL 

1151 query, params = query_builder.build_batch_update_query(updates) 

1152 

1153 # Add RETURNING clause for PostgreSQL to get updated IDs 

1154 query = query.rstrip() + " RETURNING id" 

1155 

1156 # Execute the batch update 

1157 async with self._pool.acquire() as conn: 

1158 rows = await conn.fetch(query, *params) 

1159 

1160 # Convert returned rows to set of updated IDs 

1161 updated_ids = {row["id"] for row in rows} 

1162 

1163 # Return results for each update 

1164 results = [] 

1165 for record_id, _ in updates: 

1166 results.append(record_id in updated_ids) 

1167 

1168 return results 

1169 

1170 async def vector_search( 

1171 self, 

1172 query_vector: np.ndarray | list[float] | VectorField, 

1173 field_name: str, 

1174 k: int = 10, 

1175 filter: Query | None = None, 

1176 metric: DistanceMetric | str = "cosine" 

1177 ) -> list[VectorSearchResult]: 

1178 """Search for similar vectors using PostgreSQL pgvector. 

1179  

1180 Args: 

1181 query_vector: Query vector (numpy array, list, or VectorField) 

1182 field_name: Name of vector field to search 

1183 limit: Maximum number of results 

1184 filters: Optional filters to apply 

1185 metric: Distance metric to use 

1186  

1187 Returns: 

1188 List of VectorSearchResult objects 

1189 """ 

1190 if not self._vector_enabled: 

1191 raise RuntimeError("Vector search not available - pgvector not installed") 

1192 

1193 self._check_connection() 

1194 

1195 from ..fields import VectorField 

1196 from ..vector.types import DistanceMetric, VectorSearchResult 

1197 from .postgres_vector import format_vector_for_postgres, get_vector_operator 

1198 

1199 # Convert query vector to proper format 

1200 if isinstance(query_vector, VectorField): 

1201 vector_str = format_vector_for_postgres(query_vector.value) 

1202 else: 

1203 vector_str = format_vector_for_postgres(query_vector) 

1204 

1205 # Get the appropriate operator 

1206 if isinstance(metric, DistanceMetric): 

1207 metric_str = metric.value 

1208 else: 

1209 metric_str = str(metric).lower() 

1210 operator = get_vector_operator(metric_str) 

1211 

1212 vector_column = f"vector_{field_name}" 

1213 

1214 # Build query 

1215 sql = f""" 

1216 SELECT id, data, metadata, {vector_column}, 

1217 {vector_column} {operator} $1::vector AS distance 

1218 FROM {self.schema_name}.{self.table_name} 

1219 WHERE {vector_column} IS NOT NULL 

1220 """ 

1221 

1222 params = [vector_str] 

1223 param_num = 2 

1224 

1225 # Add filters if provided using the query builder 

1226 if filter: 

1227 # First get the where clause from query builder 

1228 where_clause, filter_params = self.query_builder.build_where_clause(filter, param_num) 

1229 if where_clause: 

1230 # Convert %s placeholders to $N for asyncpg 

1231 for param in filter_params: 

1232 where_clause = where_clause.replace("%s", f"${param_num}", 1) 

1233 params.append(param) 

1234 param_num += 1 

1235 sql += where_clause 

1236 

1237 # Order by distance and limit 

1238 sql += f""" 

1239 ORDER BY distance 

1240 LIMIT {k} 

1241 """ 

1242 

1243 # Execute query 

1244 async with self._pool.acquire() as conn: 

1245 rows = await conn.fetch(sql, *params) 

1246 

1247 # Convert to VectorSearchResult objects 

1248 results = [] 

1249 for row in rows: 

1250 record = self._row_to_record(row) 

1251 

1252 # Convert distance to similarity score (1 - normalized_distance for cosine) 

1253 distance = float(row['distance']) 

1254 if metric_str == "cosine": 

1255 score = 1.0 - min(distance, 2.0) / 2.0 # Normalize cosine distance [0,2] to similarity [0,1] 

1256 elif metric_str in ["euclidean", "l2"]: 

1257 score = 1.0 / (1.0 + distance) # Convert distance to similarity 

1258 else: 

1259 score = 1.0 - distance # Generic conversion 

1260 

1261 result = VectorSearchResult( 

1262 record=record, 

1263 score=score, 

1264 vector_field=field_name, 

1265 metadata={"distance": distance, "metric": metric_str} 

1266 ) 

1267 results.append(result) 

1268 

1269 return results 

1270 

1271 async def enable_vector_support(self) -> bool: 

1272 """Enable vector support for this database. 

1273  

1274 Returns: 

1275 True if vector support is enabled 

1276 """ 

1277 if self._vector_enabled: 

1278 return True 

1279 

1280 await self._detect_vector_support() 

1281 return self._vector_enabled 

1282 

1283 async def has_vector_support(self) -> bool: 

1284 """Check if this database has vector support enabled. 

1285  

1286 Returns: 

1287 True if vector support is available 

1288 """ 

1289 return self._vector_enabled 

1290 

1291 async def bulk_embed_and_store( 

1292 self, 

1293 records: list[Record], 

1294 text_field: str | list[str], 

1295 vector_field: str, 

1296 embedding_fn: Any | None = None, 

1297 batch_size: int = 100, 

1298 model_name: str | None = None, 

1299 model_version: str | None = None, 

1300 ) -> list[str]: 

1301 """Embed text fields and store vectors with records. 

1302  

1303 This is a placeholder implementation. In a real scenario, you would: 

1304 1. Extract text from the specified fields 

1305 2. Call the embedding function to generate vectors 

1306 3. Store the vectors alongside the records 

1307  

1308 Args: 

1309 records: Records to process 

1310 text_field: Field name(s) containing text to embed 

1311 vector_field: Field name to store vectors in 

1312 embedding_fn: Function to generate embeddings 

1313 batch_size: Number of records to process at once 

1314 model_name: Name of the embedding model 

1315 model_version: Version of the embedding model 

1316  

1317 Returns: 

1318 List of record IDs that were processed 

1319 """ 

1320 if not embedding_fn: 

1321 raise ValueError("embedding_fn is required for bulk_embed_and_store") 

1322 

1323 from ..fields import VectorField 

1324 

1325 processed_ids = [] 

1326 

1327 # Process in batches 

1328 for i in range(0, len(records), batch_size): 

1329 batch = records[i:i + batch_size] 

1330 

1331 # Extract texts 

1332 texts = [] 

1333 for record in batch: 

1334 if isinstance(text_field, list): 

1335 text = " ".join(str(record.fields.get(f, {}).value) for f in text_field if f in record.fields) 

1336 else: 

1337 text = str(record.fields.get(text_field, {}).value) if text_field in record.fields else "" 

1338 texts.append(text) 

1339 

1340 # Generate embeddings 

1341 if texts: 

1342 embeddings = await embedding_fn(texts) 

1343 

1344 # Store vectors with records 

1345 for j, record in enumerate(batch): 

1346 if j < len(embeddings): 

1347 vector = embeddings[j] 

1348 

1349 # Add vector field to record 

1350 record.fields[vector_field] = VectorField( 

1351 name=vector_field, 

1352 value=vector, 

1353 dimensions=len(vector) if hasattr(vector, '__len__') else None, 

1354 source_field=text_field if isinstance(text_field, str) else ",".join(text_field), 

1355 model_name=model_name, 

1356 model_version=model_version, 

1357 ) 

1358 

1359 # Create or update record 

1360 if record.has_storage_id(): 

1361 if record.storage_id is None: 

1362 raise ValueError("Record has_storage_id() returned True but storage_id is None") 

1363 await self.update(record.storage_id, record) 

1364 else: 

1365 record_id = await self.create(record) 

1366 record.storage_id = record_id 

1367 

1368 if record.storage_id is None: 

1369 raise ValueError("Record storage_id is None after create/update") 

1370 processed_ids.append(record.storage_id) 

1371 

1372 return processed_ids 

1373 

1374 async def create_vector_index( 

1375 self, 

1376 vector_field: str, 

1377 dimensions: int, 

1378 metric: DistanceMetric | str = "cosine", 

1379 index_type: str = "ivfflat", 

1380 lists: int | None = None, 

1381 ) -> bool: 

1382 """Create a vector index for efficient similarity search. 

1383  

1384 Args: 

1385 vector_field: Name of the vector field to index 

1386 dimensions: Number of dimensions in the vectors 

1387 metric: Distance metric for the index 

1388 index_type: Type of index (ivfflat, hnsw) 

1389 lists: Number of lists for IVFFlat index 

1390  

1391 Returns: 

1392 True if index was created successfully 

1393 """ 

1394 from .postgres_vector import ( 

1395 build_vector_column_expression, 

1396 build_vector_index_sql, 

1397 get_optimal_index_type, 

1398 get_vector_count_sql, 

1399 ) 

1400 

1401 self._check_connection() 

1402 

1403 if not self._vector_enabled: 

1404 return False 

1405 

1406 # Determine optimal parameters if not provided 

1407 if not lists and index_type == "ivfflat": 

1408 # Count vectors to determine optimal lists 

1409 count_sql = get_vector_count_sql(self.schema_name, self.table_name, vector_field) 

1410 async with self._pool.acquire() as conn: 

1411 count = await conn.fetchval(count_sql) or 0 

1412 _, params = get_optimal_index_type(count) 

1413 lists = params.get("lists", 100) 

1414 

1415 # Convert metric enum to string if needed 

1416 if hasattr(metric, 'value'): 

1417 metric_str = metric.value 

1418 else: 

1419 metric_str = str(metric).lower() 

1420 

1421 # Build vector column expression for index 

1422 column_expr = build_vector_column_expression(vector_field, dimensions, for_index=True) 

1423 

1424 # Build index SQL - pass field_name for proper index naming 

1425 index_sql = build_vector_index_sql( 

1426 table_name=self.table_name, 

1427 schema_name=self.schema_name, 

1428 column_name=column_expr, 

1429 dimensions=dimensions, 

1430 metric=metric_str, 

1431 index_type=index_type, 

1432 index_params={"lists": lists} if lists else None, 

1433 field_name=vector_field 

1434 ) 

1435 

1436 # Create the index 

1437 try: 

1438 logger.debug(f"Creating vector index with SQL: {index_sql}") 

1439 async with self._pool.acquire() as conn: 

1440 await conn.execute(index_sql) 

1441 return True 

1442 except Exception as e: 

1443 logger.warning(f"Failed to create vector index: {e}") 

1444 logger.debug(f"Index SQL was: {index_sql}") 

1445 return False 

1446 

1447 async def drop_vector_index(self, vector_field: str, metric: str = "cosine") -> bool: 

1448 """Drop a vector index. 

1449  

1450 Args: 

1451 vector_field: Name of the vector field 

1452 metric: Distance metric used in the index 

1453  

1454 Returns: 

1455 True if index was dropped successfully 

1456 """ 

1457 from .postgres_vector import get_vector_index_name 

1458 

1459 self._check_connection() 

1460 

1461 index_name = get_vector_index_name(self.table_name, vector_field, metric) 

1462 

1463 try: 

1464 async with self._pool.acquire() as conn: 

1465 await conn.execute(f"DROP INDEX IF EXISTS {self.schema_name}.{index_name}") 

1466 return True 

1467 except Exception as e: 

1468 logger.warning(f"Failed to drop vector index: {e}") 

1469 return False 

1470 

1471 async def get_vector_index_stats(self, vector_field: str) -> dict[str, Any]: 

1472 """Get statistics about a vector field and its index. 

1473  

1474 Args: 

1475 vector_field: Name of the vector field 

1476  

1477 Returns: 

1478 Dictionary with index statistics 

1479 """ 

1480 from .postgres_vector import get_index_check_sql, get_vector_count_sql 

1481 

1482 self._check_connection() 

1483 

1484 stats = { 

1485 "field": vector_field, 

1486 "indexed": False, 

1487 "vector_count": 0, 

1488 } 

1489 

1490 try: 

1491 async with self._pool.acquire() as conn: 

1492 # Count vectors 

1493 count_sql = get_vector_count_sql(self.schema_name, self.table_name, vector_field) 

1494 stats["vector_count"] = await conn.fetchval(count_sql) or 0 

1495 

1496 # Check for index 

1497 index_sql, params = get_index_check_sql(self.schema_name, self.table_name, vector_field) 

1498 stats["indexed"] = await conn.fetchval(index_sql, *params) or False 

1499 except Exception as e: 

1500 logger.warning(f"Failed to get vector index stats: {e}") 

1501 

1502 return stats 

1503 

1504 async def stream_read( 

1505 self, 

1506 query: Query | None = None, 

1507 config: StreamConfig | None = None 

1508 ) -> AsyncIterator[Record]: 

1509 """Stream records from PostgreSQL using cursor.""" 

1510 self._check_connection() 

1511 config = config or StreamConfig() 

1512 

1513 # Build SQL query 

1514 sql = f"SELECT id, data, metadata FROM {self.schema_name}.{self.table_name}" 

1515 params = [] 

1516 

1517 if query and query.filters: 

1518 where_clauses = [] 

1519 param_count = 0 

1520 

1521 for filter in query.filters: 

1522 param_count += 1 

1523 field_path = f"data->>'{filter.field}'" 

1524 

1525 if filter.operator == Operator.EQ: 

1526 where_clauses.append(f"{field_path} = ${param_count}") 

1527 params.append(str(filter.value)) 

1528 

1529 if where_clauses: 

1530 sql += " WHERE " + " AND ".join(where_clauses) 

1531 

1532 # Use cursor for efficient streaming 

1533 async with self._pool.acquire() as conn: 

1534 async with conn.transaction(): 

1535 cursor = await conn.cursor(sql, *params) 

1536 

1537 batch = [] 

1538 async for row in cursor: 

1539 record = self._row_to_record(row) 

1540 if query and query.fields: 

1541 record = record.project(query.fields) 

1542 

1543 batch.append(record) 

1544 

1545 if len(batch) >= config.batch_size: 

1546 for rec in batch: 

1547 yield rec 

1548 batch = [] 

1549 

1550 # Yield remaining records 

1551 for rec in batch: 

1552 yield rec 

1553 

1554 async def stream_write( 

1555 self, 

1556 records: AsyncIterator[Record], 

1557 config: StreamConfig | None = None 

1558 ) -> StreamResult: 

1559 """Stream records into PostgreSQL using batch inserts.""" 

1560 self._check_connection() 

1561 config = config or StreamConfig() 

1562 result = StreamResult() 

1563 start_time = time.time() 

1564 quitting = False 

1565 

1566 batch = [] 

1567 async for record in records: 

1568 batch.append(record) 

1569 

1570 if len(batch) >= config.batch_size: 

1571 # Write batch with graceful fallback 

1572 # Use lambda wrapper for _write_batch 

1573 async def batch_func(b): 

1574 await self._write_batch(b) 

1575 return [r.id for r in b] 

1576 

1577 continue_processing = await async_process_batch_with_fallback( 

1578 batch, 

1579 batch_func, 

1580 self.create, 

1581 result, 

1582 config 

1583 ) 

1584 

1585 if not continue_processing: 

1586 quitting = True 

1587 break 

1588 

1589 batch = [] 

1590 

1591 # Write remaining batch 

1592 if batch and not quitting: 

1593 async def batch_func(b): 

1594 await self._write_batch(b) 

1595 return [r.id for r in b] 

1596 

1597 await async_process_batch_with_fallback( 

1598 batch, 

1599 batch_func, 

1600 self.create, 

1601 result, 

1602 config 

1603 ) 

1604 

1605 result.duration = time.time() - start_time 

1606 return result 

1607 

1608 async def _write_batch(self, records: list[Record]) -> list[str]: 

1609 """Write a batch of records using COPY for performance. 

1610  

1611 Returns: 

1612 List of created record IDs 

1613 """ 

1614 if not records: 

1615 return [] 

1616 

1617 # Prepare data for COPY 

1618 rows = [] 

1619 ids = [] 

1620 for record in records: 

1621 row_data = self._record_to_row(record) 

1622 ids.append(row_data["id"]) 

1623 rows.append(( 

1624 row_data["id"], 

1625 row_data["data"], 

1626 row_data["metadata"] 

1627 )) 

1628 

1629 # Use COPY for efficient bulk insert 

1630 async with self._pool.acquire() as conn: 

1631 await conn.copy_records_to_table( 

1632 f"{self.schema_name}.{self.table_name}", 

1633 records=rows, 

1634 columns=["id", "data", "metadata"] 

1635 ) 

1636 

1637 return ids