Coverage for src/dataknobs_data/utils/pool_manager.py: 96%

90 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-15 14:36 -0500

1"""General-purpose connection pool management utilities for async database connections.""" 

2 

3import asyncio 

4import logging 

5from typing import Dict, Optional, Any, Protocol, TypeVar, Generic 

6from weakref import WeakValueDictionary 

7from dataclasses import dataclass 

8from abc import abstractmethod 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13class PoolProtocol(Protocol): 

14 """Protocol for connection pools.""" 

15 

16 async def acquire(self): 

17 """Acquire a connection from the pool.""" 

18 ... 

19 

20 async def close(self): 

21 """Close the pool.""" 

22 ... 

23 

24 

25PoolType = TypeVar('PoolType', bound=PoolProtocol) 

26 

27 

28@dataclass 

29class BasePoolConfig: 

30 """Base configuration for connection pools.""" 

31 

32 @abstractmethod 

33 def to_connection_string(self) -> str: 

34 """Convert configuration to a connection string.""" 

35 ... 

36 

37 @abstractmethod 

38 def to_hash_key(self) -> tuple: 

39 """Create a hashable key for this configuration.""" 

40 ... 

41 

42 

43class ConnectionPoolManager(Generic[PoolType]): 

44 """ 

45 Generic connection pool manager that handles pools per event loop. 

46  

47 This class ensures that each event loop gets its own connection pool, 

48 preventing cross-loop usage errors that can occur with async connections. 

49  

50 Type Parameters: 

51 PoolType: The type of pool being managed (e.g., asyncpg.Pool) 

52 """ 

53 

54 def __init__(self): 

55 """Initialize the connection pool manager.""" 

56 # Map of (config_hash, loop_id) -> pool 

57 self._pools: Dict[tuple, PoolType] = {} 

58 # Weak references to event loops for cleanup 

59 self._loop_refs: WeakValueDictionary = WeakValueDictionary() 

60 

61 async def get_pool( 

62 self, 

63 config: BasePoolConfig, 

64 create_pool_func, 

65 validate_pool_func=None 

66 ) -> PoolType: 

67 """ 

68 Get or create a connection pool for the current event loop. 

69  

70 Args: 

71 config: Pool configuration 

72 create_pool_func: Async function to create a new pool 

73 validate_pool_func: Optional async function to validate an existing pool 

74  

75 Returns: 

76 Pool instance for the current event loop 

77 """ 

78 loop = asyncio.get_running_loop() 

79 loop_id = id(loop) 

80 config_hash = hash(config.to_hash_key()) 

81 pool_key = (config_hash, loop_id) 

82 

83 # Check if we already have a pool for this config and loop 

84 if pool_key in self._pools: 

85 pool = self._pools[pool_key] 

86 

87 # Validate the pool if validation function provided 

88 if validate_pool_func: 

89 try: 

90 await validate_pool_func(pool) 

91 return pool 

92 except Exception as e: 

93 logger.warning(f"Pool for loop {loop_id} is invalid: {e}. Creating new one.") 

94 await self._close_pool(pool_key) 

95 else: 

96 return pool 

97 

98 # Create new pool 

99 logger.info(f"Creating new connection pool for loop {loop_id}") 

100 pool = await create_pool_func(config) 

101 

102 # Store pool and loop reference 

103 self._pools[pool_key] = pool 

104 self._loop_refs[loop_id] = loop 

105 

106 return pool 

107 

108 async def _close_pool(self, pool_key: tuple): 

109 """Close and remove a pool.""" 

110 if pool_key in self._pools: 

111 pool = self._pools[pool_key] 

112 try: 

113 await pool.close() 

114 except Exception as e: 

115 logger.error(f"Error closing pool: {e}") 

116 finally: 

117 del self._pools[pool_key] 

118 

119 async def remove_pool(self, config: BasePoolConfig) -> bool: 

120 """ 

121 Remove a pool for the current event loop. 

122  

123 Args: 

124 config: Pool configuration 

125  

126 Returns: 

127 True if pool was removed, False if not found 

128 """ 

129 loop_id = id(asyncio.get_running_loop()) 

130 config_hash = hash(config.to_hash_key()) 

131 pool_key = (config_hash, loop_id) 

132 

133 if pool_key in self._pools: 

134 await self._close_pool(pool_key) 

135 return True 

136 return False 

137 

138 async def close_all(self): 

139 """Close all connection pools.""" 

140 for pool_key in list(self._pools.keys()): 

141 await self._close_pool(pool_key) 

142 

143 def get_pool_count(self) -> int: 

144 """Get the number of active pools.""" 

145 return len(self._pools) 

146 

147 def get_pool_info(self) -> Dict[str, Any]: 

148 """Get information about all active pools.""" 

149 info = {} 

150 for (config_hash, loop_id), pool in self._pools.items(): 

151 key = f"config_{config_hash}_loop_{loop_id}" 

152 info[key] = { 

153 "loop_id": loop_id, 

154 "config_hash": config_hash, 

155 "pool": str(pool) 

156 } 

157 return info 

158 

159 

160# PostgreSQL-specific implementation 

161@dataclass 

162class PostgresPoolConfig(BasePoolConfig): 

163 """Configuration for PostgreSQL connection pools.""" 

164 host: str = "localhost" 

165 port: int = 5432 

166 database: str = "postgres" 

167 user: str = "postgres" 

168 password: str = "" 

169 min_size: int = 10 

170 max_size: int = 10 

171 command_timeout: Optional[float] = None 

172 ssl: Optional[Any] = None 

173 

174 def to_connection_string(self) -> str: 

175 """Convert to PostgreSQL connection string.""" 

176 return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" 

177 

178 def to_hash_key(self) -> tuple: 

179 """Create a hashable key for this configuration.""" 

180 return (self.host, self.port, self.database, self.user) 

181 

182 @classmethod 

183 def from_dict(cls, config: dict) -> "PostgresPoolConfig": 

184 """Create from configuration dictionary.""" 

185 return cls( 

186 host=config.get("host", "localhost"), 

187 port=config.get("port", 5432), 

188 database=config.get("database", "postgres"), 

189 user=config.get("user", "postgres"), 

190 password=config.get("password", ""), 

191 min_size=config.get("min_pool_size", 10), 

192 max_size=config.get("max_pool_size", 10), 

193 command_timeout=config.get("command_timeout"), 

194 ssl=config.get("ssl") 

195 ) 

196 

197 

198async def create_asyncpg_pool(config: PostgresPoolConfig): 

199 """Create an asyncpg connection pool.""" 

200 import asyncpg 

201 return await asyncpg.create_pool( 

202 config.to_connection_string(), 

203 min_size=config.min_size, 

204 max_size=config.max_size, 

205 command_timeout=config.command_timeout, 

206 ssl=config.ssl 

207 ) 

208 

209 

210async def validate_asyncpg_pool(pool) -> None: 

211 """Validate an asyncpg pool by running a simple query.""" 

212 async with pool.acquire() as conn: 

213 await conn.fetchval("SELECT 1")