Coverage for session_buddy / utils / database_pool.py: 16.67%
154 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-04 00:43 -0800
1#!/usr/bin/env python3
2"""Database connection pooling for DuckDB.
4This module provides efficient connection pooling and management for DuckDB operations.
5"""
7import asyncio
8import atexit
9import threading
10import typing as t
11from collections.abc import AsyncIterator
12from concurrent.futures import ThreadPoolExecutor
13from contextlib import asynccontextmanager
14from pathlib import Path
15from typing import Any
17try:
18 import duckdb
20 DUCKDB_AVAILABLE = True
21except ImportError:
22 DUCKDB_AVAILABLE = False
23 duckdb: t.Any = None # type: ignore[assignment, no-redef]
25from .logging import get_session_logger
27# Lazy-load logger to avoid DI initialization issues during imports
28_logger: t.Any = None
31def _get_logger() -> t.Any:
32 """Get logger instance, initializing on first use."""
33 global _logger
34 if _logger is None:
35 try:
36 _logger = get_session_logger()
37 except Exception:
38 # Fallback to basic logging if DI not initialized
39 import logging
41 _logger = logging.getLogger(__name__)
42 return _logger
45class DatabaseConnectionPool:
46 """Thread-safe connection pool for DuckDB."""
48 def __init__(self, db_path: str, max_connections: int = 5) -> None:
49 self.db_path = db_path
50 self.max_connections = max_connections
51 self._pool: list[Any] = []
52 self._pool_lock = threading.Lock()
53 self._active_connections: dict[int, Any] = {}
54 self._executor: ThreadPoolExecutor | None = None
55 self._closed = False
57 # Ensure database directory exists
58 Path(db_path).parent.mkdir(parents=True, exist_ok=True)
60 # Register cleanup on exit
61 atexit.register(self.close_all)
63 def _create_connection(self) -> Any:
64 """Create a new DuckDB connection."""
65 if not DUCKDB_AVAILABLE:
66 msg = "DuckDB not available"
67 raise ImportError(msg)
69 try:
70 conn = duckdb.connect(self.db_path) if duckdb else None
71 # Set optimal pragmas for performance
72 if conn:
73 conn.execute("PRAGMA threads=4")
74 conn.execute("PRAGMA memory_limit='1GB'")
75 conn.execute("PRAGMA temp_directory='/tmp'")
76 return conn
77 except Exception as e:
78 _get_logger().exception(f"Failed to create database connection: {e}")
79 raise
81 def get_connection(self) -> Any:
82 """Get a connection from the pool or create a new one."""
83 if self._closed:
84 msg = "Connection pool is closed"
85 raise RuntimeError(msg)
87 with self._pool_lock:
88 if self._pool:
89 conn = self._pool.pop()
90 self._active_connections[id(conn)] = conn
91 return conn
92 if len(self._active_connections) < self.max_connections:
93 conn = self._create_connection()
94 self._active_connections[id(conn)] = conn
95 return conn
96 msg = f"Maximum connections ({self.max_connections}) reached"
97 raise RuntimeError(
98 msg,
99 )
101 def return_connection(self, conn: Any) -> None:
102 """Return a connection to the pool."""
103 if self._closed or not conn:
104 return
106 with self._pool_lock:
107 conn_id = id(conn)
108 if conn_id in self._active_connections:
109 del self._active_connections[conn_id]
110 if len(self._pool) < self.max_connections:
111 self._pool.append(conn)
112 else:
113 try:
114 conn.close()
115 except Exception as e:
116 _get_logger().warning(f"Error closing excess connection: {e}")
118 @asynccontextmanager
119 async def get_async_connection(self) -> AsyncIterator[Any]:
120 """Async context manager for getting database connections."""
121 conn = None
122 try:
123 # Get connection in executor to avoid blocking
124 loop = asyncio.get_event_loop()
125 conn = await loop.run_in_executor(self._get_executor(), self.get_connection)
126 yield conn
127 except Exception as e:
128 _get_logger().exception(f"Database connection error: {e}")
129 raise
130 finally:
131 if conn:
132 # Return connection in executor
133 loop = asyncio.get_event_loop()
134 await loop.run_in_executor(
135 self._get_executor(),
136 self.return_connection,
137 conn,
138 )
140 def _get_executor(self) -> Any:
141 """Get or create thread pool executor."""
142 if self._executor is None:
143 self._executor = ThreadPoolExecutor(max_workers=2)
144 return self._executor
146 async def execute_query(
147 self,
148 query: str,
149 parameters: tuple[Any, ...] | None = None,
150 ) -> Any:
151 """Execute a query using a pooled connection."""
152 async with self.get_async_connection() as conn:
153 loop = asyncio.get_event_loop()
155 def _execute() -> Any:
156 try:
157 if parameters:
158 return conn.execute(query, parameters).fetchall()
159 return conn.execute(query).fetchall()
160 except Exception as e:
161 _get_logger().exception(f"Query execution failed: {e}")
162 raise
164 return await loop.run_in_executor(self._get_executor(), _execute)
166 async def execute_many(self, query: str, parameter_list: list[Any]) -> Any:
167 """Execute a query multiple times with different parameters."""
168 async with self.get_async_connection() as conn:
169 loop = asyncio.get_event_loop()
171 def _execute_many() -> Any:
172 try:
173 results = []
174 for params in parameter_list:
175 result = conn.execute(query, params).fetchall()
176 results.append(result)
177 return results
178 except Exception as e:
179 _get_logger().exception(f"Batch query execution error: {e}")
180 raise
182 return await loop.run_in_executor(self._get_executor(), _execute_many)
184 def get_stats(self) -> dict[str, Any]:
185 """Get connection pool statistics."""
186 with self._pool_lock:
187 return {
188 "total_connections": len(self._active_connections) + len(self._pool),
189 "active_connections": len(self._active_connections),
190 "pooled_connections": len(self._pool),
191 "max_connections": self.max_connections,
192 "pool_utilization": len(self._active_connections)
193 / self.max_connections,
194 "db_path": self.db_path,
195 }
197 def close_all(self) -> None:
198 """Close all connections and clean up."""
199 if self._closed:
200 return
202 self._closed = True
204 with self._pool_lock:
205 # Close pooled connections
206 for conn in self._pool:
207 try:
208 conn.close()
209 except Exception as e:
210 _get_logger().warning(f"Error closing pooled connection: {e}")
212 # Close active connections
213 for conn in self._active_connections.values():
214 try:
215 conn.close()
216 except Exception as e:
217 _get_logger().warning(f"Error closing active connection: {e}")
219 self._pool.clear()
220 self._active_connections.clear()
222 # Shutdown executor
223 if self._executor:
224 self._executor.shutdown(wait=True)
225 self._executor = None
227 _get_logger().info("Database connection pool closed")
230# Global connection pool instance
231_connection_pools: dict[str, DatabaseConnectionPool] = {}
232_pools_lock = threading.Lock()
235def get_database_pool(db_path: str, max_connections: int = 5) -> DatabaseConnectionPool:
236 """Get or create a database connection pool for the given path."""
237 with _pools_lock:
238 if db_path not in _connection_pools:
239 _connection_pools[db_path] = DatabaseConnectionPool(
240 db_path,
241 max_connections,
242 )
243 return _connection_pools[db_path]
246def close_all_pools() -> None:
247 """Close all database connection pools."""
248 with _pools_lock:
249 for pool in _connection_pools.values():
250 pool.close_all()
251 _connection_pools.clear()
254# Register cleanup on module exit
255atexit.register(close_all_pools)