Coverage for session_mgmt_mcp/utils/database_pool.py: 15.91%

142 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-01 05:22 -0700

1#!/usr/bin/env python3 

2"""Database connection pooling for DuckDB. 

3 

4This module provides efficient connection pooling and management for DuckDB operations. 

5""" 

6 

7import asyncio 

8import atexit 

9import threading 

10from contextlib import asynccontextmanager 

11from pathlib import Path 

12from typing import Any 

13 

14try: 

15 import duckdb 

16 

17 DUCKDB_AVAILABLE = True 

18except ImportError: 

19 DUCKDB_AVAILABLE = False 

20 duckdb = None 

21 

22from .logging import get_session_logger 

23 

24logger = get_session_logger() 

25 

26 

27class DatabaseConnectionPool: 

28 """Thread-safe connection pool for DuckDB.""" 

29 

30 def __init__(self, db_path: str, max_connections: int = 5) -> None: 

31 self.db_path = db_path 

32 self.max_connections = max_connections 

33 self._pool: list = [] 

34 self._pool_lock = threading.Lock() 

35 self._active_connections: dict[int, Any] = {} 

36 self._executor = None 

37 self._closed = False 

38 

39 # Ensure database directory exists 

40 Path(db_path).parent.mkdir(parents=True, exist_ok=True) 

41 

42 # Register cleanup on exit 

43 atexit.register(self.close_all) 

44 

45 def _create_connection(self): 

46 """Create a new DuckDB connection.""" 

47 if not DUCKDB_AVAILABLE: 

48 msg = "DuckDB not available" 

49 raise ImportError(msg) 

50 

51 try: 

52 conn = duckdb.connect(self.db_path) 

53 # Set optimal pragmas for performance 

54 conn.execute("PRAGMA threads=4") 

55 conn.execute("PRAGMA memory_limit='1GB'") 

56 conn.execute("PRAGMA temp_directory='/tmp'") 

57 return conn 

58 except Exception as e: 

59 logger.exception(f"Failed to create database connection: {e}") 

60 raise 

61 

62 def get_connection(self): 

63 """Get a connection from the pool or create a new one.""" 

64 if self._closed: 

65 msg = "Connection pool is closed" 

66 raise RuntimeError(msg) 

67 

68 with self._pool_lock: 

69 if self._pool: 

70 conn = self._pool.pop() 

71 self._active_connections[id(conn)] = conn 

72 return conn 

73 if len(self._active_connections) < self.max_connections: 

74 conn = self._create_connection() 

75 self._active_connections[id(conn)] = conn 

76 return conn 

77 msg = f"Maximum connections ({self.max_connections}) reached" 

78 raise RuntimeError( 

79 msg, 

80 ) 

81 

82 def return_connection(self, conn) -> None: 

83 """Return a connection to the pool.""" 

84 if self._closed or not conn: 

85 return 

86 

87 with self._pool_lock: 

88 conn_id = id(conn) 

89 if conn_id in self._active_connections: 

90 del self._active_connections[conn_id] 

91 if len(self._pool) < self.max_connections: 

92 self._pool.append(conn) 

93 else: 

94 try: 

95 conn.close() 

96 except Exception as e: 

97 logger.warning(f"Error closing excess connection: {e}") 

98 

99 @asynccontextmanager 

100 async def get_async_connection(self): 

101 """Async context manager for getting database connections.""" 

102 conn = None 

103 try: 

104 # Get connection in executor to avoid blocking 

105 loop = asyncio.get_event_loop() 

106 conn = await loop.run_in_executor(self._get_executor(), self.get_connection) 

107 yield conn 

108 except Exception as e: 

109 logger.exception(f"Database connection error: {e}") 

110 raise 

111 finally: 

112 if conn: 

113 # Return connection in executor 

114 loop = asyncio.get_event_loop() 

115 await loop.run_in_executor( 

116 self._get_executor(), 

117 self.return_connection, 

118 conn, 

119 ) 

120 

121 def _get_executor(self): 

122 """Get or create thread pool executor.""" 

123 if self._executor is None: 

124 self._executor = asyncio.ThreadPoolExecutor(max_workers=2) 

125 return self._executor 

126 

127 async def execute_query(self, query: str, parameters: tuple | None = None): 

128 """Execute a query using a pooled connection.""" 

129 async with self.get_async_connection() as conn: 

130 loop = asyncio.get_event_loop() 

131 

132 def _execute(): 

133 try: 

134 if parameters: 

135 return conn.execute(query, parameters).fetchall() 

136 return conn.execute(query).fetchall() 

137 except Exception as e: 

138 logger.exception( 

139 f"Query execution error: {e}", 

140 extra={"query": query[:100]}, 

141 ) 

142 raise 

143 

144 return await loop.run_in_executor(self._get_executor(), _execute) 

145 

146 async def execute_many(self, query: str, parameter_list: list): 

147 """Execute a query multiple times with different parameters.""" 

148 async with self.get_async_connection() as conn: 

149 loop = asyncio.get_event_loop() 

150 

151 def _execute_many(): 

152 try: 

153 results = [] 

154 for params in parameter_list: 

155 result = conn.execute(query, params).fetchall() 

156 results.append(result) 

157 return results 

158 except Exception as e: 

159 logger.exception(f"Batch query execution error: {e}") 

160 raise 

161 

162 return await loop.run_in_executor(self._get_executor(), _execute_many) 

163 

164 def get_stats(self) -> dict[str, Any]: 

165 """Get connection pool statistics.""" 

166 with self._pool_lock: 

167 return { 

168 "total_connections": len(self._active_connections) + len(self._pool), 

169 "active_connections": len(self._active_connections), 

170 "pooled_connections": len(self._pool), 

171 "max_connections": self.max_connections, 

172 "pool_utilization": len(self._active_connections) 

173 / self.max_connections, 

174 "db_path": self.db_path, 

175 } 

176 

177 def close_all(self) -> None: 

178 """Close all connections and clean up.""" 

179 if self._closed: 

180 return 

181 

182 self._closed = True 

183 

184 with self._pool_lock: 

185 # Close pooled connections 

186 for conn in self._pool: 

187 try: 

188 conn.close() 

189 except Exception as e: 

190 logger.warning(f"Error closing pooled connection: {e}") 

191 

192 # Close active connections 

193 for conn in self._active_connections.values(): 

194 try: 

195 conn.close() 

196 except Exception as e: 

197 logger.warning(f"Error closing active connection: {e}") 

198 

199 self._pool.clear() 

200 self._active_connections.clear() 

201 

202 # Shutdown executor 

203 if self._executor: 

204 self._executor.shutdown(wait=True) 

205 self._executor = None 

206 

207 logger.info("Database connection pool closed") 

208 

209 

210# Global connection pool instance 

211_connection_pools: dict[str, DatabaseConnectionPool] = {} 

212_pools_lock = threading.Lock() 

213 

214 

215def get_database_pool(db_path: str, max_connections: int = 5) -> DatabaseConnectionPool: 

216 """Get or create a database connection pool for the given path.""" 

217 with _pools_lock: 

218 if db_path not in _connection_pools: 

219 _connection_pools[db_path] = DatabaseConnectionPool( 

220 db_path, 

221 max_connections, 

222 ) 

223 return _connection_pools[db_path] 

224 

225 

226def close_all_pools() -> None: 

227 """Close all database connection pools.""" 

228 with _pools_lock: 

229 for pool in _connection_pools.values(): 

230 pool.close_all() 

231 _connection_pools.clear() 

232 

233 

234# Register cleanup on module exit 

235atexit.register(close_all_pools)