Coverage for src/dataknobs_data/backends/postgres_async.py: 70%
323 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-15 14:22 -0500
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-15 14:22 -0500
1"""Improved async PostgreSQL backend with native asyncpg and event loop-aware connection pooling."""
3import asyncio
4import asyncpg
5import json
6import time
7import uuid
8import logging
9from typing import Any, AsyncIterator, Optional, Dict
10from weakref import WeakValueDictionary
11from dataclasses import dataclass
13from dataknobs_config import ConfigurableBase
15from ..database import AsyncDatabase
16from ..query import Operator, Query, SortOrder
17from ..records import Record
18from ..streaming import StreamConfig, StreamResult
20logger = logging.getLogger(__name__)
23@dataclass
24class PoolConfig:
25 """Configuration for asyncpg connection pool."""
26 host: str = "localhost"
27 port: int = 5432
28 database: str = "postgres"
29 user: str = "postgres"
30 password: str = ""
31 table: str = "records"
32 schema: str = "public"
33 min_size: int = 10
34 max_size: int = 10
35 command_timeout: Optional[float] = None
36 ssl: Optional[Any] = None
38 @classmethod
39 def from_dict(cls, config: dict) -> "PoolConfig":
40 """Create from configuration dictionary."""
41 return cls(
42 host=config.get("host", "localhost"),
43 port=config.get("port", 5432),
44 database=config.get("database", "postgres"),
45 user=config.get("user", "postgres"),
46 password=config.get("password", ""),
47 table=config.get("table", "records"),
48 schema=config.get("schema", "public"),
49 min_size=config.get("min_pool_size", 10),
50 max_size=config.get("max_pool_size", 10),
51 command_timeout=config.get("command_timeout"),
52 ssl=config.get("ssl")
53 )
55 def to_connection_string(self) -> str:
56 """Convert to PostgreSQL connection string."""
57 return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
60class ConnectionPoolManager:
61 """
62 Manages connection pools per event loop.
64 This ensures each event loop gets its own connection pool,
65 preventing "Event loop is closed" errors.
66 """
68 def __init__(self):
69 """Initialize the connection pool manager."""
70 # Map of (config_hash, loop_id) -> pool
71 self._pools: Dict[tuple, asyncpg.Pool] = {}
72 # Weak references to event loops for cleanup
73 self._loop_refs: WeakValueDictionary = WeakValueDictionary()
75 async def get_pool(self, config: PoolConfig) -> asyncpg.Pool:
76 """
77 Get or create a connection pool for the current event loop.
79 Args:
80 config: Pool configuration
82 Returns:
83 asyncpg.Pool instance for the current event loop
84 """
85 loop = asyncio.get_running_loop()
86 loop_id = id(loop)
87 config_hash = hash((config.host, config.port, config.database, config.user))
88 pool_key = (config_hash, loop_id)
90 # Check if we already have a pool for this config and loop
91 if pool_key in self._pools:
92 pool = self._pools[pool_key]
93 # Verify the pool is still valid
94 try:
95 async with pool.acquire() as conn:
96 await conn.fetchval("SELECT 1")
97 return pool
98 except (asyncpg.InterfaceError, asyncpg.InternalClientError, Exception) as e:
99 logger.warning(f"Pool for loop {loop_id} is invalid: {e}. Creating new one.")
100 await self._close_pool(pool_key)
102 # Create new pool
103 logger.info(f"Creating new connection pool for loop {loop_id}")
104 pool = await asyncpg.create_pool(
105 config.to_connection_string(),
106 min_size=config.min_size,
107 max_size=config.max_size,
108 command_timeout=config.command_timeout,
109 ssl=config.ssl
110 )
112 # Store pool and loop reference
113 self._pools[pool_key] = pool
114 self._loop_refs[loop_id] = loop
116 return pool
118 async def _close_pool(self, pool_key: tuple):
119 """Close and remove a pool."""
120 if pool_key in self._pools:
121 pool = self._pools[pool_key]
122 try:
123 await pool.close()
124 except Exception as e:
125 logger.error(f"Error closing pool: {e}")
126 finally:
127 del self._pools[pool_key]
129 async def remove_pool(self, config: PoolConfig) -> bool:
130 """Remove a pool for the current event loop."""
131 loop_id = id(asyncio.get_running_loop())
132 config_hash = hash((config.host, config.port, config.database, config.user))
133 pool_key = (config_hash, loop_id)
135 if pool_key in self._pools:
136 await self._close_pool(pool_key)
137 return True
138 return False
140 async def close_all(self):
141 """Close all connection pools."""
142 for pool_key in list(self._pools.keys()):
143 await self._close_pool(pool_key)
146# Global pool manager instance
147_pool_manager = ConnectionPoolManager()
150class AsyncPostgresDatabase(AsyncDatabase, ConfigurableBase):
151 """Native async PostgreSQL database backend with event loop-aware connection pooling."""
153 def __init__(self, config: dict[str, Any] | None = None):
154 """Initialize async PostgreSQL database."""
155 super().__init__(config)
156 self._pool_config = PoolConfig.from_dict(config or {})
157 self._pool: Optional[asyncpg.Pool] = None
158 self._connected = False
160 @classmethod
161 def from_config(cls, config: dict) -> "AsyncPostgresDatabase":
162 """Create from config dictionary."""
163 return cls(config)
165 async def connect(self) -> None:
166 """Connect to the database."""
167 if self._connected:
168 return
170 # Get or create pool for current event loop
171 self._pool = await _pool_manager.get_pool(self._pool_config)
173 # Ensure table exists
174 await self._ensure_table()
175 self._connected = True
177 async def close(self) -> None:
178 """Close the database connection."""
179 if self._connected:
180 # Pool manager handles cleanup
181 self._pool = None
182 self._connected = False
184 def _initialize(self) -> None:
185 """Initialize is handled in connect."""
186 pass
188 async def _ensure_table(self) -> None:
189 """Ensure the records table exists."""
190 if not self._pool:
191 raise RuntimeError("Database not connected. Call connect() first.")
193 create_table_sql = f"""
194 CREATE TABLE IF NOT EXISTS {self._pool_config.schema}.{self._pool_config.table} (
195 id VARCHAR(255) PRIMARY KEY,
196 data JSONB NOT NULL,
197 metadata JSONB,
198 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
199 updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
200 );
202 CREATE INDEX IF NOT EXISTS idx_{self._pool_config.table}_data
203 ON {self._pool_config.schema}.{self._pool_config.table} USING GIN (data);
205 CREATE INDEX IF NOT EXISTS idx_{self._pool_config.table}_metadata
206 ON {self._pool_config.schema}.{self._pool_config.table} USING GIN (metadata);
207 """
209 async with self._pool.acquire() as conn:
210 await conn.execute(create_table_sql)
212 def _check_connection(self) -> None:
213 """Check if database is connected."""
214 if not self._connected or not self._pool:
215 raise RuntimeError("Database not connected. Call connect() first.")
217 def _record_to_row(self, record: Record, id: str | None = None) -> dict[str, Any]:
218 """Convert a Record to a database row."""
219 data = {}
220 for field_name, field_obj in record.fields.items():
221 data[field_name] = field_obj.value
223 return {
224 "id": id or str(uuid.uuid4()),
225 "data": json.dumps(data),
226 "metadata": json.dumps(record.metadata) if record.metadata else None,
227 }
229 def _row_to_record(self, row: asyncpg.Record) -> Record:
230 """Convert a database row to a Record."""
231 data = row.get("data", {})
232 if isinstance(data, str):
233 data = json.loads(data)
235 metadata = row.get("metadata", {})
236 if isinstance(metadata, str) and metadata:
237 metadata = json.loads(metadata)
238 elif not metadata:
239 metadata = {}
241 return Record(data=data, metadata=metadata)
243 async def create(self, record: Record) -> str:
244 """Create a new record."""
245 self._check_connection()
246 id = str(uuid.uuid4())
247 row = self._record_to_row(record, id)
249 sql = f"""
250 INSERT INTO {self._pool_config.schema}.{self._pool_config.table} (id, data, metadata)
251 VALUES ($1, $2, $3)
252 """
254 async with self._pool.acquire() as conn:
255 await conn.execute(sql, row["id"], row["data"], row["metadata"])
257 return id
259 async def read(self, id: str) -> Record | None:
260 """Read a record by ID."""
261 self._check_connection()
262 sql = f"""
263 SELECT id, data, metadata
264 FROM {self._pool_config.schema}.{self._pool_config.table}
265 WHERE id = $1
266 """
268 async with self._pool.acquire() as conn:
269 row = await conn.fetchrow(sql, id)
271 if not row:
272 return None
274 return self._row_to_record(row)
276 async def update(self, id: str, record: Record) -> bool:
277 """Update an existing record."""
278 self._check_connection()
279 row = self._record_to_row(record, id)
281 sql = f"""
282 UPDATE {self._pool_config.schema}.{self._pool_config.table}
283 SET data = $2, metadata = $3, updated_at = CURRENT_TIMESTAMP
284 WHERE id = $1
285 """
287 async with self._pool.acquire() as conn:
288 result = await conn.execute(sql, row["id"], row["data"], row["metadata"])
290 # Returns UPDATE n where n is rows affected
291 return result.split()[-1] != "0"
293 async def delete(self, id: str) -> bool:
294 """Delete a record by ID."""
295 self._check_connection()
296 sql = f"""
297 DELETE FROM {self._pool_config.schema}.{self._pool_config.table}
298 WHERE id = $1
299 """
301 async with self._pool.acquire() as conn:
302 result = await conn.execute(sql, id)
304 # Returns DELETE n where n is rows affected
305 return result.split()[-1] != "0"
307 async def exists(self, id: str) -> bool:
308 """Check if a record exists."""
309 self._check_connection()
310 sql = f"""
311 SELECT 1 FROM {self._pool_config.schema}.{self._pool_config.table}
312 WHERE id = $1
313 LIMIT 1
314 """
316 async with self._pool.acquire() as conn:
317 row = await conn.fetchrow(sql, id)
319 return row is not None
321 async def upsert(self, id: str, record: Record) -> str:
322 """Update or insert a record with a specific ID."""
323 self._check_connection()
324 row = self._record_to_row(record, id)
326 sql = f"""
327 INSERT INTO {self._pool_config.schema}.{self._pool_config.table} (id, data, metadata)
328 VALUES ($1, $2, $3)
329 ON CONFLICT (id) DO UPDATE
330 SET data = EXCLUDED.data, metadata = EXCLUDED.metadata, updated_at = CURRENT_TIMESTAMP
331 """
333 async with self._pool.acquire() as conn:
334 await conn.execute(sql, row["id"], row["data"], row["metadata"])
336 return id
338 async def search(self, query: Query) -> list[Record]:
339 """Search for records matching the query."""
340 self._check_connection()
342 # Build SQL query from Query object
343 where_clauses = []
344 params = []
345 param_count = 0
347 # Build WHERE clauses for filters
348 for filter in query.filters:
349 param_count += 1
350 field_path = f"data->>'{filter.field}'"
352 if filter.operator == Operator.EQ:
353 if isinstance(filter.value, bool):
354 where_clauses.append(f"({field_path})::boolean = ${param_count}")
355 params.append(filter.value)
356 elif isinstance(filter.value, (int, float)):
357 where_clauses.append(f"({field_path})::numeric = ${param_count}")
358 params.append(filter.value)
359 else:
360 where_clauses.append(f"{field_path} = ${param_count}")
361 params.append(str(filter.value))
362 elif filter.operator == Operator.NEQ:
363 if isinstance(filter.value, bool):
364 where_clauses.append(f"({field_path})::boolean != ${param_count}")
365 params.append(filter.value)
366 elif isinstance(filter.value, (int, float)):
367 where_clauses.append(f"({field_path})::numeric != ${param_count}")
368 params.append(filter.value)
369 else:
370 where_clauses.append(f"{field_path} != ${param_count}")
371 params.append(str(filter.value))
372 elif filter.operator == Operator.GT:
373 where_clauses.append(f"({field_path})::numeric > ${param_count}")
374 params.append(filter.value)
375 elif filter.operator == Operator.LT:
376 where_clauses.append(f"({field_path})::numeric < ${param_count}")
377 params.append(filter.value)
378 elif filter.operator == Operator.GTE:
379 where_clauses.append(f"({field_path})::numeric >= ${param_count}")
380 params.append(filter.value)
381 elif filter.operator == Operator.LTE:
382 where_clauses.append(f"({field_path})::numeric <= ${param_count}")
383 params.append(filter.value)
384 elif filter.operator == Operator.LIKE:
385 where_clauses.append(f"{field_path} LIKE ${param_count}")
386 params.append(f"%{filter.value}%")
387 elif filter.operator == Operator.IN:
388 values = [str(v) for v in filter.value]
389 where_clauses.append(f"{field_path} = ANY(${param_count})")
390 params.append(values)
391 elif filter.operator == Operator.NOT_IN:
392 values = [str(v) for v in filter.value]
393 where_clauses.append(f"{field_path} != ALL(${param_count})")
394 params.append(values)
396 # Build SQL
397 sql = f"SELECT id, data, metadata FROM {self._pool_config.schema}.{self._pool_config.table}"
398 if where_clauses:
399 sql += " WHERE " + " AND ".join(where_clauses)
401 # Add ORDER BY
402 if query.sort_specs:
403 order_clauses = []
404 for sort_spec in query.sort_specs:
405 field_path = f"data->>'{sort_spec.field}'"
406 direction = "DESC" if sort_spec.order == SortOrder.DESC else "ASC"
407 # Handle numeric sorting
408 order_clause = f"""
409 CASE
410 WHEN {field_path} ~ '^[0-9]+(\\.[0-9]+)?$'
411 THEN ({field_path})::numeric
412 ELSE NULL
413 END {direction} NULLS LAST,
414 {field_path} {direction}
415 """
416 order_clauses.append(order_clause)
417 sql += " ORDER BY " + ", ".join(order_clauses)
419 # Add LIMIT and OFFSET
420 if query.limit_value:
421 sql += f" LIMIT {query.limit_value}"
422 if query.offset_value:
423 sql += f" OFFSET {query.offset_value}"
425 # Execute query
426 async with self._pool.acquire() as conn:
427 rows = await conn.fetch(sql, *params)
429 # Convert to records
430 records = []
431 for row in rows:
432 record = self._row_to_record(row)
434 # Apply field projection if specified
435 if query.fields:
436 record = record.project(query.fields)
438 records.append(record)
440 return records
442 async def _count_all(self) -> int:
443 """Count all records in the database."""
444 self._check_connection()
445 sql = f"SELECT COUNT(*) as count FROM {self._pool_config.schema}.{self._pool_config.table}"
447 async with self._pool.acquire() as conn:
448 row = await conn.fetchrow(sql)
450 return row["count"] if row else 0
452 async def clear(self) -> int:
453 """Clear all records from the database."""
454 self._check_connection()
455 # Get count first
456 count = await self._count_all()
458 # Delete all records
459 sql = f"TRUNCATE TABLE {self._pool_config.schema}.{self._pool_config.table}"
461 async with self._pool.acquire() as conn:
462 await conn.execute(sql)
464 return count
466 async def stream_read(
467 self,
468 query: Optional[Query] = None,
469 config: Optional[StreamConfig] = None
470 ) -> AsyncIterator[Record]:
471 """Stream records from PostgreSQL using cursor."""
472 self._check_connection()
473 config = config or StreamConfig()
475 # Build SQL query
476 sql = f"SELECT id, data, metadata FROM {self._pool_config.schema}.{self._pool_config.table}"
477 params = []
479 if query and query.filters:
480 where_clauses = []
481 param_count = 0
483 for filter in query.filters:
484 param_count += 1
485 field_path = f"data->>'{filter.field}'"
487 if filter.operator == Operator.EQ:
488 where_clauses.append(f"{field_path} = ${param_count}")
489 params.append(str(filter.value))
491 if where_clauses:
492 sql += " WHERE " + " AND ".join(where_clauses)
494 # Use cursor for efficient streaming
495 async with self._pool.acquire() as conn:
496 async with conn.transaction():
497 cursor = await conn.cursor(sql, *params)
499 batch = []
500 async for row in cursor:
501 record = self._row_to_record(row)
502 if query and query.fields:
503 record = record.project(query.fields)
505 batch.append(record)
507 if len(batch) >= config.batch_size:
508 for rec in batch:
509 yield rec
510 batch = []
512 # Yield remaining records
513 for rec in batch:
514 yield rec
516 async def stream_write(
517 self,
518 records: AsyncIterator[Record],
519 config: Optional[StreamConfig] = None
520 ) -> StreamResult:
521 """Stream records into PostgreSQL using batch inserts."""
522 self._check_connection()
523 config = config or StreamConfig()
524 result = StreamResult()
525 start_time = time.time()
527 batch = []
528 async for record in records:
529 batch.append(record)
531 if len(batch) >= config.batch_size:
532 # Write batch
533 try:
534 await self._write_batch(batch)
535 result.successful += len(batch)
536 result.total_processed += len(batch)
537 except Exception as e:
538 result.failed += len(batch)
539 result.total_processed += len(batch)
540 if config.on_error:
541 for rec in batch:
542 if not config.on_error(e, rec):
543 result.add_error(None, e)
544 break
545 else:
546 result.add_error(None, e)
548 batch = []
550 # Write remaining batch
551 if batch:
552 try:
553 await self._write_batch(batch)
554 result.successful += len(batch)
555 result.total_processed += len(batch)
556 except Exception as e:
557 result.failed += len(batch)
558 result.total_processed += len(batch)
559 result.add_error(None, e)
561 result.duration = time.time() - start_time
562 return result
564 async def _write_batch(self, records: list[Record]) -> None:
565 """Write a batch of records using COPY for performance."""
566 if not records:
567 return
569 # Prepare data for COPY
570 rows = []
571 for record in records:
572 row_data = self._record_to_row(record)
573 rows.append((
574 row_data["id"],
575 row_data["data"],
576 row_data["metadata"]
577 ))
579 # Use COPY for efficient bulk insert
580 async with self._pool.acquire() as conn:
581 await conn.copy_records_to_table(
582 f"{self._pool_config.schema}.{self._pool_config.table}",
583 records=rows,
584 columns=["id", "data", "metadata"]
585 )