Coverage for src/dataknobs_data/backends/postgres_refactored.py: 0%
319 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-15 10:46 -0500
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-15 10:46 -0500
1"""PostgreSQL backend implementation with proper connection management."""
3import asyncio
4import json
5import time
6import uuid
7from typing import Any, AsyncIterator, Iterator, Optional
9from dataknobs_config import ConfigurableBase
10from dataknobs_utils.sql_utils import DotenvPostgresConnector, PostgresDB
12from ..database import Database, SyncDatabase
13from ..query import Operator, Query, SortOrder
14from ..records import Record
15from ..streaming import StreamConfig, StreamResult
18class SyncPostgresDatabase(SyncDatabase, ConfigurableBase):
19 """Synchronous PostgreSQL database backend with proper connection management."""
21 def __init__(self, config: dict[str, Any] | None = None):
22 """Initialize PostgreSQL database configuration.
24 Args:
25 config: Configuration with the following optional keys:
26 - host: PostgreSQL host (default: from env/localhost)
27 - port: PostgreSQL port (default: 5432)
28 - database: Database name (default: from env/postgres)
29 - user: Username (default: from env/postgres)
30 - password: Password (default: from env)
31 - table: Table name (default: "records")
32 - schema: Schema name (default: "public")
33 """
34 super().__init__(config)
35 self.db = None # Will be initialized in connect()
36 self._connected = False
38 @classmethod
39 def from_config(cls, config: dict) -> "SyncPostgresDatabase":
40 """Create from config dictionary."""
41 return cls(config)
43 def connect(self) -> None:
44 """Connect to the PostgreSQL database."""
45 if self._connected:
46 return # Already connected
48 config = self.config.copy()
50 # Extract table configuration
51 self.table_name = config.pop("table", "records")
52 self.schema_name = config.pop("schema", "public")
54 # Create connection using existing utilities
55 if not any(key in config for key in ["host", "database", "user"]):
56 # Use dotenv connector for environment-based config
57 connector = DotenvPostgresConnector()
58 self.db = PostgresDB(connector)
59 else:
60 # Direct configuration - map 'database' to 'db' for PostgresDB
61 self.db = PostgresDB(
62 host=config.get("host", "localhost"),
63 db=config.get("database", "postgres"), # Note: PostgresDB expects 'db' not 'database'
64 user=config.get("user", "postgres"),
65 pwd=config.get("password"), # Note: PostgresDB expects 'pwd' not 'password'
66 port=config.get("port", 5432),
67 )
69 # Create table if it doesn't exist
70 self._ensure_table()
71 self._connected = True
73 def close(self) -> None:
74 """Close the database connection."""
75 if self.db:
76 # PostgresDB manages its own connections via context managers
77 # but we can mark as disconnected
78 self._connected = False
80 def _initialize(self) -> None:
81 """Initialize method - connection setup moved to connect()."""
82 # Configuration parsing stays here, actual connection in connect()
83 pass
85 def _ensure_table(self) -> None:
86 """Ensure the records table exists."""
87 if not self.db:
88 raise RuntimeError("Database not connected. Call connect() first.")
90 create_table_sql = f"""
91 CREATE TABLE IF NOT EXISTS {self.schema_name}.{self.table_name} (
92 id VARCHAR(255) PRIMARY KEY,
93 data JSONB NOT NULL,
94 metadata JSONB,
95 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
96 updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
97 );
99 CREATE INDEX IF NOT EXISTS idx_{self.table_name}_data
100 ON {self.schema_name}.{self.table_name} USING GIN (data);
102 CREATE INDEX IF NOT EXISTS idx_{self.table_name}_metadata
103 ON {self.schema_name}.{self.table_name} USING GIN (metadata);
104 """
105 self.db.execute(create_table_sql)
107 def _check_connection(self) -> None:
108 """Check if database is connected."""
109 if not self._connected or not self.db:
110 raise RuntimeError("Database not connected. Call connect() first.")
112 def _record_to_row(self, record: Record, id: str | None = None) -> dict[str, Any]:
113 """Convert a Record to a database row."""
114 data = {}
115 for field_name, field_obj in record.fields.items():
116 data[field_name] = field_obj.value
118 return {
119 "id": id or str(uuid.uuid4()),
120 "data": json.dumps(data),
121 "metadata": json.dumps(record.metadata) if record.metadata else None,
122 }
124 def _row_to_record(self, row: dict[str, Any]) -> Record:
125 """Convert a database row to a Record."""
126 data = row.get("data", {})
127 if isinstance(data, str):
128 data = json.loads(data)
130 metadata = row.get("metadata", {})
131 if isinstance(metadata, str) and metadata:
132 metadata = json.loads(metadata)
133 elif not metadata:
134 metadata = {}
136 return Record(data=data, metadata=metadata)
138 def create(self, record: Record) -> str:
139 """Create a new record."""
140 self._check_connection()
141 id = str(uuid.uuid4())
142 row = self._record_to_row(record, id)
144 sql = f"""
145 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata)
146 VALUES (%(id)s, %(data)s, %(metadata)s)
147 """
148 self.db.execute(sql, row)
149 return id
151 def read(self, id: str) -> Record | None:
152 """Read a record by ID."""
153 self._check_connection()
154 sql = f"""
155 SELECT id, data, metadata
156 FROM {self.schema_name}.{self.table_name}
157 WHERE id = %(id)s
158 """
159 df = self.db.query(sql, {"id": id})
161 if df.empty:
162 return None
164 row = df.iloc[0].to_dict()
165 return self._row_to_record(row)
167 def update(self, id: str, record: Record) -> bool:
168 """Update an existing record."""
169 self._check_connection()
170 row = self._record_to_row(record, id)
172 sql = f"""
173 UPDATE {self.schema_name}.{self.table_name}
174 SET data = %(data)s, metadata = %(metadata)s, updated_at = CURRENT_TIMESTAMP
175 WHERE id = %(id)s
176 """
177 result = self.db.execute(sql, row)
178 # PostgresDB.execute returns number of affected rows
179 return result > 0 if isinstance(result, int) else False
181 def delete(self, id: str) -> bool:
182 """Delete a record by ID."""
183 self._check_connection()
184 sql = f"""
185 DELETE FROM {self.schema_name}.{self.table_name}
186 WHERE id = %(id)s
187 """
188 result = self.db.execute(sql, {"id": id})
189 return result > 0 if isinstance(result, int) else False
191 def exists(self, id: str) -> bool:
192 """Check if a record exists."""
193 self._check_connection()
194 sql = f"""
195 SELECT 1 FROM {self.schema_name}.{self.table_name}
196 WHERE id = %(id)s
197 LIMIT 1
198 """
199 df = self.db.query(sql, {"id": id})
200 return not df.empty
202 def search(self, query: Query) -> list[Record]:
203 """Search for records matching the query."""
204 self._check_connection()
205 # Build SQL query from Query object
206 where_clauses = []
207 params = {}
209 # Build WHERE clauses for filters
210 for i, filter in enumerate(query.filters):
211 field_path = f"data->>'{filter.field}'"
212 param_name = f"param_{i}"
214 if filter.operator == Operator.EQUALS:
215 where_clauses.append(f"{field_path} = %({param_name})s")
216 params[param_name] = str(filter.value)
217 elif filter.operator == Operator.NOT_EQUALS:
218 where_clauses.append(f"{field_path} != %({param_name})s")
219 params[param_name] = str(filter.value)
220 elif filter.operator == Operator.GREATER_THAN:
221 where_clauses.append(f"({field_path})::float > %({param_name})s")
222 params[param_name] = float(filter.value)
223 elif filter.operator == Operator.LESS_THAN:
224 where_clauses.append(f"({field_path})::float < %({param_name})s")
225 params[param_name] = float(filter.value)
226 elif filter.operator == Operator.GREATER_THAN_OR_EQUAL:
227 where_clauses.append(f"({field_path})::float >= %({param_name})s")
228 params[param_name] = float(filter.value)
229 elif filter.operator == Operator.LESS_THAN_OR_EQUAL:
230 where_clauses.append(f"({field_path})::float <= %({param_name})s")
231 params[param_name] = float(filter.value)
232 elif filter.operator == Operator.CONTAINS:
233 where_clauses.append(f"{field_path} LIKE %({param_name})s")
234 params[param_name] = f"%{filter.value}%"
235 elif filter.operator == Operator.IN:
236 where_clauses.append(f"{field_path} = ANY(%({param_name})s)")
237 params[param_name] = list(filter.value)
238 elif filter.operator == Operator.NOT_IN:
239 where_clauses.append(f"{field_path} != ALL(%({param_name})s)")
240 params[param_name] = list(filter.value)
242 # Build SQL
243 sql = f"SELECT id, data, metadata FROM {self.schema_name}.{self.table_name}"
244 if where_clauses:
245 sql += " WHERE " + " AND ".join(where_clauses)
247 # Add ORDER BY
248 if query.sort_specs:
249 order_clauses = []
250 for sort_spec in query.sort_specs:
251 field_path = f"data->>'{sort_spec.field}'"
252 direction = "DESC" if sort_spec.order == SortOrder.DESC else "ASC"
253 order_clauses.append(f"{field_path} {direction}")
254 sql += " ORDER BY " + ", ".join(order_clauses)
256 # Add LIMIT and OFFSET
257 if query.limit_value:
258 sql += f" LIMIT {query.limit_value}"
259 if query.offset_value:
260 sql += f" OFFSET {query.offset_value}"
262 # Execute query
263 df = self.db.query(sql, params)
265 # Convert to records
266 records = []
267 for _, row in df.iterrows():
268 record = self._row_to_record(row.to_dict())
270 # Apply field projection if specified
271 if query.fields:
272 record = record.project(query.fields)
274 records.append(record)
276 return records
278 def _count_all(self) -> int:
279 """Count all records in the database."""
280 self._check_connection()
281 sql = f"SELECT COUNT(*) as count FROM {self.schema_name}.{self.table_name}"
282 df = self.db.query(sql)
283 return int(df.iloc[0]["count"]) if not df.empty else 0
285 def clear(self) -> int:
286 """Clear all records from the database."""
287 self._check_connection()
288 # Get count first
289 count = self._count_all()
291 # Delete all records
292 sql = f"TRUNCATE TABLE {self.schema_name}.{self.table_name}"
293 self.db.execute(sql)
295 return count
297 def stream_read(
298 self,
299 query: Optional[Query] = None,
300 config: Optional[StreamConfig] = None
301 ) -> Iterator[Record]:
302 """Stream records from PostgreSQL."""
303 self._check_connection()
304 config = config or StreamConfig()
306 # Build SQL query
307 sql = f"SELECT id, data, metadata FROM {self.schema_name}.{self.table_name}"
308 params = {}
310 if query and query.filters:
311 # Add WHERE clause (simplified for now)
312 where_clauses = []
313 for i, filter in enumerate(query.filters):
314 field_path = f"data->>'{filter.field}'"
315 param_name = f"param_{i}"
317 if filter.operator == Operator.EQUALS:
318 where_clauses.append(f"{field_path} = %({param_name})s")
319 params[param_name] = str(filter.value)
321 if where_clauses:
322 sql += " WHERE " + " AND ".join(where_clauses)
324 # Use cursor for streaming
325 # Note: PostgresDB may need modification to support cursors
326 # For now, we'll fetch in batches
327 sql += f" LIMIT {config.batch_size} OFFSET %(offset)s"
329 offset = 0
330 while True:
331 params["offset"] = offset
332 df = self.db.query(sql, params)
334 if df.empty:
335 break
337 for _, row in df.iterrows():
338 record = self._row_to_record(row.to_dict())
339 if query and query.fields:
340 record = record.project(query.fields)
341 yield record
343 offset += config.batch_size
345 # If we got less than batch_size, we're done
346 if len(df) < config.batch_size:
347 break
349 def stream_write(
350 self,
351 records: Iterator[Record],
352 config: Optional[StreamConfig] = None
353 ) -> StreamResult:
354 """Stream records into PostgreSQL."""
355 self._check_connection()
356 config = config or StreamConfig()
357 result = StreamResult()
358 start_time = time.time()
360 batch = []
361 for record in records:
362 batch.append(record)
364 if len(batch) >= config.batch_size:
365 # Write batch
366 try:
367 self._write_batch(batch)
368 result.successful += len(batch)
369 result.total_processed += len(batch)
370 except Exception as e:
371 result.failed += len(batch)
372 result.total_processed += len(batch)
373 if config.on_error:
374 for rec in batch:
375 if not config.on_error(e, rec):
376 result.add_error(None, e)
377 break
378 else:
379 result.add_error(None, e)
381 batch = []
383 # Write remaining batch
384 if batch:
385 try:
386 self._write_batch(batch)
387 result.successful += len(batch)
388 result.total_processed += len(batch)
389 except Exception as e:
390 result.failed += len(batch)
391 result.total_processed += len(batch)
392 result.add_error(None, e)
394 result.duration = time.time() - start_time
395 return result
397 def _write_batch(self, records: list[Record]) -> None:
398 """Write a batch of records to the database."""
399 # Build batch insert SQL
400 values = []
401 params = {}
403 for i, record in enumerate(records):
404 id = str(uuid.uuid4())
405 row = self._record_to_row(record, id)
406 values.append(f"(%(id_{i})s, %(data_{i})s, %(metadata_{i})s)")
407 params[f"id_{i}"] = row["id"]
408 params[f"data_{i}"] = row["data"]
409 params[f"metadata_{i}"] = row["metadata"]
411 sql = f"""
412 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata)
413 VALUES {', '.join(values)}
414 """
415 self.db.execute(sql, params)
418class PostgresDatabase(Database, ConfigurableBase):
419 """Asynchronous PostgreSQL database backend with proper connection management."""
421 def __init__(self, config: dict[str, Any] | None = None):
422 """Initialize async PostgreSQL database."""
423 # Create sync database for delegation
424 self._sync_db = SyncPostgresDatabase(config)
425 super().__init__(config)
426 self._connected = False
428 @classmethod
429 def from_config(cls, config: dict) -> "PostgresDatabase":
430 """Create from config dictionary."""
431 return cls(config)
433 async def connect(self) -> None:
434 """Connect to the database."""
435 if self._connected:
436 return
438 # Run sync connect in executor
439 loop = asyncio.get_event_loop()
440 await loop.run_in_executor(None, self._sync_db.connect)
441 self._connected = True
443 async def close(self) -> None:
444 """Close the database connection."""
445 if self._connected:
446 loop = asyncio.get_event_loop()
447 await loop.run_in_executor(None, self._sync_db.close)
448 self._connected = False
450 def _initialize(self) -> None:
451 """Initialize is handled by sync database."""
452 pass
454 async def create(self, record: Record) -> str:
455 """Create a new record asynchronously."""
456 loop = asyncio.get_event_loop()
457 return await loop.run_in_executor(None, self._sync_db.create, record)
459 async def read(self, id: str) -> Record | None:
460 """Read a record asynchronously."""
461 loop = asyncio.get_event_loop()
462 return await loop.run_in_executor(None, self._sync_db.read, id)
464 async def update(self, id: str, record: Record) -> bool:
465 """Update a record asynchronously."""
466 loop = asyncio.get_event_loop()
467 return await loop.run_in_executor(None, self._sync_db.update, id, record)
469 async def delete(self, id: str) -> bool:
470 """Delete a record asynchronously."""
471 loop = asyncio.get_event_loop()
472 return await loop.run_in_executor(None, self._sync_db.delete, id)
474 async def exists(self, id: str) -> bool:
475 """Check if a record exists asynchronously."""
476 loop = asyncio.get_event_loop()
477 return await loop.run_in_executor(None, self._sync_db.exists, id)
479 async def search(self, query: Query) -> list[Record]:
480 """Search for records asynchronously."""
481 loop = asyncio.get_event_loop()
482 return await loop.run_in_executor(None, self._sync_db.search, query)
484 async def _count_all(self) -> int:
485 """Count all records asynchronously."""
486 loop = asyncio.get_event_loop()
487 return await loop.run_in_executor(None, self._sync_db._count_all)
489 async def clear(self) -> int:
490 """Clear all records asynchronously."""
491 loop = asyncio.get_event_loop()
492 return await loop.run_in_executor(None, self._sync_db.clear)
494 async def stream_read(
495 self,
496 query: Optional[Query] = None,
497 config: Optional[StreamConfig] = None
498 ) -> AsyncIterator[Record]:
499 """Stream records from PostgreSQL asynchronously."""
500 loop = asyncio.get_event_loop()
502 # Get sync iterator in thread
503 sync_iter = await loop.run_in_executor(
504 None,
505 self._sync_db.stream_read,
506 query,
507 config
508 )
510 # Convert to async iterator
511 for record in sync_iter:
512 yield record
513 # Small yield to prevent blocking
514 await asyncio.sleep(0)
516 async def stream_write(
517 self,
518 records: AsyncIterator[Record],
519 config: Optional[StreamConfig] = None
520 ) -> StreamResult:
521 """Stream records into PostgreSQL asynchronously."""
522 config = config or StreamConfig()
523 result = StreamResult()
524 start_time = time.time()
526 batch = []
527 async for record in records:
528 batch.append(record)
530 if len(batch) >= config.batch_size:
531 # Write batch in executor
532 loop = asyncio.get_event_loop()
533 try:
534 await loop.run_in_executor(None, self._sync_db._write_batch, batch)
535 result.successful += len(batch)
536 result.total_processed += len(batch)
537 except Exception as e:
538 result.failed += len(batch)
539 result.total_processed += len(batch)
540 if config.on_error:
541 for rec in batch:
542 if not config.on_error(e, rec):
543 result.add_error(None, e)
544 break
545 else:
546 result.add_error(None, e)
548 batch = []
550 # Write remaining batch
551 if batch:
552 loop = asyncio.get_event_loop()
553 try:
554 await loop.run_in_executor(None, self._sync_db._write_batch, batch)
555 result.successful += len(batch)
556 result.total_processed += len(batch)
557 except Exception as e:
558 result.failed += len(batch)
559 result.total_processed += len(batch)
560 result.add_error(None, e)
562 result.duration = time.time() - start_time
563 return result