Coverage for src/dataknobs_data/backends/postgres_refactored.py: 0%

319 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-15 10:46 -0500

1"""PostgreSQL backend implementation with proper connection management.""" 

2 

3import asyncio 

4import json 

5import time 

6import uuid 

7from typing import Any, AsyncIterator, Iterator, Optional 

8 

9from dataknobs_config import ConfigurableBase 

10from dataknobs_utils.sql_utils import DotenvPostgresConnector, PostgresDB 

11 

12from ..database import Database, SyncDatabase 

13from ..query import Operator, Query, SortOrder 

14from ..records import Record 

15from ..streaming import StreamConfig, StreamResult 

16 

17 

18class SyncPostgresDatabase(SyncDatabase, ConfigurableBase): 

19 """Synchronous PostgreSQL database backend with proper connection management.""" 

20 

21 def __init__(self, config: dict[str, Any] | None = None): 

22 """Initialize PostgreSQL database configuration. 

23 

24 Args: 

25 config: Configuration with the following optional keys: 

26 - host: PostgreSQL host (default: from env/localhost) 

27 - port: PostgreSQL port (default: 5432) 

28 - database: Database name (default: from env/postgres) 

29 - user: Username (default: from env/postgres) 

30 - password: Password (default: from env) 

31 - table: Table name (default: "records") 

32 - schema: Schema name (default: "public") 

33 """ 

34 super().__init__(config) 

35 self.db = None # Will be initialized in connect() 

36 self._connected = False 

37 

38 @classmethod 

39 def from_config(cls, config: dict) -> "SyncPostgresDatabase": 

40 """Create from config dictionary.""" 

41 return cls(config) 

42 

43 def connect(self) -> None: 

44 """Connect to the PostgreSQL database.""" 

45 if self._connected: 

46 return # Already connected 

47 

48 config = self.config.copy() 

49 

50 # Extract table configuration 

51 self.table_name = config.pop("table", "records") 

52 self.schema_name = config.pop("schema", "public") 

53 

54 # Create connection using existing utilities 

55 if not any(key in config for key in ["host", "database", "user"]): 

56 # Use dotenv connector for environment-based config 

57 connector = DotenvPostgresConnector() 

58 self.db = PostgresDB(connector) 

59 else: 

60 # Direct configuration - map 'database' to 'db' for PostgresDB 

61 self.db = PostgresDB( 

62 host=config.get("host", "localhost"), 

63 db=config.get("database", "postgres"), # Note: PostgresDB expects 'db' not 'database' 

64 user=config.get("user", "postgres"), 

65 pwd=config.get("password"), # Note: PostgresDB expects 'pwd' not 'password' 

66 port=config.get("port", 5432), 

67 ) 

68 

69 # Create table if it doesn't exist 

70 self._ensure_table() 

71 self._connected = True 

72 

73 def close(self) -> None: 

74 """Close the database connection.""" 

75 if self.db: 

76 # PostgresDB manages its own connections via context managers 

77 # but we can mark as disconnected 

78 self._connected = False 

79 

80 def _initialize(self) -> None: 

81 """Initialize method - connection setup moved to connect().""" 

82 # Configuration parsing stays here, actual connection in connect() 

83 pass 

84 

85 def _ensure_table(self) -> None: 

86 """Ensure the records table exists.""" 

87 if not self.db: 

88 raise RuntimeError("Database not connected. Call connect() first.") 

89 

90 create_table_sql = f""" 

91 CREATE TABLE IF NOT EXISTS {self.schema_name}.{self.table_name} ( 

92 id VARCHAR(255) PRIMARY KEY, 

93 data JSONB NOT NULL, 

94 metadata JSONB, 

95 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 

96 updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP 

97 ); 

98  

99 CREATE INDEX IF NOT EXISTS idx_{self.table_name}_data  

100 ON {self.schema_name}.{self.table_name} USING GIN (data); 

101  

102 CREATE INDEX IF NOT EXISTS idx_{self.table_name}_metadata 

103 ON {self.schema_name}.{self.table_name} USING GIN (metadata); 

104 """ 

105 self.db.execute(create_table_sql) 

106 

107 def _check_connection(self) -> None: 

108 """Check if database is connected.""" 

109 if not self._connected or not self.db: 

110 raise RuntimeError("Database not connected. Call connect() first.") 

111 

112 def _record_to_row(self, record: Record, id: str | None = None) -> dict[str, Any]: 

113 """Convert a Record to a database row.""" 

114 data = {} 

115 for field_name, field_obj in record.fields.items(): 

116 data[field_name] = field_obj.value 

117 

118 return { 

119 "id": id or str(uuid.uuid4()), 

120 "data": json.dumps(data), 

121 "metadata": json.dumps(record.metadata) if record.metadata else None, 

122 } 

123 

124 def _row_to_record(self, row: dict[str, Any]) -> Record: 

125 """Convert a database row to a Record.""" 

126 data = row.get("data", {}) 

127 if isinstance(data, str): 

128 data = json.loads(data) 

129 

130 metadata = row.get("metadata", {}) 

131 if isinstance(metadata, str) and metadata: 

132 metadata = json.loads(metadata) 

133 elif not metadata: 

134 metadata = {} 

135 

136 return Record(data=data, metadata=metadata) 

137 

138 def create(self, record: Record) -> str: 

139 """Create a new record.""" 

140 self._check_connection() 

141 id = str(uuid.uuid4()) 

142 row = self._record_to_row(record, id) 

143 

144 sql = f""" 

145 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata) 

146 VALUES (%(id)s, %(data)s, %(metadata)s) 

147 """ 

148 self.db.execute(sql, row) 

149 return id 

150 

151 def read(self, id: str) -> Record | None: 

152 """Read a record by ID.""" 

153 self._check_connection() 

154 sql = f""" 

155 SELECT id, data, metadata 

156 FROM {self.schema_name}.{self.table_name} 

157 WHERE id = %(id)s 

158 """ 

159 df = self.db.query(sql, {"id": id}) 

160 

161 if df.empty: 

162 return None 

163 

164 row = df.iloc[0].to_dict() 

165 return self._row_to_record(row) 

166 

167 def update(self, id: str, record: Record) -> bool: 

168 """Update an existing record.""" 

169 self._check_connection() 

170 row = self._record_to_row(record, id) 

171 

172 sql = f""" 

173 UPDATE {self.schema_name}.{self.table_name} 

174 SET data = %(data)s, metadata = %(metadata)s, updated_at = CURRENT_TIMESTAMP 

175 WHERE id = %(id)s 

176 """ 

177 result = self.db.execute(sql, row) 

178 # PostgresDB.execute returns number of affected rows 

179 return result > 0 if isinstance(result, int) else False 

180 

181 def delete(self, id: str) -> bool: 

182 """Delete a record by ID.""" 

183 self._check_connection() 

184 sql = f""" 

185 DELETE FROM {self.schema_name}.{self.table_name} 

186 WHERE id = %(id)s 

187 """ 

188 result = self.db.execute(sql, {"id": id}) 

189 return result > 0 if isinstance(result, int) else False 

190 

191 def exists(self, id: str) -> bool: 

192 """Check if a record exists.""" 

193 self._check_connection() 

194 sql = f""" 

195 SELECT 1 FROM {self.schema_name}.{self.table_name} 

196 WHERE id = %(id)s 

197 LIMIT 1 

198 """ 

199 df = self.db.query(sql, {"id": id}) 

200 return not df.empty 

201 

202 def search(self, query: Query) -> list[Record]: 

203 """Search for records matching the query.""" 

204 self._check_connection() 

205 # Build SQL query from Query object 

206 where_clauses = [] 

207 params = {} 

208 

209 # Build WHERE clauses for filters 

210 for i, filter in enumerate(query.filters): 

211 field_path = f"data->>'{filter.field}'" 

212 param_name = f"param_{i}" 

213 

214 if filter.operator == Operator.EQUALS: 

215 where_clauses.append(f"{field_path} = %({param_name})s") 

216 params[param_name] = str(filter.value) 

217 elif filter.operator == Operator.NOT_EQUALS: 

218 where_clauses.append(f"{field_path} != %({param_name})s") 

219 params[param_name] = str(filter.value) 

220 elif filter.operator == Operator.GREATER_THAN: 

221 where_clauses.append(f"({field_path})::float > %({param_name})s") 

222 params[param_name] = float(filter.value) 

223 elif filter.operator == Operator.LESS_THAN: 

224 where_clauses.append(f"({field_path})::float < %({param_name})s") 

225 params[param_name] = float(filter.value) 

226 elif filter.operator == Operator.GREATER_THAN_OR_EQUAL: 

227 where_clauses.append(f"({field_path})::float >= %({param_name})s") 

228 params[param_name] = float(filter.value) 

229 elif filter.operator == Operator.LESS_THAN_OR_EQUAL: 

230 where_clauses.append(f"({field_path})::float <= %({param_name})s") 

231 params[param_name] = float(filter.value) 

232 elif filter.operator == Operator.CONTAINS: 

233 where_clauses.append(f"{field_path} LIKE %({param_name})s") 

234 params[param_name] = f"%{filter.value}%" 

235 elif filter.operator == Operator.IN: 

236 where_clauses.append(f"{field_path} = ANY(%({param_name})s)") 

237 params[param_name] = list(filter.value) 

238 elif filter.operator == Operator.NOT_IN: 

239 where_clauses.append(f"{field_path} != ALL(%({param_name})s)") 

240 params[param_name] = list(filter.value) 

241 

242 # Build SQL 

243 sql = f"SELECT id, data, metadata FROM {self.schema_name}.{self.table_name}" 

244 if where_clauses: 

245 sql += " WHERE " + " AND ".join(where_clauses) 

246 

247 # Add ORDER BY 

248 if query.sort_specs: 

249 order_clauses = [] 

250 for sort_spec in query.sort_specs: 

251 field_path = f"data->>'{sort_spec.field}'" 

252 direction = "DESC" if sort_spec.order == SortOrder.DESC else "ASC" 

253 order_clauses.append(f"{field_path} {direction}") 

254 sql += " ORDER BY " + ", ".join(order_clauses) 

255 

256 # Add LIMIT and OFFSET 

257 if query.limit_value: 

258 sql += f" LIMIT {query.limit_value}" 

259 if query.offset_value: 

260 sql += f" OFFSET {query.offset_value}" 

261 

262 # Execute query 

263 df = self.db.query(sql, params) 

264 

265 # Convert to records 

266 records = [] 

267 for _, row in df.iterrows(): 

268 record = self._row_to_record(row.to_dict()) 

269 

270 # Apply field projection if specified 

271 if query.fields: 

272 record = record.project(query.fields) 

273 

274 records.append(record) 

275 

276 return records 

277 

278 def _count_all(self) -> int: 

279 """Count all records in the database.""" 

280 self._check_connection() 

281 sql = f"SELECT COUNT(*) as count FROM {self.schema_name}.{self.table_name}" 

282 df = self.db.query(sql) 

283 return int(df.iloc[0]["count"]) if not df.empty else 0 

284 

285 def clear(self) -> int: 

286 """Clear all records from the database.""" 

287 self._check_connection() 

288 # Get count first 

289 count = self._count_all() 

290 

291 # Delete all records 

292 sql = f"TRUNCATE TABLE {self.schema_name}.{self.table_name}" 

293 self.db.execute(sql) 

294 

295 return count 

296 

297 def stream_read( 

298 self, 

299 query: Optional[Query] = None, 

300 config: Optional[StreamConfig] = None 

301 ) -> Iterator[Record]: 

302 """Stream records from PostgreSQL.""" 

303 self._check_connection() 

304 config = config or StreamConfig() 

305 

306 # Build SQL query 

307 sql = f"SELECT id, data, metadata FROM {self.schema_name}.{self.table_name}" 

308 params = {} 

309 

310 if query and query.filters: 

311 # Add WHERE clause (simplified for now) 

312 where_clauses = [] 

313 for i, filter in enumerate(query.filters): 

314 field_path = f"data->>'{filter.field}'" 

315 param_name = f"param_{i}" 

316 

317 if filter.operator == Operator.EQUALS: 

318 where_clauses.append(f"{field_path} = %({param_name})s") 

319 params[param_name] = str(filter.value) 

320 

321 if where_clauses: 

322 sql += " WHERE " + " AND ".join(where_clauses) 

323 

324 # Use cursor for streaming 

325 # Note: PostgresDB may need modification to support cursors 

326 # For now, we'll fetch in batches 

327 sql += f" LIMIT {config.batch_size} OFFSET %(offset)s" 

328 

329 offset = 0 

330 while True: 

331 params["offset"] = offset 

332 df = self.db.query(sql, params) 

333 

334 if df.empty: 

335 break 

336 

337 for _, row in df.iterrows(): 

338 record = self._row_to_record(row.to_dict()) 

339 if query and query.fields: 

340 record = record.project(query.fields) 

341 yield record 

342 

343 offset += config.batch_size 

344 

345 # If we got less than batch_size, we're done 

346 if len(df) < config.batch_size: 

347 break 

348 

349 def stream_write( 

350 self, 

351 records: Iterator[Record], 

352 config: Optional[StreamConfig] = None 

353 ) -> StreamResult: 

354 """Stream records into PostgreSQL.""" 

355 self._check_connection() 

356 config = config or StreamConfig() 

357 result = StreamResult() 

358 start_time = time.time() 

359 

360 batch = [] 

361 for record in records: 

362 batch.append(record) 

363 

364 if len(batch) >= config.batch_size: 

365 # Write batch 

366 try: 

367 self._write_batch(batch) 

368 result.successful += len(batch) 

369 result.total_processed += len(batch) 

370 except Exception as e: 

371 result.failed += len(batch) 

372 result.total_processed += len(batch) 

373 if config.on_error: 

374 for rec in batch: 

375 if not config.on_error(e, rec): 

376 result.add_error(None, e) 

377 break 

378 else: 

379 result.add_error(None, e) 

380 

381 batch = [] 

382 

383 # Write remaining batch 

384 if batch: 

385 try: 

386 self._write_batch(batch) 

387 result.successful += len(batch) 

388 result.total_processed += len(batch) 

389 except Exception as e: 

390 result.failed += len(batch) 

391 result.total_processed += len(batch) 

392 result.add_error(None, e) 

393 

394 result.duration = time.time() - start_time 

395 return result 

396 

397 def _write_batch(self, records: list[Record]) -> None: 

398 """Write a batch of records to the database.""" 

399 # Build batch insert SQL 

400 values = [] 

401 params = {} 

402 

403 for i, record in enumerate(records): 

404 id = str(uuid.uuid4()) 

405 row = self._record_to_row(record, id) 

406 values.append(f"(%(id_{i})s, %(data_{i})s, %(metadata_{i})s)") 

407 params[f"id_{i}"] = row["id"] 

408 params[f"data_{i}"] = row["data"] 

409 params[f"metadata_{i}"] = row["metadata"] 

410 

411 sql = f""" 

412 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata) 

413 VALUES {', '.join(values)} 

414 """ 

415 self.db.execute(sql, params) 

416 

417 

418class PostgresDatabase(Database, ConfigurableBase): 

419 """Asynchronous PostgreSQL database backend with proper connection management.""" 

420 

421 def __init__(self, config: dict[str, Any] | None = None): 

422 """Initialize async PostgreSQL database.""" 

423 # Create sync database for delegation 

424 self._sync_db = SyncPostgresDatabase(config) 

425 super().__init__(config) 

426 self._connected = False 

427 

428 @classmethod 

429 def from_config(cls, config: dict) -> "PostgresDatabase": 

430 """Create from config dictionary.""" 

431 return cls(config) 

432 

433 async def connect(self) -> None: 

434 """Connect to the database.""" 

435 if self._connected: 

436 return 

437 

438 # Run sync connect in executor 

439 loop = asyncio.get_event_loop() 

440 await loop.run_in_executor(None, self._sync_db.connect) 

441 self._connected = True 

442 

443 async def close(self) -> None: 

444 """Close the database connection.""" 

445 if self._connected: 

446 loop = asyncio.get_event_loop() 

447 await loop.run_in_executor(None, self._sync_db.close) 

448 self._connected = False 

449 

450 def _initialize(self) -> None: 

451 """Initialize is handled by sync database.""" 

452 pass 

453 

454 async def create(self, record: Record) -> str: 

455 """Create a new record asynchronously.""" 

456 loop = asyncio.get_event_loop() 

457 return await loop.run_in_executor(None, self._sync_db.create, record) 

458 

459 async def read(self, id: str) -> Record | None: 

460 """Read a record asynchronously.""" 

461 loop = asyncio.get_event_loop() 

462 return await loop.run_in_executor(None, self._sync_db.read, id) 

463 

464 async def update(self, id: str, record: Record) -> bool: 

465 """Update a record asynchronously.""" 

466 loop = asyncio.get_event_loop() 

467 return await loop.run_in_executor(None, self._sync_db.update, id, record) 

468 

469 async def delete(self, id: str) -> bool: 

470 """Delete a record asynchronously.""" 

471 loop = asyncio.get_event_loop() 

472 return await loop.run_in_executor(None, self._sync_db.delete, id) 

473 

474 async def exists(self, id: str) -> bool: 

475 """Check if a record exists asynchronously.""" 

476 loop = asyncio.get_event_loop() 

477 return await loop.run_in_executor(None, self._sync_db.exists, id) 

478 

479 async def search(self, query: Query) -> list[Record]: 

480 """Search for records asynchronously.""" 

481 loop = asyncio.get_event_loop() 

482 return await loop.run_in_executor(None, self._sync_db.search, query) 

483 

484 async def _count_all(self) -> int: 

485 """Count all records asynchronously.""" 

486 loop = asyncio.get_event_loop() 

487 return await loop.run_in_executor(None, self._sync_db._count_all) 

488 

489 async def clear(self) -> int: 

490 """Clear all records asynchronously.""" 

491 loop = asyncio.get_event_loop() 

492 return await loop.run_in_executor(None, self._sync_db.clear) 

493 

494 async def stream_read( 

495 self, 

496 query: Optional[Query] = None, 

497 config: Optional[StreamConfig] = None 

498 ) -> AsyncIterator[Record]: 

499 """Stream records from PostgreSQL asynchronously.""" 

500 loop = asyncio.get_event_loop() 

501 

502 # Get sync iterator in thread 

503 sync_iter = await loop.run_in_executor( 

504 None, 

505 self._sync_db.stream_read, 

506 query, 

507 config 

508 ) 

509 

510 # Convert to async iterator 

511 for record in sync_iter: 

512 yield record 

513 # Small yield to prevent blocking 

514 await asyncio.sleep(0) 

515 

516 async def stream_write( 

517 self, 

518 records: AsyncIterator[Record], 

519 config: Optional[StreamConfig] = None 

520 ) -> StreamResult: 

521 """Stream records into PostgreSQL asynchronously.""" 

522 config = config or StreamConfig() 

523 result = StreamResult() 

524 start_time = time.time() 

525 

526 batch = [] 

527 async for record in records: 

528 batch.append(record) 

529 

530 if len(batch) >= config.batch_size: 

531 # Write batch in executor 

532 loop = asyncio.get_event_loop() 

533 try: 

534 await loop.run_in_executor(None, self._sync_db._write_batch, batch) 

535 result.successful += len(batch) 

536 result.total_processed += len(batch) 

537 except Exception as e: 

538 result.failed += len(batch) 

539 result.total_processed += len(batch) 

540 if config.on_error: 

541 for rec in batch: 

542 if not config.on_error(e, rec): 

543 result.add_error(None, e) 

544 break 

545 else: 

546 result.add_error(None, e) 

547 

548 batch = [] 

549 

550 # Write remaining batch 

551 if batch: 

552 loop = asyncio.get_event_loop() 

553 try: 

554 await loop.run_in_executor(None, self._sync_db._write_batch, batch) 

555 result.successful += len(batch) 

556 result.total_processed += len(batch) 

557 except Exception as e: 

558 result.failed += len(batch) 

559 result.total_processed += len(batch) 

560 result.add_error(None, e) 

561 

562 result.duration = time.time() - start_time 

563 return result