Coverage for src / dataknobs_data / vector / stores / pgvector.py: 10%

288 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-26 15:45 -0700

1"""PostgreSQL pgvector backend implementation.""" 

2 

3from __future__ import annotations 

4 

5import json 

6import logging 

7from typing import TYPE_CHECKING, Any 

8from uuid import uuid4 

9 

10from ..types import DistanceMetric 

11from .base import VectorStore 

12 

13if TYPE_CHECKING: 

14 import numpy as np 

15 

16logger = logging.getLogger(__name__) 

17 

18try: 

19 import asyncpg 

20 

21 ASYNCPG_AVAILABLE = True 

22except ImportError: 

23 ASYNCPG_AVAILABLE = False 

24 

25 

26class PgVectorStore(VectorStore): 

27 """PostgreSQL pgvector backend for vector similarity search. 

28 

29 Uses PostgreSQL with the pgvector extension for efficient vector storage 

30 and similarity search. Supports IVFFlat and HNSW indexes. 

31 

32 Configuration: 

33 connection_string: PostgreSQL connection URL 

34 table_name: Table name (default: knowledge_embeddings) 

35 schema: Database schema (default: edubot) 

36 dimensions: Vector dimensions (required) 

37 metric: Distance metric (cosine, euclidean, inner_product) 

38 pool_min_size: Minimum connection pool size (default: 2) 

39 pool_max_size: Maximum connection pool size (default: 10) 

40 columns: Column name mappings (optional) 

41 auto_create_table: Create table if missing (default: True) 

42 id_type: ID column type - 'uuid' or 'text' (default: 'uuid') 

43 

44 Index configuration: 

45 index_type: Type of vector index - 'none', 'hnsw', or 'ivfflat' (default: 'none') 

46 auto_create_index: Automatically create index when conditions are met (default: False) 

47 min_rows_for_index: Minimum rows before auto-creating IVFFlat index (default: 1000) 

48 index_params: Parameters for index creation (optional dict) 

49 - For HNSW: m (default: 16), ef_construction (default: 64) 

50 - For IVFFlat: lists (default: 100) 

51 

52 Example - Default schema: 

53 ```python 

54 store = PgVectorStore({ 

55 "connection_string": "postgresql://user:pass@host:5432/db", 

56 "dimensions": 768, 

57 "metric": "cosine", 

58 "schema": "edubot", 

59 }) 

60 ``` 

61 

62 Example - With HNSW index (created immediately, works with any data size): 

63 ```python 

64 store = PgVectorStore({ 

65 "connection_string": "postgresql://user:pass@host:5432/db", 

66 "dimensions": 768, 

67 "index_type": "hnsw", 

68 "auto_create_index": True, 

69 "index_params": {"m": 16, "ef_construction": 64}, 

70 }) 

71 ``` 

72 

73 Example - With IVFFlat index (auto-created when data exceeds threshold): 

74 ```python 

75 store = PgVectorStore({ 

76 "connection_string": "postgresql://user:pass@host:5432/db", 

77 "dimensions": 768, 

78 "index_type": "ivfflat", 

79 "auto_create_index": True, 

80 "min_rows_for_index": 1000, 

81 "index_params": {"lists": 100}, 

82 }) 

83 ``` 

84 

85 Example - Custom table with column mappings: 

86 ```python 

87 store = PgVectorStore({ 

88 "connection_string": "postgresql://user:pass@host:5432/db", 

89 "dimensions": 768, 

90 "table_name": "product_embeddings", 

91 "columns": { 

92 "id": "product_id", 

93 "embedding": "vector_data", 

94 "content": "description", 

95 "metadata": "attributes", 

96 "domain_id": "category", 

97 }, 

98 "id_type": "text", 

99 "auto_create_table": True, 

100 }) 

101 ``` 

102 """ 

103 

104 # Default column mappings 

105 DEFAULT_COLUMNS = { 

106 "id": "id", 

107 "embedding": "embedding", 

108 "content": "content", 

109 "metadata": "metadata", 

110 "domain_id": "domain_id", 

111 "document_id": "document_id", 

112 "chunk_index": "chunk_index", 

113 "created_at": "created_at", 

114 } 

115 

116 def __init__(self, config: dict[str, Any] | None = None) -> None: 

117 """Initialize pgvector store.""" 

118 if not ASYNCPG_AVAILABLE: 

119 raise ImportError( 

120 "asyncpg is not installed. Install with: pip install asyncpg" 

121 ) 

122 

123 super().__init__(config) 

124 self._pool: asyncpg.Pool | None = None 

125 

126 def _parse_backend_config(self) -> None: 

127 """Parse pgvector-specific configuration.""" 

128 import os 

129 

130 self.connection_string = self.config.get("connection_string") 

131 if not self.connection_string: 

132 self.connection_string = os.environ.get("DATABASE_URL") 

133 

134 if not self.connection_string: 

135 raise ValueError( 

136 "connection_string required for pgvector backend. " 

137 "Set in config or DATABASE_URL environment variable." 

138 ) 

139 

140 # Normalize connection string format 

141 if self.connection_string.startswith("postgresql+asyncpg://"): 

142 self.connection_string = self.connection_string.replace( 

143 "postgresql+asyncpg://", "postgresql://" 

144 ) 

145 

146 self.table_name = self.config.get("table_name", "knowledge_embeddings") 

147 self.schema = self.config.get("schema", "edubot") 

148 self.pool_min_size = self.config.get("pool_min_size", 2) 

149 self.pool_max_size = self.config.get("pool_max_size", 10) 

150 

151 # Domain filtering (optional - for multi-tenant isolation) 

152 self.domain_id = self.config.get("domain_id") 

153 

154 # Column mappings - merge user config with defaults 

155 user_columns = self.config.get("columns", {}) 

156 self.columns = {**self.DEFAULT_COLUMNS, **user_columns} 

157 

158 # Table creation options 

159 self.auto_create_table = self.config.get("auto_create_table", True) 

160 self.id_type = self.config.get("id_type", "uuid") 

161 if self.id_type not in ("uuid", "text"): 

162 raise ValueError(f"id_type must be 'uuid' or 'text', got: {self.id_type}") 

163 

164 # Index configuration 

165 self.index_type = self.config.get("index_type", "none") 

166 if self.index_type not in ("none", "hnsw", "ivfflat"): 

167 raise ValueError( 

168 f"index_type must be 'none', 'hnsw', or 'ivfflat', got: {self.index_type}" 

169 ) 

170 self.auto_create_index = self.config.get("auto_create_index", False) 

171 self.min_rows_for_index = self.config.get("min_rows_for_index", 1000) 

172 self.index_params = self.config.get("index_params", {}) 

173 

174 def _col(self, name: str) -> str: 

175 """Get the actual column name for a logical field name.""" 

176 return self.columns.get(name, name) 

177 

178 def _get_operator_class(self) -> str: 

179 """Get the pgvector operator class for the configured metric.""" 

180 if self.metric == DistanceMetric.COSINE: 

181 return "vector_cosine_ops" 

182 elif self.metric in (DistanceMetric.EUCLIDEAN, DistanceMetric.L2): 

183 return "vector_l2_ops" 

184 elif self.metric in (DistanceMetric.DOT_PRODUCT, DistanceMetric.INNER_PRODUCT): 

185 return "vector_ip_ops" 

186 else: 

187 return "vector_cosine_ops" # Default 

188 

189 async def _check_index_exists(self) -> bool: 

190 """Check if a vector index exists on the embedding column. 

191 

192 Queries PostgreSQL's pg_indexes catalog to check for any index 

193 on the embedding column. Works reliably in distributed environments. 

194 """ 

195 if not self._pool: 

196 return False 

197 

198 col_embedding = self._col("embedding") 

199 async with self._pool.acquire() as conn: 

200 # Check pg_indexes for any index on our table that includes the embedding column 

201 exists = await conn.fetchval( 

202 """ 

203 SELECT EXISTS ( 

204 SELECT 1 FROM pg_indexes 

205 WHERE schemaname = $1 

206 AND tablename = $2 

207 AND indexdef LIKE $3 

208 ) 

209 """, 

210 self.schema, 

211 self.table_name, 

212 f"%{col_embedding}%", 

213 ) 

214 return bool(exists) 

215 

216 async def create_index( 

217 self, 

218 index_type: str | None = None, 

219 params: dict[str, Any] | None = None, 

220 if_not_exists: bool = True, 

221 ) -> bool: 

222 """Create a vector index on the embedding column. 

223 

224 Args: 

225 index_type: Type of index - 'hnsw' or 'ivfflat'. Defaults to configured index_type. 

226 params: Index parameters. Defaults to configured index_params. 

227 - For HNSW: m (connections per layer), ef_construction (build quality) 

228 - For IVFFlat: lists (number of clusters) 

229 if_not_exists: Skip creation if index already exists (default: True) 

230 

231 Returns: 

232 True if index was created, False if skipped (already exists) 

233 

234 Raises: 

235 ValueError: If index_type is invalid or 'none' 

236 RuntimeError: If store not initialized 

237 

238 Example: 

239 ```python 

240 # Create HNSW index with custom parameters 

241 await store.create_index("hnsw", {"m": 32, "ef_construction": 128}) 

242 

243 # Create IVFFlat index (requires sufficient data) 

244 await store.create_index("ivfflat", {"lists": 200}) 

245 ``` 

246 """ 

247 if not self._initialized: 

248 raise RuntimeError("Store must be initialized before creating index") 

249 

250 # Use configured values as defaults 

251 idx_type = index_type or self.index_type 

252 idx_params = params if params is not None else self.index_params 

253 

254 if idx_type == "none": 

255 raise ValueError("Cannot create index with index_type='none'") 

256 if idx_type not in ("hnsw", "ivfflat"): 

257 raise ValueError(f"index_type must be 'hnsw' or 'ivfflat', got: {idx_type}") 

258 

259 # Check if index already exists 

260 if if_not_exists and await self._check_index_exists(): 

261 logger.info(f"Index already exists on {self.schema}.{self.table_name}") 

262 return False 

263 

264 col_embedding = self._col("embedding") 

265 operator_class = self._get_operator_class() 

266 index_name = f"idx_{self.table_name}_{col_embedding}_{idx_type}" 

267 

268 async with self._pool.acquire() as conn: 

269 if idx_type == "hnsw": 

270 m = idx_params.get("m", 16) 

271 ef_construction = idx_params.get("ef_construction", 64) 

272 await conn.execute(f""" 

273 CREATE INDEX {"IF NOT EXISTS" if if_not_exists else ""} {index_name} 

274 ON {self.schema}.{self.table_name} 

275 USING hnsw ({col_embedding} {operator_class}) 

276 WITH (m = {m}, ef_construction = {ef_construction}) 

277 """) 

278 else: # ivfflat 

279 lists = idx_params.get("lists", 100) 

280 await conn.execute(f""" 

281 CREATE INDEX {"IF NOT EXISTS" if if_not_exists else ""} {index_name} 

282 ON {self.schema}.{self.table_name} 

283 USING ivfflat ({col_embedding} {operator_class}) 

284 WITH (lists = {lists}) 

285 """) 

286 

287 logger.info( 

288 f"Created {idx_type} index on {self.schema}.{self.table_name}.{col_embedding}" 

289 ) 

290 return True 

291 

292 async def _maybe_create_index(self) -> None: 

293 """Conditionally create index based on configuration and data size. 

294 

295 Called during search() to auto-create IVFFlat index when: 

296 - auto_create_index is True 

297 - index_type is 'ivfflat' 

298 - Row count exceeds min_rows_for_index 

299 - No index exists yet 

300 """ 

301 if not self.auto_create_index: 

302 return 

303 if self.index_type != "ivfflat": 

304 return # HNSW is created at table creation time 

305 

306 # Check if index already exists (distributed-safe) 

307 if await self._check_index_exists(): 

308 return 

309 

310 # Check row count 

311 row_count = await self.count() 

312 if row_count < self.min_rows_for_index: 

313 return 

314 

315 logger.info( 

316 f"Auto-creating IVFFlat index: {row_count} rows >= {self.min_rows_for_index} threshold" 

317 ) 

318 await self.create_index("ivfflat", self.index_params, if_not_exists=True) 

319 

320 async def initialize(self) -> None: 

321 """Initialize database connection pool.""" 

322 if self._initialized: 

323 return 

324 

325 logger.info(f"Initializing pgvector store: {self.schema}.{self.table_name}") 

326 

327 # Create connection pool 

328 self._pool = await asyncpg.create_pool( 

329 self.connection_string, 

330 min_size=self.pool_min_size, 

331 max_size=self.pool_max_size, 

332 command_timeout=30, 

333 ) 

334 

335 # Verify pgvector extension and table exist 

336 async with self._pool.acquire() as conn: 

337 # Check pgvector extension 

338 has_pgvector = await conn.fetchval( 

339 "SELECT EXISTS(SELECT 1 FROM pg_extension WHERE extname = 'vector')" 

340 ) 

341 if not has_pgvector: 

342 raise RuntimeError( 

343 "pgvector extension not installed. Run: CREATE EXTENSION vector;" 

344 ) 

345 

346 # Check table exists 

347 table_exists = await conn.fetchval( 

348 """ 

349 SELECT EXISTS ( 

350 SELECT FROM information_schema.tables 

351 WHERE table_schema = $1 AND table_name = $2 

352 ) 

353 """, 

354 self.schema, 

355 self.table_name, 

356 ) 

357 if not table_exists: 

358 if self.auto_create_table: 

359 logger.info( 

360 f"Table {self.schema}.{self.table_name} does not exist. " 

361 "Creating with configured schema." 

362 ) 

363 await self._create_table(conn) 

364 else: 

365 raise RuntimeError( 

366 f"Table {self.schema}.{self.table_name} does not exist " 

367 "and auto_create_table is disabled." 

368 ) 

369 

370 self._initialized = True 

371 logger.info("pgvector store initialized successfully") 

372 

373 async def _create_table(self, conn: asyncpg.Connection) -> None: 

374 """Create the embeddings table using configured column names.""" 

375 await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self.schema}") 

376 

377 # Build ID column definition based on id_type 

378 if self.id_type == "uuid": 

379 id_def = f"{self._col('id')} UUID PRIMARY KEY DEFAULT gen_random_uuid()" 

380 else: 

381 id_def = f"{self._col('id')} TEXT PRIMARY KEY" 

382 

383 # Build CREATE TABLE with configured column names 

384 await conn.execute(f""" 

385 CREATE TABLE IF NOT EXISTS {self.schema}.{self.table_name} ( 

386 {id_def}, 

387 {self._col('domain_id')} VARCHAR(100), 

388 {self._col('document_id')} VARCHAR(255), 

389 {self._col('chunk_index')} INTEGER, 

390 {self._col('content')} TEXT, 

391 {self._col('embedding')} vector({self.dimensions}), 

392 {self._col('metadata')} JSONB DEFAULT '{{}}', 

393 {self._col('created_at')} TIMESTAMP DEFAULT NOW() 

394 ) 

395 """) 

396 

397 # Create HNSW index immediately if configured (HNSW works with empty tables) 

398 if self.auto_create_index and self.index_type == "hnsw": 

399 col_embedding = self._col("embedding") 

400 operator_class = self._get_operator_class() 

401 index_name = f"idx_{self.table_name}_{col_embedding}_hnsw" 

402 m = self.index_params.get("m", 16) 

403 ef_construction = self.index_params.get("ef_construction", 64) 

404 await conn.execute(f""" 

405 CREATE INDEX IF NOT EXISTS {index_name} 

406 ON {self.schema}.{self.table_name} 

407 USING hnsw ({col_embedding} {operator_class}) 

408 WITH (m = {m}, ef_construction = {ef_construction}) 

409 """) 

410 logger.info(f"Created HNSW index on {self.schema}.{self.table_name}") 

411 

412 # Note: IVFFlat index is not created here because it requires existing data. 

413 # It will be auto-created during search() if auto_create_index=True and 

414 # row count exceeds min_rows_for_index. Or use create_index() explicitly. 

415 

416 logger.info( 

417 f"Created table {self.schema}.{self.table_name} with columns: {self.columns}" 

418 ) 

419 

420 async def close(self) -> None: 

421 """Close the connection pool.""" 

422 if self._pool: 

423 await self._pool.close() 

424 self._pool = None 

425 self._initialized = False 

426 logger.info("pgvector store closed") 

427 

428 async def add_vectors( 

429 self, 

430 vectors: np.ndarray | list[np.ndarray], 

431 ids: list[str] | None = None, 

432 metadata: list[dict[str, Any]] | None = None, 

433 ) -> list[str]: 

434 """Add vectors to the store.""" 

435 if not self._initialized: 

436 await self.initialize() 

437 

438 

439 # Prepare vectors 

440 vectors = self._prepare_vector( 

441 vectors, normalize=(self.metric == DistanceMetric.COSINE) 

442 ) 

443 

444 # Generate IDs if not provided 

445 if ids is None: 

446 ids = [str(uuid4()) for _ in range(len(vectors))] 

447 

448 # Generate metadata if not provided 

449 if metadata is None: 

450 metadata = [{} for _ in range(len(vectors))] 

451 

452 # Build ID type cast 

453 id_cast = "::uuid" if self.id_type == "uuid" else "" 

454 

455 async with self._pool.acquire() as conn: 

456 # Batch insert 

457 for i, (vec, vec_id, meta) in enumerate(zip(vectors, ids, metadata)): 

458 # Extract document info from metadata if available 

459 document_id = meta.get("document_id", meta.get("source")) 

460 chunk_index = meta.get("chunk_index", i) 

461 content = meta.get("source_text", meta.get("content", "")) 

462 domain_id = meta.get("domain_id", self.domain_id) 

463 

464 await conn.execute( 

465 f""" 

466 INSERT INTO {self.schema}.{self.table_name} 

467 ({self._col('id')}, {self._col('domain_id')}, 

468 {self._col('document_id')}, {self._col('chunk_index')}, 

469 {self._col('content')}, {self._col('embedding')}, 

470 {self._col('metadata')}) 

471 VALUES ($1{id_cast}, $2, $3, $4, $5, $6::vector, $7::jsonb) 

472 ON CONFLICT ({self._col('id')}) DO UPDATE SET 

473 {self._col('embedding')} = EXCLUDED.{self._col('embedding')}, 

474 {self._col('metadata')} = EXCLUDED.{self._col('metadata')} 

475 """, 

476 vec_id, 

477 domain_id, 

478 document_id, 

479 chunk_index, 

480 content, 

481 f"[{','.join(str(x) for x in vec.tolist())}]", 

482 json.dumps(meta), 

483 ) 

484 

485 logger.debug(f"Added {len(ids)} vectors to pgvector") 

486 return ids 

487 

488 async def get_vectors( 

489 self, 

490 ids: list[str], 

491 include_metadata: bool = True, 

492 ) -> list[tuple[np.ndarray | None, dict[str, Any] | None]]: 

493 """Retrieve vectors by ID.""" 

494 if not self._initialized: 

495 await self.initialize() 

496 

497 import numpy as np 

498 

499 id_cast = "::uuid" if self.id_type == "uuid" else "" 

500 col_embedding = self._col("embedding") 

501 col_metadata = self._col("metadata") 

502 col_id = self._col("id") 

503 

504 results: list[tuple[np.ndarray | None, dict[str, Any] | None]] = [] 

505 async with self._pool.acquire() as conn: 

506 for vec_id in ids: 

507 row = await conn.fetchrow( 

508 f""" 

509 SELECT {col_embedding}::text as embedding, {col_metadata} as metadata 

510 FROM {self.schema}.{self.table_name} 

511 WHERE {col_id} = $1{id_cast} 

512 """, 

513 vec_id, 

514 ) 

515 

516 if row is None: 

517 results.append((None, None)) 

518 else: 

519 # Parse vector from PostgreSQL format 

520 vec_str = row["embedding"] 

521 vec = np.array( 

522 [float(x) for x in vec_str.strip("[]").split(",")], 

523 dtype=np.float32, 

524 ) 

525 # asyncpg returns JSONB as dict or str depending on version 

526 meta = None 

527 if include_metadata and row["metadata"] is not None: 

528 raw_meta = row["metadata"] 

529 if isinstance(raw_meta, dict): 

530 meta = raw_meta 

531 elif isinstance(raw_meta, str): 

532 meta = json.loads(raw_meta) 

533 else: 

534 meta = dict(raw_meta) 

535 results.append((vec, meta)) 

536 

537 return results 

538 

539 async def delete_vectors(self, ids: list[str]) -> int: 

540 """Delete vectors by ID.""" 

541 if not self._initialized: 

542 await self.initialize() 

543 

544 id_array_cast = "::uuid[]" if self.id_type == "uuid" else "::text[]" 

545 col_id = self._col("id") 

546 

547 async with self._pool.acquire() as conn: 

548 result = await conn.execute( 

549 f""" 

550 DELETE FROM {self.schema}.{self.table_name} 

551 WHERE {col_id} = ANY($1{id_array_cast}) 

552 """, 

553 ids, 

554 ) 

555 # Parse "DELETE n" to get count 

556 count = int(result.split()[-1]) 

557 

558 logger.debug(f"Deleted {count} vectors from pgvector") 

559 return count 

560 

561 async def search( 

562 self, 

563 query_vector: np.ndarray, 

564 k: int = 10, 

565 filter: dict[str, Any] | None = None, 

566 include_metadata: bool = True, 

567 ) -> list[tuple[str, float, dict[str, Any] | None]]: 

568 """Search for similar vectors using pgvector.""" 

569 if not self._initialized: 

570 await self.initialize() 

571 

572 # Auto-create IVFFlat index if conditions are met 

573 await self._maybe_create_index() 

574 

575 # Prepare query vector 

576 query = self._prepare_vector( 

577 query_vector, normalize=(self.metric == DistanceMetric.COSINE) 

578 ) 

579 query_str = f"[{','.join(str(x) for x in query[0].tolist())}]" 

580 

581 # Get column names 

582 col_id = self._col("id") 

583 col_embedding = self._col("embedding") 

584 col_metadata = self._col("metadata") 

585 col_content = self._col("content") 

586 col_domain_id = self._col("domain_id") 

587 

588 # Build distance operator based on metric 

589 if self.metric == DistanceMetric.COSINE: 

590 distance_op = "<=>" # Cosine distance 

591 # Convert to similarity 

592 score_expr = f"1 - ({col_embedding} <=> $1::vector)" 

593 elif self.metric in (DistanceMetric.EUCLIDEAN, DistanceMetric.L2): 

594 distance_op = "<->" # L2 distance 

595 score_expr = f"1.0 / (1.0 + ({col_embedding} <-> $1::vector))" 

596 elif self.metric in (DistanceMetric.DOT_PRODUCT, DistanceMetric.INNER_PRODUCT): 

597 distance_op = "<#>" # Negative inner product 

598 # Negate to get actual inner product 

599 score_expr = f"-({col_embedding} <#> $1::vector)" 

600 else: 

601 distance_op = "<=>" 

602 score_expr = f"1 - ({col_embedding} <=> $1::vector)" 

603 

604 # Build WHERE clause for filters 

605 where_clauses = [] 

606 params: list[Any] = [query_str] 

607 param_idx = 2 

608 

609 # Add domain filter if configured 

610 if self.domain_id: 

611 where_clauses.append(f"{col_domain_id} = ${param_idx}") 

612 params.append(self.domain_id) 

613 param_idx += 1 

614 

615 # Add metadata filters 

616 if filter: 

617 for key, value in filter.items(): 

618 where_clauses.append(f"{col_metadata}->>'{key}' = ${param_idx}") 

619 params.append(str(value)) 

620 param_idx += 1 

621 

622 where_sql = "" 

623 if where_clauses: 

624 where_sql = "WHERE " + " AND ".join(where_clauses) 

625 

626 # Execute search query 

627 async with self._pool.acquire() as conn: 

628 rows = await conn.fetch( 

629 f""" 

630 SELECT 

631 {col_id}::text as id, 

632 {score_expr} as score, 

633 {col_metadata} as metadata, 

634 {col_content} as content 

635 FROM {self.schema}.{self.table_name} 

636 {where_sql} 

637 ORDER BY {col_embedding} {distance_op} $1::vector 

638 LIMIT {k} 

639 """, 

640 *params, 

641 ) 

642 

643 results = [] 

644 for row in rows: 

645 meta = None 

646 if include_metadata and row["metadata"] is not None: 

647 raw_meta = row["metadata"] 

648 if isinstance(raw_meta, dict): 

649 meta = raw_meta.copy() 

650 elif isinstance(raw_meta, str): 

651 meta = json.loads(raw_meta) 

652 else: 

653 meta = dict(raw_meta) 

654 # Add content to metadata for convenience 

655 meta["content"] = row["content"] 

656 results.append((row["id"], float(row["score"]), meta)) 

657 

658 return results 

659 

660 async def update_metadata( 

661 self, 

662 ids: list[str], 

663 metadata: list[dict[str, Any]], 

664 ) -> int: 

665 """Update metadata for existing vectors.""" 

666 if not self._initialized: 

667 await self.initialize() 

668 

669 id_cast = "::uuid" if self.id_type == "uuid" else "" 

670 col_id = self._col("id") 

671 col_metadata = self._col("metadata") 

672 

673 updated = 0 

674 async with self._pool.acquire() as conn: 

675 for vec_id, meta in zip(ids, metadata): 

676 result = await conn.execute( 

677 f""" 

678 UPDATE {self.schema}.{self.table_name} 

679 SET {col_metadata} = $2::jsonb 

680 WHERE {col_id} = $1{id_cast} 

681 """, 

682 vec_id, 

683 json.dumps(meta), 

684 ) 

685 if result == "UPDATE 1": 

686 updated += 1 

687 

688 return updated 

689 

690 async def count(self, filter: dict[str, Any] | None = None) -> int: 

691 """Count vectors in the store.""" 

692 if not self._initialized: 

693 await self.initialize() 

694 

695 col_domain_id = self._col("domain_id") 

696 col_metadata = self._col("metadata") 

697 

698 where_clauses = [] 

699 params: list[Any] = [] 

700 param_idx = 1 

701 

702 if self.domain_id: 

703 where_clauses.append(f"{col_domain_id} = ${param_idx}") 

704 params.append(self.domain_id) 

705 param_idx += 1 

706 

707 if filter: 

708 for key, value in filter.items(): 

709 where_clauses.append(f"{col_metadata}->>'{key}' = ${param_idx}") 

710 params.append(str(value)) 

711 param_idx += 1 

712 

713 where_sql = "" 

714 if where_clauses: 

715 where_sql = "WHERE " + " AND ".join(where_clauses) 

716 

717 async with self._pool.acquire() as conn: 

718 count = await conn.fetchval( 

719 f"SELECT COUNT(*) FROM {self.schema}.{self.table_name} {where_sql}", 

720 *params, 

721 ) 

722 

723 return int(count or 0) 

724 

725 async def clear(self) -> None: 

726 """Clear all vectors from the store.""" 

727 if not self._initialized: 

728 await self.initialize() 

729 

730 col_domain_id = self._col("domain_id") 

731 

732 async with self._pool.acquire() as conn: 

733 if self.domain_id: 

734 await conn.execute( 

735 f"DELETE FROM {self.schema}.{self.table_name} " 

736 f"WHERE {col_domain_id} = $1", 

737 self.domain_id, 

738 ) 

739 else: 

740 await conn.execute(f"TRUNCATE {self.schema}.{self.table_name}") 

741 

742 logger.info("Cleared pgvector store")