Coverage for src/dataknobs_data/utils/pool_manager.py: 96%
90 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-15 14:36 -0500
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-15 14:36 -0500
1"""General-purpose connection pool management utilities for async database connections."""
3import asyncio
4import logging
5from typing import Dict, Optional, Any, Protocol, TypeVar, Generic
6from weakref import WeakValueDictionary
7from dataclasses import dataclass
8from abc import abstractmethod
10logger = logging.getLogger(__name__)
13class PoolProtocol(Protocol):
14 """Protocol for connection pools."""
16 async def acquire(self):
17 """Acquire a connection from the pool."""
18 ...
20 async def close(self):
21 """Close the pool."""
22 ...
25PoolType = TypeVar('PoolType', bound=PoolProtocol)
28@dataclass
29class BasePoolConfig:
30 """Base configuration for connection pools."""
32 @abstractmethod
33 def to_connection_string(self) -> str:
34 """Convert configuration to a connection string."""
35 ...
37 @abstractmethod
38 def to_hash_key(self) -> tuple:
39 """Create a hashable key for this configuration."""
40 ...
43class ConnectionPoolManager(Generic[PoolType]):
44 """
45 Generic connection pool manager that handles pools per event loop.
47 This class ensures that each event loop gets its own connection pool,
48 preventing cross-loop usage errors that can occur with async connections.
50 Type Parameters:
51 PoolType: The type of pool being managed (e.g., asyncpg.Pool)
52 """
54 def __init__(self):
55 """Initialize the connection pool manager."""
56 # Map of (config_hash, loop_id) -> pool
57 self._pools: Dict[tuple, PoolType] = {}
58 # Weak references to event loops for cleanup
59 self._loop_refs: WeakValueDictionary = WeakValueDictionary()
61 async def get_pool(
62 self,
63 config: BasePoolConfig,
64 create_pool_func,
65 validate_pool_func=None
66 ) -> PoolType:
67 """
68 Get or create a connection pool for the current event loop.
70 Args:
71 config: Pool configuration
72 create_pool_func: Async function to create a new pool
73 validate_pool_func: Optional async function to validate an existing pool
75 Returns:
76 Pool instance for the current event loop
77 """
78 loop = asyncio.get_running_loop()
79 loop_id = id(loop)
80 config_hash = hash(config.to_hash_key())
81 pool_key = (config_hash, loop_id)
83 # Check if we already have a pool for this config and loop
84 if pool_key in self._pools:
85 pool = self._pools[pool_key]
87 # Validate the pool if validation function provided
88 if validate_pool_func:
89 try:
90 await validate_pool_func(pool)
91 return pool
92 except Exception as e:
93 logger.warning(f"Pool for loop {loop_id} is invalid: {e}. Creating new one.")
94 await self._close_pool(pool_key)
95 else:
96 return pool
98 # Create new pool
99 logger.info(f"Creating new connection pool for loop {loop_id}")
100 pool = await create_pool_func(config)
102 # Store pool and loop reference
103 self._pools[pool_key] = pool
104 self._loop_refs[loop_id] = loop
106 return pool
108 async def _close_pool(self, pool_key: tuple):
109 """Close and remove a pool."""
110 if pool_key in self._pools:
111 pool = self._pools[pool_key]
112 try:
113 await pool.close()
114 except Exception as e:
115 logger.error(f"Error closing pool: {e}")
116 finally:
117 del self._pools[pool_key]
119 async def remove_pool(self, config: BasePoolConfig) -> bool:
120 """
121 Remove a pool for the current event loop.
123 Args:
124 config: Pool configuration
126 Returns:
127 True if pool was removed, False if not found
128 """
129 loop_id = id(asyncio.get_running_loop())
130 config_hash = hash(config.to_hash_key())
131 pool_key = (config_hash, loop_id)
133 if pool_key in self._pools:
134 await self._close_pool(pool_key)
135 return True
136 return False
138 async def close_all(self):
139 """Close all connection pools."""
140 for pool_key in list(self._pools.keys()):
141 await self._close_pool(pool_key)
143 def get_pool_count(self) -> int:
144 """Get the number of active pools."""
145 return len(self._pools)
147 def get_pool_info(self) -> Dict[str, Any]:
148 """Get information about all active pools."""
149 info = {}
150 for (config_hash, loop_id), pool in self._pools.items():
151 key = f"config_{config_hash}_loop_{loop_id}"
152 info[key] = {
153 "loop_id": loop_id,
154 "config_hash": config_hash,
155 "pool": str(pool)
156 }
157 return info
160# PostgreSQL-specific implementation
161@dataclass
162class PostgresPoolConfig(BasePoolConfig):
163 """Configuration for PostgreSQL connection pools."""
164 host: str = "localhost"
165 port: int = 5432
166 database: str = "postgres"
167 user: str = "postgres"
168 password: str = ""
169 min_size: int = 10
170 max_size: int = 10
171 command_timeout: Optional[float] = None
172 ssl: Optional[Any] = None
174 def to_connection_string(self) -> str:
175 """Convert to PostgreSQL connection string."""
176 return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
178 def to_hash_key(self) -> tuple:
179 """Create a hashable key for this configuration."""
180 return (self.host, self.port, self.database, self.user)
182 @classmethod
183 def from_dict(cls, config: dict) -> "PostgresPoolConfig":
184 """Create from configuration dictionary."""
185 return cls(
186 host=config.get("host", "localhost"),
187 port=config.get("port", 5432),
188 database=config.get("database", "postgres"),
189 user=config.get("user", "postgres"),
190 password=config.get("password", ""),
191 min_size=config.get("min_pool_size", 10),
192 max_size=config.get("max_pool_size", 10),
193 command_timeout=config.get("command_timeout"),
194 ssl=config.get("ssl")
195 )
198async def create_asyncpg_pool(config: PostgresPoolConfig):
199 """Create an asyncpg connection pool."""
200 import asyncpg
201 return await asyncpg.create_pool(
202 config.to_connection_string(),
203 min_size=config.min_size,
204 max_size=config.max_size,
205 command_timeout=config.command_timeout,
206 ssl=config.ssl
207 )
210async def validate_asyncpg_pool(pool) -> None:
211 """Validate an asyncpg pool by running a simple query."""
212 async with pool.acquire() as conn:
213 await conn.fetchval("SELECT 1")