Coverage for session_mgmt_mcp/utils/database_pool.py: 15.91%
142 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-01 05:22 -0700
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-01 05:22 -0700
1#!/usr/bin/env python3
2"""Database connection pooling for DuckDB.
4This module provides efficient connection pooling and management for DuckDB operations.
5"""
7import asyncio
8import atexit
9import threading
10from contextlib import asynccontextmanager
11from pathlib import Path
12from typing import Any
14try:
15 import duckdb
17 DUCKDB_AVAILABLE = True
18except ImportError:
19 DUCKDB_AVAILABLE = False
20 duckdb = None
22from .logging import get_session_logger
24logger = get_session_logger()
27class DatabaseConnectionPool:
28 """Thread-safe connection pool for DuckDB."""
30 def __init__(self, db_path: str, max_connections: int = 5) -> None:
31 self.db_path = db_path
32 self.max_connections = max_connections
33 self._pool: list = []
34 self._pool_lock = threading.Lock()
35 self._active_connections: dict[int, Any] = {}
36 self._executor = None
37 self._closed = False
39 # Ensure database directory exists
40 Path(db_path).parent.mkdir(parents=True, exist_ok=True)
42 # Register cleanup on exit
43 atexit.register(self.close_all)
45 def _create_connection(self):
46 """Create a new DuckDB connection."""
47 if not DUCKDB_AVAILABLE:
48 msg = "DuckDB not available"
49 raise ImportError(msg)
51 try:
52 conn = duckdb.connect(self.db_path)
53 # Set optimal pragmas for performance
54 conn.execute("PRAGMA threads=4")
55 conn.execute("PRAGMA memory_limit='1GB'")
56 conn.execute("PRAGMA temp_directory='/tmp'")
57 return conn
58 except Exception as e:
59 logger.exception(f"Failed to create database connection: {e}")
60 raise
62 def get_connection(self):
63 """Get a connection from the pool or create a new one."""
64 if self._closed:
65 msg = "Connection pool is closed"
66 raise RuntimeError(msg)
68 with self._pool_lock:
69 if self._pool:
70 conn = self._pool.pop()
71 self._active_connections[id(conn)] = conn
72 return conn
73 if len(self._active_connections) < self.max_connections:
74 conn = self._create_connection()
75 self._active_connections[id(conn)] = conn
76 return conn
77 msg = f"Maximum connections ({self.max_connections}) reached"
78 raise RuntimeError(
79 msg,
80 )
82 def return_connection(self, conn) -> None:
83 """Return a connection to the pool."""
84 if self._closed or not conn:
85 return
87 with self._pool_lock:
88 conn_id = id(conn)
89 if conn_id in self._active_connections:
90 del self._active_connections[conn_id]
91 if len(self._pool) < self.max_connections:
92 self._pool.append(conn)
93 else:
94 try:
95 conn.close()
96 except Exception as e:
97 logger.warning(f"Error closing excess connection: {e}")
99 @asynccontextmanager
100 async def get_async_connection(self):
101 """Async context manager for getting database connections."""
102 conn = None
103 try:
104 # Get connection in executor to avoid blocking
105 loop = asyncio.get_event_loop()
106 conn = await loop.run_in_executor(self._get_executor(), self.get_connection)
107 yield conn
108 except Exception as e:
109 logger.exception(f"Database connection error: {e}")
110 raise
111 finally:
112 if conn:
113 # Return connection in executor
114 loop = asyncio.get_event_loop()
115 await loop.run_in_executor(
116 self._get_executor(),
117 self.return_connection,
118 conn,
119 )
121 def _get_executor(self):
122 """Get or create thread pool executor."""
123 if self._executor is None:
124 self._executor = asyncio.ThreadPoolExecutor(max_workers=2)
125 return self._executor
127 async def execute_query(self, query: str, parameters: tuple | None = None):
128 """Execute a query using a pooled connection."""
129 async with self.get_async_connection() as conn:
130 loop = asyncio.get_event_loop()
132 def _execute():
133 try:
134 if parameters:
135 return conn.execute(query, parameters).fetchall()
136 return conn.execute(query).fetchall()
137 except Exception as e:
138 logger.exception(
139 f"Query execution error: {e}",
140 extra={"query": query[:100]},
141 )
142 raise
144 return await loop.run_in_executor(self._get_executor(), _execute)
146 async def execute_many(self, query: str, parameter_list: list):
147 """Execute a query multiple times with different parameters."""
148 async with self.get_async_connection() as conn:
149 loop = asyncio.get_event_loop()
151 def _execute_many():
152 try:
153 results = []
154 for params in parameter_list:
155 result = conn.execute(query, params).fetchall()
156 results.append(result)
157 return results
158 except Exception as e:
159 logger.exception(f"Batch query execution error: {e}")
160 raise
162 return await loop.run_in_executor(self._get_executor(), _execute_many)
164 def get_stats(self) -> dict[str, Any]:
165 """Get connection pool statistics."""
166 with self._pool_lock:
167 return {
168 "total_connections": len(self._active_connections) + len(self._pool),
169 "active_connections": len(self._active_connections),
170 "pooled_connections": len(self._pool),
171 "max_connections": self.max_connections,
172 "pool_utilization": len(self._active_connections)
173 / self.max_connections,
174 "db_path": self.db_path,
175 }
177 def close_all(self) -> None:
178 """Close all connections and clean up."""
179 if self._closed:
180 return
182 self._closed = True
184 with self._pool_lock:
185 # Close pooled connections
186 for conn in self._pool:
187 try:
188 conn.close()
189 except Exception as e:
190 logger.warning(f"Error closing pooled connection: {e}")
192 # Close active connections
193 for conn in self._active_connections.values():
194 try:
195 conn.close()
196 except Exception as e:
197 logger.warning(f"Error closing active connection: {e}")
199 self._pool.clear()
200 self._active_connections.clear()
202 # Shutdown executor
203 if self._executor:
204 self._executor.shutdown(wait=True)
205 self._executor = None
207 logger.info("Database connection pool closed")
210# Global connection pool instance
211_connection_pools: dict[str, DatabaseConnectionPool] = {}
212_pools_lock = threading.Lock()
215def get_database_pool(db_path: str, max_connections: int = 5) -> DatabaseConnectionPool:
216 """Get or create a database connection pool for the given path."""
217 with _pools_lock:
218 if db_path not in _connection_pools:
219 _connection_pools[db_path] = DatabaseConnectionPool(
220 db_path,
221 max_connections,
222 )
223 return _connection_pools[db_path]
226def close_all_pools() -> None:
227 """Close all database connection pools."""
228 with _pools_lock:
229 for pool in _connection_pools.values():
230 pool.close_all()
231 _connection_pools.clear()
234# Register cleanup on module exit
235atexit.register(close_all_pools)