Coverage for src / dataknobs_data / pooling / base.py: 25%

101 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-26 15:45 -0700

1"""Base classes for generic connection pool management.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import atexit 

7import logging 

8from abc import abstractmethod 

9from typing import Any, Generic, Protocol, TypeVar, TYPE_CHECKING 

10from weakref import WeakValueDictionary 

11 

12if TYPE_CHECKING: 

13 from collections.abc import Awaitable, Callable 

14 

15 

16logger = logging.getLogger(__name__) 

17 

18 

19class PoolProtocol(Protocol): 

20 """Protocol for connection pools.""" 

21 

22 async def acquire(self): 

23 """Acquire a connection from the pool.""" 

24 ... 

25 

26 async def close(self): 

27 """Close the pool.""" 

28 ... 

29 

30 

31PoolType = TypeVar('PoolType', bound=PoolProtocol) 

32 

33 

34class BasePoolConfig: 

35 """Base configuration for connection pools.""" 

36 

37 @abstractmethod 

38 def to_connection_string(self) -> str: 

39 """Convert configuration to a connection string.""" 

40 ... 

41 

42 @abstractmethod 

43 def to_hash_key(self) -> tuple: 

44 """Create a hashable key for this configuration.""" 

45 ... 

46 

47 

48class ConnectionPoolManager(Generic[PoolType]): 

49 """Generic connection pool manager that handles pools per event loop. 

50  

51 This class ensures that each event loop gets its own connection pool, 

52 preventing cross-loop usage errors that can occur with async connections. 

53 """ 

54 

55 def __init__(self): 

56 """Initialize the connection pool manager.""" 

57 # Map of (config_hash, loop_id) -> pool or (pool, close_func) 

58 self._pools: dict[tuple, PoolType | tuple[PoolType, Callable | None]] = {} 

59 # Weak references to event loops for cleanup 

60 self._loop_refs: WeakValueDictionary = WeakValueDictionary() 

61 # Register cleanup on exit 

62 atexit.register(self._cleanup_on_exit) 

63 

64 async def get_pool( 

65 self, 

66 config: BasePoolConfig, 

67 create_pool_func: Callable[[BasePoolConfig], Awaitable[PoolType]], 

68 validate_pool_func: Callable[[PoolType], Awaitable[None]] | None = None, 

69 close_pool_func: Callable[[PoolType], Awaitable[None]] | None = None 

70 ) -> PoolType: 

71 """Get or create a connection pool for the current event loop. 

72  

73 Args: 

74 config: Pool configuration 

75 create_pool_func: Async function to create a new pool 

76 validate_pool_func: Optional async function to validate an existing pool 

77 close_pool_func: Optional async function to close a pool 

78  

79 Returns: 

80 Pool instance for the current event loop 

81 """ 

82 loop = asyncio.get_running_loop() 

83 loop_id = id(loop) 

84 config_hash = hash(config.to_hash_key()) 

85 pool_key = (config_hash, loop_id) 

86 

87 # Check if we already have a pool for this config and loop 

88 if pool_key in self._pools: 

89 pool_entry = self._pools[pool_key] 

90 # Handle both old and new format 

91 if isinstance(pool_entry, tuple): 

92 pool, _ = pool_entry 

93 else: 

94 # Non-tuple format (backward compatibility) 

95 pool = pool_entry 

96 

97 # Validate the pool if validation function provided 

98 if validate_pool_func: 

99 try: 

100 await validate_pool_func(pool) 

101 return pool 

102 except Exception as e: 

103 logger.warning(f"Pool for loop {loop_id} is invalid: {e}. Creating new one.") 

104 await self._close_pool(pool_key, close_pool_func) 

105 else: 

106 return pool 

107 

108 # Create new pool 

109 logger.info(f"Creating new connection pool for loop {loop_id}") 

110 pool = await create_pool_func(config) 

111 

112 # Store pool and loop reference with close function 

113 self._pools[pool_key] = (pool, close_pool_func) 

114 self._loop_refs[loop_id] = loop 

115 

116 return pool 

117 

118 async def _close_pool(self, pool_key: tuple, close_func: Callable | None = None): 

119 """Close and remove a pool.""" 

120 if pool_key in self._pools: 

121 pool_entry = self._pools[pool_key] 

122 # Handle both old format (pool) and new format (pool, close_func) 

123 if isinstance(pool_entry, tuple): 

124 pool, stored_close_func = pool_entry 

125 close_func = close_func or stored_close_func 

126 else: 

127 # Non-tuple format (backward compatibility) 

128 pool = pool_entry 

129 

130 try: 

131 # Check if we have a running event loop 

132 try: 

133 loop = asyncio.get_running_loop() 

134 if loop.is_closed(): 

135 # Event loop is closed, skip async cleanup 

136 return 

137 except RuntimeError: 

138 # No running event loop, skip async cleanup 

139 return 

140 

141 if close_func: 

142 await close_func(pool) 

143 elif hasattr(pool, 'close'): 

144 await pool.close() 

145 except RuntimeError as e: 

146 # Silently ignore "Event loop is closed" errors 

147 if "Event loop is closed" not in str(e): 

148 logger.error(f"Error closing pool: {e}") 

149 except Exception as e: 

150 logger.error(f"Error closing pool: {e}") 

151 finally: 

152 del self._pools[pool_key] 

153 

154 async def remove_pool(self, config: BasePoolConfig) -> bool: 

155 """Remove a pool for the current event loop. 

156  

157 Args: 

158 config: Pool configuration 

159  

160 Returns: 

161 True if pool was removed, False if not found 

162 """ 

163 loop_id = id(asyncio.get_running_loop()) 

164 config_hash = hash(config.to_hash_key()) 

165 pool_key = (config_hash, loop_id) 

166 

167 if pool_key in self._pools: 

168 await self._close_pool(pool_key) 

169 return True 

170 return False 

171 

172 async def close_all(self): 

173 """Close all connection pools.""" 

174 for pool_key in list(self._pools.keys()): 

175 await self._close_pool(pool_key) 

176 

177 def get_pool_count(self) -> int: 

178 """Get the number of active pools.""" 

179 return len(self._pools) 

180 

181 def get_pool_info(self) -> dict[str, Any]: 

182 """Get information about all active pools.""" 

183 info = {} 

184 for (config_hash, loop_id), pool_entry in self._pools.items(): 

185 # Handle both old and new format 

186 if isinstance(pool_entry, tuple): 

187 pool, _ = pool_entry 

188 else: 

189 # Non-tuple format (backward compatibility) 

190 pool = pool_entry 

191 

192 key = f"config_{config_hash}_loop_{loop_id}" 

193 info[key] = { 

194 "loop_id": loop_id, 

195 "config_hash": config_hash, 

196 "pool": str(pool) 

197 } 

198 return info 

199 

200 def _cleanup_on_exit(self): 

201 """Cleanup function called on program exit.""" 

202 if self._pools: 

203 logger.debug(f"Cleaning up {len(self._pools)} connection pools on exit") 

204 # Try to get any running loop 

205 try: 

206 loop = asyncio.get_running_loop() 

207 except RuntimeError: 

208 # No running loop, try to create one 

209 loop = asyncio.new_event_loop() 

210 asyncio.set_event_loop(loop) 

211 try: 

212 loop.run_until_complete(self.close_all()) 

213 finally: 

214 loop.close() 

215 else: 

216 # There's a running loop, schedule cleanup 

217 asyncio.create_task(self.close_all())