Coverage for session_buddy / utils / database_pool.py: 16.67%

154 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-04 00:43 -0800

1#!/usr/bin/env python3 

2"""Database connection pooling for DuckDB. 

3 

4This module provides efficient connection pooling and management for DuckDB operations. 

5""" 

6 

7import asyncio 

8import atexit 

9import threading 

10import typing as t 

11from collections.abc import AsyncIterator 

12from concurrent.futures import ThreadPoolExecutor 

13from contextlib import asynccontextmanager 

14from pathlib import Path 

15from typing import Any 

16 

17try: 

18 import duckdb 

19 

20 DUCKDB_AVAILABLE = True 

21except ImportError: 

22 DUCKDB_AVAILABLE = False 

23 duckdb: t.Any = None # type: ignore[assignment, no-redef] 

24 

25from .logging import get_session_logger 

26 

27# Lazy-load logger to avoid DI initialization issues during imports 

28_logger: t.Any = None 

29 

30 

31def _get_logger() -> t.Any: 

32 """Get logger instance, initializing on first use.""" 

33 global _logger 

34 if _logger is None: 

35 try: 

36 _logger = get_session_logger() 

37 except Exception: 

38 # Fallback to basic logging if DI not initialized 

39 import logging 

40 

41 _logger = logging.getLogger(__name__) 

42 return _logger 

43 

44 

45class DatabaseConnectionPool: 

46 """Thread-safe connection pool for DuckDB.""" 

47 

48 def __init__(self, db_path: str, max_connections: int = 5) -> None: 

49 self.db_path = db_path 

50 self.max_connections = max_connections 

51 self._pool: list[Any] = [] 

52 self._pool_lock = threading.Lock() 

53 self._active_connections: dict[int, Any] = {} 

54 self._executor: ThreadPoolExecutor | None = None 

55 self._closed = False 

56 

57 # Ensure database directory exists 

58 Path(db_path).parent.mkdir(parents=True, exist_ok=True) 

59 

60 # Register cleanup on exit 

61 atexit.register(self.close_all) 

62 

63 def _create_connection(self) -> Any: 

64 """Create a new DuckDB connection.""" 

65 if not DUCKDB_AVAILABLE: 

66 msg = "DuckDB not available" 

67 raise ImportError(msg) 

68 

69 try: 

70 conn = duckdb.connect(self.db_path) if duckdb else None 

71 # Set optimal pragmas for performance 

72 if conn: 

73 conn.execute("PRAGMA threads=4") 

74 conn.execute("PRAGMA memory_limit='1GB'") 

75 conn.execute("PRAGMA temp_directory='/tmp'") 

76 return conn 

77 except Exception as e: 

78 _get_logger().exception(f"Failed to create database connection: {e}") 

79 raise 

80 

81 def get_connection(self) -> Any: 

82 """Get a connection from the pool or create a new one.""" 

83 if self._closed: 

84 msg = "Connection pool is closed" 

85 raise RuntimeError(msg) 

86 

87 with self._pool_lock: 

88 if self._pool: 

89 conn = self._pool.pop() 

90 self._active_connections[id(conn)] = conn 

91 return conn 

92 if len(self._active_connections) < self.max_connections: 

93 conn = self._create_connection() 

94 self._active_connections[id(conn)] = conn 

95 return conn 

96 msg = f"Maximum connections ({self.max_connections}) reached" 

97 raise RuntimeError( 

98 msg, 

99 ) 

100 

101 def return_connection(self, conn: Any) -> None: 

102 """Return a connection to the pool.""" 

103 if self._closed or not conn: 

104 return 

105 

106 with self._pool_lock: 

107 conn_id = id(conn) 

108 if conn_id in self._active_connections: 

109 del self._active_connections[conn_id] 

110 if len(self._pool) < self.max_connections: 

111 self._pool.append(conn) 

112 else: 

113 try: 

114 conn.close() 

115 except Exception as e: 

116 _get_logger().warning(f"Error closing excess connection: {e}") 

117 

118 @asynccontextmanager 

119 async def get_async_connection(self) -> AsyncIterator[Any]: 

120 """Async context manager for getting database connections.""" 

121 conn = None 

122 try: 

123 # Get connection in executor to avoid blocking 

124 loop = asyncio.get_event_loop() 

125 conn = await loop.run_in_executor(self._get_executor(), self.get_connection) 

126 yield conn 

127 except Exception as e: 

128 _get_logger().exception(f"Database connection error: {e}") 

129 raise 

130 finally: 

131 if conn: 

132 # Return connection in executor 

133 loop = asyncio.get_event_loop() 

134 await loop.run_in_executor( 

135 self._get_executor(), 

136 self.return_connection, 

137 conn, 

138 ) 

139 

140 def _get_executor(self) -> Any: 

141 """Get or create thread pool executor.""" 

142 if self._executor is None: 

143 self._executor = ThreadPoolExecutor(max_workers=2) 

144 return self._executor 

145 

146 async def execute_query( 

147 self, 

148 query: str, 

149 parameters: tuple[Any, ...] | None = None, 

150 ) -> Any: 

151 """Execute a query using a pooled connection.""" 

152 async with self.get_async_connection() as conn: 

153 loop = asyncio.get_event_loop() 

154 

155 def _execute() -> Any: 

156 try: 

157 if parameters: 

158 return conn.execute(query, parameters).fetchall() 

159 return conn.execute(query).fetchall() 

160 except Exception as e: 

161 _get_logger().exception(f"Query execution failed: {e}") 

162 raise 

163 

164 return await loop.run_in_executor(self._get_executor(), _execute) 

165 

166 async def execute_many(self, query: str, parameter_list: list[Any]) -> Any: 

167 """Execute a query multiple times with different parameters.""" 

168 async with self.get_async_connection() as conn: 

169 loop = asyncio.get_event_loop() 

170 

171 def _execute_many() -> Any: 

172 try: 

173 results = [] 

174 for params in parameter_list: 

175 result = conn.execute(query, params).fetchall() 

176 results.append(result) 

177 return results 

178 except Exception as e: 

179 _get_logger().exception(f"Batch query execution error: {e}") 

180 raise 

181 

182 return await loop.run_in_executor(self._get_executor(), _execute_many) 

183 

184 def get_stats(self) -> dict[str, Any]: 

185 """Get connection pool statistics.""" 

186 with self._pool_lock: 

187 return { 

188 "total_connections": len(self._active_connections) + len(self._pool), 

189 "active_connections": len(self._active_connections), 

190 "pooled_connections": len(self._pool), 

191 "max_connections": self.max_connections, 

192 "pool_utilization": len(self._active_connections) 

193 / self.max_connections, 

194 "db_path": self.db_path, 

195 } 

196 

197 def close_all(self) -> None: 

198 """Close all connections and clean up.""" 

199 if self._closed: 

200 return 

201 

202 self._closed = True 

203 

204 with self._pool_lock: 

205 # Close pooled connections 

206 for conn in self._pool: 

207 try: 

208 conn.close() 

209 except Exception as e: 

210 _get_logger().warning(f"Error closing pooled connection: {e}") 

211 

212 # Close active connections 

213 for conn in self._active_connections.values(): 

214 try: 

215 conn.close() 

216 except Exception as e: 

217 _get_logger().warning(f"Error closing active connection: {e}") 

218 

219 self._pool.clear() 

220 self._active_connections.clear() 

221 

222 # Shutdown executor 

223 if self._executor: 

224 self._executor.shutdown(wait=True) 

225 self._executor = None 

226 

227 _get_logger().info("Database connection pool closed") 

228 

229 

230# Global connection pool instance 

231_connection_pools: dict[str, DatabaseConnectionPool] = {} 

232_pools_lock = threading.Lock() 

233 

234 

235def get_database_pool(db_path: str, max_connections: int = 5) -> DatabaseConnectionPool: 

236 """Get or create a database connection pool for the given path.""" 

237 with _pools_lock: 

238 if db_path not in _connection_pools: 

239 _connection_pools[db_path] = DatabaseConnectionPool( 

240 db_path, 

241 max_connections, 

242 ) 

243 return _connection_pools[db_path] 

244 

245 

246def close_all_pools() -> None: 

247 """Close all database connection pools.""" 

248 with _pools_lock: 

249 for pool in _connection_pools.values(): 

250 pool.close_all() 

251 _connection_pools.clear() 

252 

253 

254# Register cleanup on module exit 

255atexit.register(close_all_pools)