Coverage for src/dataknobs_data/pooling/base.py: 25%
101 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 14:14 -0600
1"""Base classes for generic connection pool management."""
3from __future__ import annotations
5import asyncio
6import atexit
7import logging
8from abc import abstractmethod
9from typing import Any, Generic, Protocol, TypeVar, TYPE_CHECKING
10from weakref import WeakValueDictionary
12if TYPE_CHECKING:
13 from collections.abc import Awaitable, Callable
16logger = logging.getLogger(__name__)
19class PoolProtocol(Protocol):
20 """Protocol for connection pools."""
22 async def acquire(self):
23 """Acquire a connection from the pool."""
24 ...
26 async def close(self):
27 """Close the pool."""
28 ...
31PoolType = TypeVar('PoolType', bound=PoolProtocol)
34class BasePoolConfig:
35 """Base configuration for connection pools."""
37 @abstractmethod
38 def to_connection_string(self) -> str:
39 """Convert configuration to a connection string."""
40 ...
42 @abstractmethod
43 def to_hash_key(self) -> tuple:
44 """Create a hashable key for this configuration."""
45 ...
48class ConnectionPoolManager(Generic[PoolType]):
49 """Generic connection pool manager that handles pools per event loop.
51 This class ensures that each event loop gets its own connection pool,
52 preventing cross-loop usage errors that can occur with async connections.
53 """
55 def __init__(self):
56 """Initialize the connection pool manager."""
57 # Map of (config_hash, loop_id) -> pool or (pool, close_func)
58 self._pools: dict[tuple, PoolType | tuple[PoolType, Callable | None]] = {}
59 # Weak references to event loops for cleanup
60 self._loop_refs: WeakValueDictionary = WeakValueDictionary()
61 # Register cleanup on exit
62 atexit.register(self._cleanup_on_exit)
64 async def get_pool(
65 self,
66 config: BasePoolConfig,
67 create_pool_func: Callable[[BasePoolConfig], Awaitable[PoolType]],
68 validate_pool_func: Callable[[PoolType], Awaitable[None]] | None = None,
69 close_pool_func: Callable[[PoolType], Awaitable[None]] | None = None
70 ) -> PoolType:
71 """Get or create a connection pool for the current event loop.
73 Args:
74 config: Pool configuration
75 create_pool_func: Async function to create a new pool
76 validate_pool_func: Optional async function to validate an existing pool
77 close_pool_func: Optional async function to close a pool
79 Returns:
80 Pool instance for the current event loop
81 """
82 loop = asyncio.get_running_loop()
83 loop_id = id(loop)
84 config_hash = hash(config.to_hash_key())
85 pool_key = (config_hash, loop_id)
87 # Check if we already have a pool for this config and loop
88 if pool_key in self._pools:
89 pool_entry = self._pools[pool_key]
90 # Handle both old and new format
91 if isinstance(pool_entry, tuple):
92 pool, _ = pool_entry
93 else:
94 # Non-tuple format (backward compatibility)
95 pool = pool_entry
97 # Validate the pool if validation function provided
98 if validate_pool_func:
99 try:
100 await validate_pool_func(pool)
101 return pool
102 except Exception as e:
103 logger.warning(f"Pool for loop {loop_id} is invalid: {e}. Creating new one.")
104 await self._close_pool(pool_key, close_pool_func)
105 else:
106 return pool
108 # Create new pool
109 logger.info(f"Creating new connection pool for loop {loop_id}")
110 pool = await create_pool_func(config)
112 # Store pool and loop reference with close function
113 self._pools[pool_key] = (pool, close_pool_func)
114 self._loop_refs[loop_id] = loop
116 return pool
118 async def _close_pool(self, pool_key: tuple, close_func: Callable | None = None):
119 """Close and remove a pool."""
120 if pool_key in self._pools:
121 pool_entry = self._pools[pool_key]
122 # Handle both old format (pool) and new format (pool, close_func)
123 if isinstance(pool_entry, tuple):
124 pool, stored_close_func = pool_entry
125 close_func = close_func or stored_close_func
126 else:
127 # Non-tuple format (backward compatibility)
128 pool = pool_entry
130 try:
131 # Check if we have a running event loop
132 try:
133 loop = asyncio.get_running_loop()
134 if loop.is_closed():
135 # Event loop is closed, skip async cleanup
136 return
137 except RuntimeError:
138 # No running event loop, skip async cleanup
139 return
141 if close_func:
142 await close_func(pool)
143 elif hasattr(pool, 'close'):
144 await pool.close()
145 except RuntimeError as e:
146 # Silently ignore "Event loop is closed" errors
147 if "Event loop is closed" not in str(e):
148 logger.error(f"Error closing pool: {e}")
149 except Exception as e:
150 logger.error(f"Error closing pool: {e}")
151 finally:
152 del self._pools[pool_key]
154 async def remove_pool(self, config: BasePoolConfig) -> bool:
155 """Remove a pool for the current event loop.
157 Args:
158 config: Pool configuration
160 Returns:
161 True if pool was removed, False if not found
162 """
163 loop_id = id(asyncio.get_running_loop())
164 config_hash = hash(config.to_hash_key())
165 pool_key = (config_hash, loop_id)
167 if pool_key in self._pools:
168 await self._close_pool(pool_key)
169 return True
170 return False
172 async def close_all(self):
173 """Close all connection pools."""
174 for pool_key in list(self._pools.keys()):
175 await self._close_pool(pool_key)
177 def get_pool_count(self) -> int:
178 """Get the number of active pools."""
179 return len(self._pools)
181 def get_pool_info(self) -> dict[str, Any]:
182 """Get information about all active pools."""
183 info = {}
184 for (config_hash, loop_id), pool_entry in self._pools.items():
185 # Handle both old and new format
186 if isinstance(pool_entry, tuple):
187 pool, _ = pool_entry
188 else:
189 # Non-tuple format (backward compatibility)
190 pool = pool_entry
192 key = f"config_{config_hash}_loop_{loop_id}"
193 info[key] = {
194 "loop_id": loop_id,
195 "config_hash": config_hash,
196 "pool": str(pool)
197 }
198 return info
200 def _cleanup_on_exit(self):
201 """Cleanup function called on program exit."""
202 if self._pools:
203 logger.debug(f"Cleaning up {len(self._pools)} connection pools on exit")
204 # Try to get any running loop
205 try:
206 loop = asyncio.get_running_loop()
207 except RuntimeError:
208 # No running loop, try to create one
209 loop = asyncio.new_event_loop()
210 asyncio.set_event_loop(loop)
211 try:
212 loop.run_until_complete(self.close_all())
213 finally:
214 loop.close()
215 else:
216 # There's a running loop, schedule cleanup
217 asyncio.create_task(self.close_all())