Coverage for src/dataknobs_data/backends/postgres_async.py: 70%

323 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-15 14:22 -0500

1"""Improved async PostgreSQL backend with native asyncpg and event loop-aware connection pooling.""" 

2 

3import asyncio 

4import asyncpg 

5import json 

6import time 

7import uuid 

8import logging 

9from typing import Any, AsyncIterator, Optional, Dict 

10from weakref import WeakValueDictionary 

11from dataclasses import dataclass 

12 

13from dataknobs_config import ConfigurableBase 

14 

15from ..database import AsyncDatabase 

16from ..query import Operator, Query, SortOrder 

17from ..records import Record 

18from ..streaming import StreamConfig, StreamResult 

19 

20logger = logging.getLogger(__name__) 

21 

22 

23@dataclass 

24class PoolConfig: 

25 """Configuration for asyncpg connection pool.""" 

26 host: str = "localhost" 

27 port: int = 5432 

28 database: str = "postgres" 

29 user: str = "postgres" 

30 password: str = "" 

31 table: str = "records" 

32 schema: str = "public" 

33 min_size: int = 10 

34 max_size: int = 10 

35 command_timeout: Optional[float] = None 

36 ssl: Optional[Any] = None 

37 

38 @classmethod 

39 def from_dict(cls, config: dict) -> "PoolConfig": 

40 """Create from configuration dictionary.""" 

41 return cls( 

42 host=config.get("host", "localhost"), 

43 port=config.get("port", 5432), 

44 database=config.get("database", "postgres"), 

45 user=config.get("user", "postgres"), 

46 password=config.get("password", ""), 

47 table=config.get("table", "records"), 

48 schema=config.get("schema", "public"), 

49 min_size=config.get("min_pool_size", 10), 

50 max_size=config.get("max_pool_size", 10), 

51 command_timeout=config.get("command_timeout"), 

52 ssl=config.get("ssl") 

53 ) 

54 

55 def to_connection_string(self) -> str: 

56 """Convert to PostgreSQL connection string.""" 

57 return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" 

58 

59 

60class ConnectionPoolManager: 

61 """ 

62 Manages connection pools per event loop. 

63  

64 This ensures each event loop gets its own connection pool, 

65 preventing "Event loop is closed" errors. 

66 """ 

67 

68 def __init__(self): 

69 """Initialize the connection pool manager.""" 

70 # Map of (config_hash, loop_id) -> pool 

71 self._pools: Dict[tuple, asyncpg.Pool] = {} 

72 # Weak references to event loops for cleanup 

73 self._loop_refs: WeakValueDictionary = WeakValueDictionary() 

74 

75 async def get_pool(self, config: PoolConfig) -> asyncpg.Pool: 

76 """ 

77 Get or create a connection pool for the current event loop. 

78  

79 Args: 

80 config: Pool configuration 

81  

82 Returns: 

83 asyncpg.Pool instance for the current event loop 

84 """ 

85 loop = asyncio.get_running_loop() 

86 loop_id = id(loop) 

87 config_hash = hash((config.host, config.port, config.database, config.user)) 

88 pool_key = (config_hash, loop_id) 

89 

90 # Check if we already have a pool for this config and loop 

91 if pool_key in self._pools: 

92 pool = self._pools[pool_key] 

93 # Verify the pool is still valid 

94 try: 

95 async with pool.acquire() as conn: 

96 await conn.fetchval("SELECT 1") 

97 return pool 

98 except (asyncpg.InterfaceError, asyncpg.InternalClientError, Exception) as e: 

99 logger.warning(f"Pool for loop {loop_id} is invalid: {e}. Creating new one.") 

100 await self._close_pool(pool_key) 

101 

102 # Create new pool 

103 logger.info(f"Creating new connection pool for loop {loop_id}") 

104 pool = await asyncpg.create_pool( 

105 config.to_connection_string(), 

106 min_size=config.min_size, 

107 max_size=config.max_size, 

108 command_timeout=config.command_timeout, 

109 ssl=config.ssl 

110 ) 

111 

112 # Store pool and loop reference 

113 self._pools[pool_key] = pool 

114 self._loop_refs[loop_id] = loop 

115 

116 return pool 

117 

118 async def _close_pool(self, pool_key: tuple): 

119 """Close and remove a pool.""" 

120 if pool_key in self._pools: 

121 pool = self._pools[pool_key] 

122 try: 

123 await pool.close() 

124 except Exception as e: 

125 logger.error(f"Error closing pool: {e}") 

126 finally: 

127 del self._pools[pool_key] 

128 

129 async def remove_pool(self, config: PoolConfig) -> bool: 

130 """Remove a pool for the current event loop.""" 

131 loop_id = id(asyncio.get_running_loop()) 

132 config_hash = hash((config.host, config.port, config.database, config.user)) 

133 pool_key = (config_hash, loop_id) 

134 

135 if pool_key in self._pools: 

136 await self._close_pool(pool_key) 

137 return True 

138 return False 

139 

140 async def close_all(self): 

141 """Close all connection pools.""" 

142 for pool_key in list(self._pools.keys()): 

143 await self._close_pool(pool_key) 

144 

145 

146# Global pool manager instance 

147_pool_manager = ConnectionPoolManager() 

148 

149 

150class AsyncPostgresDatabase(AsyncDatabase, ConfigurableBase): 

151 """Native async PostgreSQL database backend with event loop-aware connection pooling.""" 

152 

153 def __init__(self, config: dict[str, Any] | None = None): 

154 """Initialize async PostgreSQL database.""" 

155 super().__init__(config) 

156 self._pool_config = PoolConfig.from_dict(config or {}) 

157 self._pool: Optional[asyncpg.Pool] = None 

158 self._connected = False 

159 

160 @classmethod 

161 def from_config(cls, config: dict) -> "AsyncPostgresDatabase": 

162 """Create from config dictionary.""" 

163 return cls(config) 

164 

165 async def connect(self) -> None: 

166 """Connect to the database.""" 

167 if self._connected: 

168 return 

169 

170 # Get or create pool for current event loop 

171 self._pool = await _pool_manager.get_pool(self._pool_config) 

172 

173 # Ensure table exists 

174 await self._ensure_table() 

175 self._connected = True 

176 

177 async def close(self) -> None: 

178 """Close the database connection.""" 

179 if self._connected: 

180 # Pool manager handles cleanup 

181 self._pool = None 

182 self._connected = False 

183 

184 def _initialize(self) -> None: 

185 """Initialize is handled in connect.""" 

186 pass 

187 

188 async def _ensure_table(self) -> None: 

189 """Ensure the records table exists.""" 

190 if not self._pool: 

191 raise RuntimeError("Database not connected. Call connect() first.") 

192 

193 create_table_sql = f""" 

194 CREATE TABLE IF NOT EXISTS {self._pool_config.schema}.{self._pool_config.table} ( 

195 id VARCHAR(255) PRIMARY KEY, 

196 data JSONB NOT NULL, 

197 metadata JSONB, 

198 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 

199 updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP 

200 ); 

201  

202 CREATE INDEX IF NOT EXISTS idx_{self._pool_config.table}_data  

203 ON {self._pool_config.schema}.{self._pool_config.table} USING GIN (data); 

204  

205 CREATE INDEX IF NOT EXISTS idx_{self._pool_config.table}_metadata 

206 ON {self._pool_config.schema}.{self._pool_config.table} USING GIN (metadata); 

207 """ 

208 

209 async with self._pool.acquire() as conn: 

210 await conn.execute(create_table_sql) 

211 

212 def _check_connection(self) -> None: 

213 """Check if database is connected.""" 

214 if not self._connected or not self._pool: 

215 raise RuntimeError("Database not connected. Call connect() first.") 

216 

217 def _record_to_row(self, record: Record, id: str | None = None) -> dict[str, Any]: 

218 """Convert a Record to a database row.""" 

219 data = {} 

220 for field_name, field_obj in record.fields.items(): 

221 data[field_name] = field_obj.value 

222 

223 return { 

224 "id": id or str(uuid.uuid4()), 

225 "data": json.dumps(data), 

226 "metadata": json.dumps(record.metadata) if record.metadata else None, 

227 } 

228 

229 def _row_to_record(self, row: asyncpg.Record) -> Record: 

230 """Convert a database row to a Record.""" 

231 data = row.get("data", {}) 

232 if isinstance(data, str): 

233 data = json.loads(data) 

234 

235 metadata = row.get("metadata", {}) 

236 if isinstance(metadata, str) and metadata: 

237 metadata = json.loads(metadata) 

238 elif not metadata: 

239 metadata = {} 

240 

241 return Record(data=data, metadata=metadata) 

242 

243 async def create(self, record: Record) -> str: 

244 """Create a new record.""" 

245 self._check_connection() 

246 id = str(uuid.uuid4()) 

247 row = self._record_to_row(record, id) 

248 

249 sql = f""" 

250 INSERT INTO {self._pool_config.schema}.{self._pool_config.table} (id, data, metadata) 

251 VALUES ($1, $2, $3) 

252 """ 

253 

254 async with self._pool.acquire() as conn: 

255 await conn.execute(sql, row["id"], row["data"], row["metadata"]) 

256 

257 return id 

258 

259 async def read(self, id: str) -> Record | None: 

260 """Read a record by ID.""" 

261 self._check_connection() 

262 sql = f""" 

263 SELECT id, data, metadata 

264 FROM {self._pool_config.schema}.{self._pool_config.table} 

265 WHERE id = $1 

266 """ 

267 

268 async with self._pool.acquire() as conn: 

269 row = await conn.fetchrow(sql, id) 

270 

271 if not row: 

272 return None 

273 

274 return self._row_to_record(row) 

275 

276 async def update(self, id: str, record: Record) -> bool: 

277 """Update an existing record.""" 

278 self._check_connection() 

279 row = self._record_to_row(record, id) 

280 

281 sql = f""" 

282 UPDATE {self._pool_config.schema}.{self._pool_config.table} 

283 SET data = $2, metadata = $3, updated_at = CURRENT_TIMESTAMP 

284 WHERE id = $1 

285 """ 

286 

287 async with self._pool.acquire() as conn: 

288 result = await conn.execute(sql, row["id"], row["data"], row["metadata"]) 

289 

290 # Returns UPDATE n where n is rows affected 

291 return result.split()[-1] != "0" 

292 

293 async def delete(self, id: str) -> bool: 

294 """Delete a record by ID.""" 

295 self._check_connection() 

296 sql = f""" 

297 DELETE FROM {self._pool_config.schema}.{self._pool_config.table} 

298 WHERE id = $1 

299 """ 

300 

301 async with self._pool.acquire() as conn: 

302 result = await conn.execute(sql, id) 

303 

304 # Returns DELETE n where n is rows affected 

305 return result.split()[-1] != "0" 

306 

307 async def exists(self, id: str) -> bool: 

308 """Check if a record exists.""" 

309 self._check_connection() 

310 sql = f""" 

311 SELECT 1 FROM {self._pool_config.schema}.{self._pool_config.table} 

312 WHERE id = $1 

313 LIMIT 1 

314 """ 

315 

316 async with self._pool.acquire() as conn: 

317 row = await conn.fetchrow(sql, id) 

318 

319 return row is not None 

320 

321 async def upsert(self, id: str, record: Record) -> str: 

322 """Update or insert a record with a specific ID.""" 

323 self._check_connection() 

324 row = self._record_to_row(record, id) 

325 

326 sql = f""" 

327 INSERT INTO {self._pool_config.schema}.{self._pool_config.table} (id, data, metadata) 

328 VALUES ($1, $2, $3) 

329 ON CONFLICT (id) DO UPDATE 

330 SET data = EXCLUDED.data, metadata = EXCLUDED.metadata, updated_at = CURRENT_TIMESTAMP 

331 """ 

332 

333 async with self._pool.acquire() as conn: 

334 await conn.execute(sql, row["id"], row["data"], row["metadata"]) 

335 

336 return id 

337 

338 async def search(self, query: Query) -> list[Record]: 

339 """Search for records matching the query.""" 

340 self._check_connection() 

341 

342 # Build SQL query from Query object 

343 where_clauses = [] 

344 params = [] 

345 param_count = 0 

346 

347 # Build WHERE clauses for filters 

348 for filter in query.filters: 

349 param_count += 1 

350 field_path = f"data->>'{filter.field}'" 

351 

352 if filter.operator == Operator.EQ: 

353 if isinstance(filter.value, bool): 

354 where_clauses.append(f"({field_path})::boolean = ${param_count}") 

355 params.append(filter.value) 

356 elif isinstance(filter.value, (int, float)): 

357 where_clauses.append(f"({field_path})::numeric = ${param_count}") 

358 params.append(filter.value) 

359 else: 

360 where_clauses.append(f"{field_path} = ${param_count}") 

361 params.append(str(filter.value)) 

362 elif filter.operator == Operator.NEQ: 

363 if isinstance(filter.value, bool): 

364 where_clauses.append(f"({field_path})::boolean != ${param_count}") 

365 params.append(filter.value) 

366 elif isinstance(filter.value, (int, float)): 

367 where_clauses.append(f"({field_path})::numeric != ${param_count}") 

368 params.append(filter.value) 

369 else: 

370 where_clauses.append(f"{field_path} != ${param_count}") 

371 params.append(str(filter.value)) 

372 elif filter.operator == Operator.GT: 

373 where_clauses.append(f"({field_path})::numeric > ${param_count}") 

374 params.append(filter.value) 

375 elif filter.operator == Operator.LT: 

376 where_clauses.append(f"({field_path})::numeric < ${param_count}") 

377 params.append(filter.value) 

378 elif filter.operator == Operator.GTE: 

379 where_clauses.append(f"({field_path})::numeric >= ${param_count}") 

380 params.append(filter.value) 

381 elif filter.operator == Operator.LTE: 

382 where_clauses.append(f"({field_path})::numeric <= ${param_count}") 

383 params.append(filter.value) 

384 elif filter.operator == Operator.LIKE: 

385 where_clauses.append(f"{field_path} LIKE ${param_count}") 

386 params.append(f"%{filter.value}%") 

387 elif filter.operator == Operator.IN: 

388 values = [str(v) for v in filter.value] 

389 where_clauses.append(f"{field_path} = ANY(${param_count})") 

390 params.append(values) 

391 elif filter.operator == Operator.NOT_IN: 

392 values = [str(v) for v in filter.value] 

393 where_clauses.append(f"{field_path} != ALL(${param_count})") 

394 params.append(values) 

395 

396 # Build SQL 

397 sql = f"SELECT id, data, metadata FROM {self._pool_config.schema}.{self._pool_config.table}" 

398 if where_clauses: 

399 sql += " WHERE " + " AND ".join(where_clauses) 

400 

401 # Add ORDER BY 

402 if query.sort_specs: 

403 order_clauses = [] 

404 for sort_spec in query.sort_specs: 

405 field_path = f"data->>'{sort_spec.field}'" 

406 direction = "DESC" if sort_spec.order == SortOrder.DESC else "ASC" 

407 # Handle numeric sorting 

408 order_clause = f""" 

409 CASE  

410 WHEN {field_path} ~ '^[0-9]+(\\.[0-9]+)?$'  

411 THEN ({field_path})::numeric  

412 ELSE NULL  

413 END {direction} NULLS LAST, 

414 {field_path} {direction} 

415 """ 

416 order_clauses.append(order_clause) 

417 sql += " ORDER BY " + ", ".join(order_clauses) 

418 

419 # Add LIMIT and OFFSET 

420 if query.limit_value: 

421 sql += f" LIMIT {query.limit_value}" 

422 if query.offset_value: 

423 sql += f" OFFSET {query.offset_value}" 

424 

425 # Execute query 

426 async with self._pool.acquire() as conn: 

427 rows = await conn.fetch(sql, *params) 

428 

429 # Convert to records 

430 records = [] 

431 for row in rows: 

432 record = self._row_to_record(row) 

433 

434 # Apply field projection if specified 

435 if query.fields: 

436 record = record.project(query.fields) 

437 

438 records.append(record) 

439 

440 return records 

441 

442 async def _count_all(self) -> int: 

443 """Count all records in the database.""" 

444 self._check_connection() 

445 sql = f"SELECT COUNT(*) as count FROM {self._pool_config.schema}.{self._pool_config.table}" 

446 

447 async with self._pool.acquire() as conn: 

448 row = await conn.fetchrow(sql) 

449 

450 return row["count"] if row else 0 

451 

452 async def clear(self) -> int: 

453 """Clear all records from the database.""" 

454 self._check_connection() 

455 # Get count first 

456 count = await self._count_all() 

457 

458 # Delete all records 

459 sql = f"TRUNCATE TABLE {self._pool_config.schema}.{self._pool_config.table}" 

460 

461 async with self._pool.acquire() as conn: 

462 await conn.execute(sql) 

463 

464 return count 

465 

466 async def stream_read( 

467 self, 

468 query: Optional[Query] = None, 

469 config: Optional[StreamConfig] = None 

470 ) -> AsyncIterator[Record]: 

471 """Stream records from PostgreSQL using cursor.""" 

472 self._check_connection() 

473 config = config or StreamConfig() 

474 

475 # Build SQL query 

476 sql = f"SELECT id, data, metadata FROM {self._pool_config.schema}.{self._pool_config.table}" 

477 params = [] 

478 

479 if query and query.filters: 

480 where_clauses = [] 

481 param_count = 0 

482 

483 for filter in query.filters: 

484 param_count += 1 

485 field_path = f"data->>'{filter.field}'" 

486 

487 if filter.operator == Operator.EQ: 

488 where_clauses.append(f"{field_path} = ${param_count}") 

489 params.append(str(filter.value)) 

490 

491 if where_clauses: 

492 sql += " WHERE " + " AND ".join(where_clauses) 

493 

494 # Use cursor for efficient streaming 

495 async with self._pool.acquire() as conn: 

496 async with conn.transaction(): 

497 cursor = await conn.cursor(sql, *params) 

498 

499 batch = [] 

500 async for row in cursor: 

501 record = self._row_to_record(row) 

502 if query and query.fields: 

503 record = record.project(query.fields) 

504 

505 batch.append(record) 

506 

507 if len(batch) >= config.batch_size: 

508 for rec in batch: 

509 yield rec 

510 batch = [] 

511 

512 # Yield remaining records 

513 for rec in batch: 

514 yield rec 

515 

516 async def stream_write( 

517 self, 

518 records: AsyncIterator[Record], 

519 config: Optional[StreamConfig] = None 

520 ) -> StreamResult: 

521 """Stream records into PostgreSQL using batch inserts.""" 

522 self._check_connection() 

523 config = config or StreamConfig() 

524 result = StreamResult() 

525 start_time = time.time() 

526 

527 batch = [] 

528 async for record in records: 

529 batch.append(record) 

530 

531 if len(batch) >= config.batch_size: 

532 # Write batch 

533 try: 

534 await self._write_batch(batch) 

535 result.successful += len(batch) 

536 result.total_processed += len(batch) 

537 except Exception as e: 

538 result.failed += len(batch) 

539 result.total_processed += len(batch) 

540 if config.on_error: 

541 for rec in batch: 

542 if not config.on_error(e, rec): 

543 result.add_error(None, e) 

544 break 

545 else: 

546 result.add_error(None, e) 

547 

548 batch = [] 

549 

550 # Write remaining batch 

551 if batch: 

552 try: 

553 await self._write_batch(batch) 

554 result.successful += len(batch) 

555 result.total_processed += len(batch) 

556 except Exception as e: 

557 result.failed += len(batch) 

558 result.total_processed += len(batch) 

559 result.add_error(None, e) 

560 

561 result.duration = time.time() - start_time 

562 return result 

563 

564 async def _write_batch(self, records: list[Record]) -> None: 

565 """Write a batch of records using COPY for performance.""" 

566 if not records: 

567 return 

568 

569 # Prepare data for COPY 

570 rows = [] 

571 for record in records: 

572 row_data = self._record_to_row(record) 

573 rows.append(( 

574 row_data["id"], 

575 row_data["data"], 

576 row_data["metadata"] 

577 )) 

578 

579 # Use COPY for efficient bulk insert 

580 async with self._pool.acquire() as conn: 

581 await conn.copy_records_to_table( 

582 f"{self._pool_config.schema}.{self._pool_config.table}", 

583 records=rows, 

584 columns=["id", "data", "metadata"] 

585 )