Coverage for src/dataknobs_data/backends/postgres_mixins.py: 44%
62 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:23 -0700
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 11:23 -0700
1"""Shared mixins for PostgreSQL database backends.
3These mixins provide common functionality for both sync and async PostgreSQL implementations,
4reducing code duplication and ensuring consistent behavior.
5"""
7import logging
8from typing import Any
10from ..records import Record
11from .vector_config_mixin import VectorConfigMixin
13logger = logging.getLogger(__name__)
16class PostgresBaseConfig(VectorConfigMixin):
17 """Shared configuration logic for PostgreSQL backends."""
19 def _parse_postgres_config(self, config: dict[str, Any]) -> tuple[str, str, dict]:
20 """Extract table, schema, and connection configuration.
22 Args:
23 config: Configuration dictionary
25 Returns:
26 Tuple of (table_name, schema_name, connection_config)
27 """
28 config = config.copy() if config else {}
30 # Parse vector configuration using the mixin
31 self._parse_vector_config(config)
33 # Extract PostgreSQL-specific configuration
34 table_name = config.pop("table", config.pop("table_name", "records"))
35 schema_name = config.pop("schema", config.pop("schema_name", "public"))
37 # Remove vector config parameters since they've been processed
38 config.pop("vector_enabled", None)
39 config.pop("vector_metric", None)
41 return table_name, schema_name, config
43 def _init_postgres_attributes(self, table_name: str, schema_name: str) -> None:
44 """Initialize common PostgreSQL attributes.
46 Args:
47 table_name: Name of the database table
48 schema_name: Name of the database schema
49 """
50 self.table_name = table_name
51 self.schema_name = schema_name
52 self._connected = False
54 # Initialize vector state using the mixin
55 self._init_vector_state()
58class PostgresTableManager:
59 """Shared table management SQL and logic."""
61 @staticmethod
62 def get_create_table_sql(schema_name: str, table_name: str) -> str:
63 """Get SQL for creating the records table with indexes.
65 Args:
66 schema_name: Database schema name
67 table_name: Database table name
69 Returns:
70 SQL string for table creation
71 """
72 return f"""
73 CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (
74 id TEXT PRIMARY KEY,
75 data JSONB NOT NULL,
76 metadata JSONB,
77 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
78 updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
79 );
81 CREATE INDEX IF NOT EXISTS idx_{table_name}_data
82 ON {schema_name}.{table_name} USING GIN (data);
84 CREATE INDEX IF NOT EXISTS idx_{table_name}_metadata
85 ON {schema_name}.{table_name} USING GIN (metadata);
86 """
88 @staticmethod
89 def get_table_exists_sql(schema_name: str, table_name: str) -> str:
90 """Get SQL to check if table exists.
92 Args:
93 schema_name: Database schema name
94 table_name: Database table name
96 Returns:
97 SQL string to check table existence
98 """
99 return f"""
100 SELECT EXISTS (
101 SELECT FROM information_schema.tables
102 WHERE table_schema = '{schema_name}'
103 AND table_name = '{table_name}'
104 )
105 """
108class PostgresVectorSupport:
109 """Shared vector support detection and management."""
111 def _has_vector_fields(self, record: Record) -> bool:
112 """Check if record has vector fields.
114 Args:
115 record: Record to check
117 Returns:
118 True if record has vector fields
119 """
120 from ..fields import VectorField
121 return any(isinstance(field, VectorField)
122 for field in record.fields.values())
124 def _extract_vector_dimensions(self, record: Record) -> dict[str, int]:
125 """Extract dimensions from vector fields in a record.
127 Args:
128 record: Record containing potential vector fields
130 Returns:
131 Dictionary mapping field names to dimensions
132 """
133 from ..fields import VectorField
134 dimensions = {}
135 for name, field in record.fields.items():
136 if isinstance(field, VectorField) and field.dimensions:
137 dimensions[name] = field.dimensions
138 return dimensions
140 def _update_vector_dimensions(self, record: Record) -> None:
141 """Update tracked vector dimensions from a record.
143 Args:
144 record: Record containing vector fields
145 """
146 if hasattr(self, '_vector_dimensions'):
147 dimensions = self._extract_vector_dimensions(record)
148 self._vector_dimensions.update(dimensions)
151class PostgresErrorHandler:
152 """Shared error handling logic for PostgreSQL operations."""
154 @staticmethod
155 def handle_connection_error(e: Exception) -> None:
156 """Handle and log connection errors consistently.
158 Args:
159 e: The exception that occurred
161 Raises:
162 RuntimeError: With a user-friendly message
163 """
164 logger.error(f"PostgreSQL connection error: {e}")
165 raise RuntimeError(f"Database connection failed: {e}")
167 @staticmethod
168 def handle_query_error(e: Exception, operation: str) -> None:
169 """Handle and log query execution errors.
171 Args:
172 e: The exception that occurred
173 operation: The operation that failed (e.g., "create", "update")
175 Raises:
176 RuntimeError: With a user-friendly message
177 """
178 logger.error(f"PostgreSQL {operation} error: {e}")
179 raise RuntimeError(f"Database {operation} failed: {e}")
181 @staticmethod
182 def log_operation(operation: str, details: str = "") -> None:
183 """Log a database operation for debugging.
185 Args:
186 operation: The operation being performed
187 details: Additional details about the operation
188 """
189 if details:
190 logger.debug(f"PostgreSQL {operation}: {details}")
191 else:
192 logger.debug(f"PostgreSQL {operation}")
195class PostgresConnectionValidator:
196 """Shared connection validation logic."""
198 def _check_connection(self) -> None:
199 """Check if database is connected.
201 Raises:
202 RuntimeError: If not connected
203 """
204 if not getattr(self, '_connected', False):
205 raise RuntimeError("Database not connected. Call connect() first.")
207 def _check_async_connection(self) -> None:
208 """Check if async database is connected with pool.
210 Raises:
211 RuntimeError: If not connected or pool not initialized
212 """
213 if not getattr(self, '_connected', False) or not getattr(self, '_pool', None):
214 raise RuntimeError("Database not connected. Call connect() first.")