Coverage for src / sql_tool / core / postgres.py: 95%

149 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-14 15:28 -0500

1"""PostgreSQL service and introspection operations. 

2 

3Framework-agnostic business logic for database commands. 

4CLI layers in cli/commands/service.py and cli/main.py provide the typer interface. 

5""" 

6 

7from __future__ import annotations 

8 

9from typing import TYPE_CHECKING, Any 

10 

11from sql_tool.core.models import ColumnMeta, QueryResult 

12 

13if TYPE_CHECKING: 

14 from collections.abc import Callable 

15 

16 from sql_tool.core.client import PgClient 

17 

18# --------------------------------------------------------------------------- 

19# Service operations (used by cli/commands/service.py) 

20# --------------------------------------------------------------------------- 

21 

22 

23def check_server(client: PgClient) -> QueryResult: 

24 """Check PostgreSQL server connectivity and return server info.""" 

25 queries = [ 

26 ("version", "SELECT version()"), 

27 ("database", "SELECT current_database()"), 

28 ("user", "SELECT current_user"), 

29 ("uptime", "SELECT pg_postmaster_start_time()"), 

30 ] 

31 

32 rows: list[tuple[Any, ...]] = [] 

33 for key, sql in queries: 

34 result = client.execute_query(sql) 

35 if result.rows: 

36 rows.append((key, str(result.rows[0][0]))) 

37 

38 return QueryResult( 

39 columns=[ 

40 ColumnMeta(name="property", type_oid=25, type_name="text"), 

41 ColumnMeta(name="value", type_oid=25, type_name="text"), 

42 ], 

43 rows=rows, 

44 row_count=len(rows), 

45 status_message=f"SELECT {len(rows)}", 

46 ) 

47 

48 

49def list_user_tables(client: PgClient) -> list[str]: 

50 """List all user tables (excluding system schemas).""" 

51 sql = """ 

52 SELECT schemaname || '.' || tablename 

53 FROM pg_catalog.pg_tables 

54 WHERE schemaname NOT IN ('pg_catalog', 'information_schema') 

55 """ 

56 result = client.execute_query(sql) 

57 return [row[0] for row in result.rows] 

58 

59 

60def vacuum_tables( 

61 client: PgClient, table_names: list[str], *, full: bool = False 

62) -> int: 

63 """Run VACUUM ANALYZE on the given tables. Returns count of tables vacuumed.""" 

64 vacuum_cmd = "VACUUM FULL ANALYZE" if full else "VACUUM ANALYZE" 

65 for tbl in table_names: 

66 client.execute_query(f"{vacuum_cmd} {tbl}") 

67 return len(table_names) 

68 

69 

70def kill_backend(client: PgClient, pid: int, *, cancel: bool = False) -> bool: 

71 """Terminate or cancel a PostgreSQL backend. Returns True if successful.""" 

72 if cancel: 

73 sql = f"SELECT pg_cancel_backend({pid})" 

74 else: 

75 sql = f"SELECT pg_terminate_backend({pid})" 

76 

77 result = client.execute_query(sql) 

78 return bool(result.rows and result.rows[0][0]) 

79 

80 

81# --------------------------------------------------------------------------- 

82# Database introspection (used by cli/main.py) 

83# --------------------------------------------------------------------------- 

84 

85 

86def list_databases(client: PgClient) -> QueryResult: 

87 """Query all databases with size and owner info, sorted by size DESC.""" 

88 sql = """ 

89 SELECT 

90 d.datname AS name, 

91 pg_catalog.pg_get_userbyid(d.datdba) AS owner, 

92 pg_catalog.pg_encoding_to_char(d.encoding) AS encoding, 

93 pg_catalog.pg_database_size(d.datname) AS size_bytes 

94 FROM pg_catalog.pg_database d 

95 ORDER BY pg_catalog.pg_database_size(d.datname) DESC 

96 """ 

97 return client.execute_query(sql) 

98 

99 

100def list_all_database_names(client: PgClient) -> list[str]: 

101 """List all non-template database names.""" 

102 sql = """ 

103 SELECT datname FROM pg_catalog.pg_database 

104 WHERE datistemplate = false 

105 ORDER BY datname 

106 """ 

107 result = client.execute_query(sql) 

108 return [row[0] for row in result.rows] 

109 

110 

111_SCHEMA_SQL = """ 

112SELECT 

113 schemaname AS schema, 

114 COUNT(*) AS tables, 

115 SUM(pg_catalog.pg_total_relation_size( 

116 quote_ident(schemaname)||'.'||quote_ident(tablename) 

117 )) AS total_bytes 

118FROM pg_catalog.pg_tables 

119WHERE schemaname NOT IN ( 

120 'pg_catalog', 'information_schema', 

121 '_timescaledb_cache', '_timescaledb_catalog', 

122 '_timescaledb_internal', '_timescaledb_config' 

123) 

124GROUP BY schemaname 

125""" 

126 

127_CHUNK_SQL = """ 

128SELECT 

129 h.hypertable_schema, 

130 SUM(COALESCE(cs.before_compression_total_bytes, 

131 (SELECT SUM(pg_total_relation_size( 

132 ('_timescaledb_internal.' || quote_ident(c.chunk_name))::regclass)) 

133 FROM timescaledb_information.chunks c 

134 WHERE c.hypertable_schema = h.hypertable_schema 

135 AND c.hypertable_name = h.hypertable_name) 

136 ))::bigint AS before_bytes, 

137 SUM(cs.after_compression_total_bytes)::bigint AS after_bytes, 

138 SUM(hypertable_size( 

139 (quote_ident(h.hypertable_schema) || '.' || quote_ident(h.hypertable_name))::regclass 

140 )) AS ht_total_bytes 

141FROM timescaledb_information.hypertables h 

142LEFT JOIN LATERAL ( 

143 SELECT 

144 SUM(before_compression_total_bytes)::bigint AS before_compression_total_bytes, 

145 SUM(after_compression_total_bytes)::bigint AS after_compression_total_bytes 

146 FROM hypertable_compression_stats( 

147 (quote_ident(h.hypertable_schema) || '.' || quote_ident(h.hypertable_name))::regclass 

148 ) 

149) cs ON true 

150GROUP BY h.hypertable_schema 

151""" 

152 

153 

154def _query_chunk_map(client: PgClient) -> dict[str, tuple[int, int, int]]: 

155 """Query TimescaleDB chunk stats per schema. Returns empty dict if unavailable.""" 

156 chunk_map: dict[str, tuple[int, int, int]] = {} 

157 try: 

158 chunk_result = client.execute_query(_CHUNK_SQL) 

159 for ht_schema, before_b, after_b, ht_total in chunk_result.rows: 

160 chunk_map[ht_schema] = (before_b or 0, after_b or 0, ht_total or 0) 

161 except Exception: 

162 pass 

163 return chunk_map 

164 

165 

166def list_schemas( 

167 client: PgClient, 

168) -> tuple[QueryResult, dict[str, tuple[int, int, int]]]: 

169 """Query schemas with table counts/sizes and optional TimescaleDB chunk stats. 

170 

171 Returns (schema_result, chunk_map) where: 

172 - schema_result rows: (schema, tables, total_bytes) 

173 - chunk_map: schema -> (before_bytes, after_bytes, ht_total_bytes) 

174 """ 

175 schema_result = client.execute_query(_SCHEMA_SQL) 

176 chunk_map = _query_chunk_map(client) 

177 return schema_result, chunk_map 

178 

179 

180def list_schemas_all_databases( 

181 db_names: list[str], 

182 client_factory: Callable[[str], PgClient], 

183) -> tuple[list[tuple[str, str, int, int, int, int, int]], bool]: 

184 """Query schemas across multiple databases. 

185 

186 Returns (raw_data, has_chunks) where raw_data rows are: 

187 (db_name, schema, tables, total_bytes, before_b, after_b, ht_total). 

188 """ 

189 raw_data: list[tuple[str, str, int, int, int, int, int]] = [] 

190 has_chunks = False 

191 

192 for db_name in db_names: 

193 try: 

194 with client_factory(db_name) as client: 

195 schema_result, chunk_map = list_schemas(client) 

196 if chunk_map: 

197 has_chunks = True 

198 

199 for schema, tables, total_bytes in schema_result.rows: 

200 before_b, after_b, ht_total = chunk_map.get(schema, (0, 0, 0)) 

201 raw_data.append( 

202 ( 

203 db_name, 

204 schema, 

205 tables or 0, 

206 total_bytes or 0, 

207 before_b, 

208 after_b, 

209 ht_total, 

210 ) 

211 ) 

212 except Exception: 

213 raw_data.append((db_name, "(connection failed)", 0, 0, 0, 0, 0)) 

214 

215 return raw_data, has_chunks 

216 

217 

218def list_tables( 

219 client: PgClient, 

220 *, 

221 schema_filter: str | None = None, 

222 include_internal_tables: bool = False, 

223) -> tuple[QueryResult, dict[tuple[str, str], tuple[int, int, int | None, int | None]]]: 

224 """Query tables with size breakdown and optional hypertable stats. 

225 

226 Returns (table_result, ht_map) where: 

227 - table_result rows: if schema_filter set: (name, table_size, index_size, total) 

228 otherwise: (schema, name, table_size, index_size, total) 

229 - ht_map: (schema, name) -> (uncompr_chunks, compr_chunks, before_bytes, after_bytes) 

230 """ 

231 if schema_filter: 

232 where_clause = "WHERE schemaname = %(schema)s" 

233 params: dict[str, Any] = {"schema": schema_filter} 

234 schema_column = "" 

235 else: 

236 if include_internal_tables: 

237 where_clause = ( 

238 "WHERE schemaname NOT IN ('pg_catalog', 'information_schema')" 

239 ) 

240 else: 

241 where_clause = """WHERE schemaname NOT IN ('pg_catalog', 'information_schema', 

242 '_timescaledb_cache', '_timescaledb_catalog', 

243 '_timescaledb_internal', '_timescaledb_config')""" 

244 params = {} 

245 schema_column = "schemaname AS schema," 

246 

247 sql = f""" 

248 SELECT 

249 {schema_column} 

250 tablename AS name, 

251 pg_catalog.pg_table_size(quote_ident(schemaname)||'.'||quote_ident(tablename)) AS table_size_bytes, 

252 pg_catalog.pg_indexes_size(quote_ident(schemaname)||'.'||quote_ident(tablename)) AS index_size_bytes, 

253 pg_catalog.pg_total_relation_size(quote_ident(schemaname)||'.'||quote_ident(tablename)) AS total_bytes 

254 FROM pg_catalog.pg_tables 

255 {where_clause} 

256 ORDER BY schemaname, tablename 

257 """ 

258 

259 ht_sql = """ 

260 SELECT 

261 h.hypertable_schema, 

262 h.hypertable_name, 

263 (SELECT COUNT(*) FILTER (WHERE NOT c.is_compressed) FROM timescaledb_information.chunks c 

264 WHERE c.hypertable_schema = h.hypertable_schema AND c.hypertable_name = h.hypertable_name) AS uncompr_chunks, 

265 (SELECT COUNT(*) FILTER (WHERE c.is_compressed) FROM timescaledb_information.chunks c 

266 WHERE c.hypertable_schema = h.hypertable_schema AND c.hypertable_name = h.hypertable_name) AS compr_chunks, 

267 COALESCE(d.before_compression_total_bytes, 

268 (SELECT SUM(pg_total_relation_size( 

269 ('_timescaledb_internal.' || quote_ident(c.chunk_name))::regclass)) 

270 FROM timescaledb_information.chunks c 

271 WHERE c.hypertable_schema = h.hypertable_schema 

272 AND c.hypertable_name = h.hypertable_name) 

273 ) AS before_bytes, 

274 d.after_compression_total_bytes AS after_bytes 

275 FROM timescaledb_information.hypertables h 

276 LEFT JOIN LATERAL ( 

277 SELECT 

278 SUM(before_compression_total_bytes)::bigint AS before_compression_total_bytes, 

279 SUM(after_compression_total_bytes)::bigint AS after_compression_total_bytes 

280 FROM hypertable_compression_stats( 

281 (quote_ident(h.hypertable_schema) || '.' || quote_ident(h.hypertable_name))::regclass 

282 ) 

283 ) d ON true 

284 """ 

285 

286 result = client.execute_query(sql, params if params else None) 

287 

288 ht_map: dict[tuple[str, str], tuple[int, int, int | None, int | None]] = {} 

289 try: 

290 ht_result = client.execute_query(ht_sql) 

291 for ( 

292 ht_schema, 

293 ht_name, 

294 uncompr_chunks, 

295 compr_chunks, 

296 before_bytes, 

297 after_bytes, 

298 ) in ht_result.rows: 

299 ht_map[(ht_schema, ht_name)] = ( 

300 uncompr_chunks or 0, 

301 compr_chunks or 0, 

302 before_bytes, 

303 after_bytes, 

304 ) 

305 except Exception: 

306 pass 

307 

308 return result, ht_map 

309 

310 

311def describe_table(client: PgClient, schema_name: str, table_name: str) -> QueryResult: 

312 """Get column definitions for a table.""" 

313 sql = """ 

314 SELECT 

315 column_name, 

316 data_type, 

317 is_nullable, 

318 column_default 

319 FROM information_schema.columns 

320 WHERE table_schema = %(schema)s AND table_name = %(table)s 

321 ORDER BY ordinal_position 

322 """ 

323 return client.execute_query(sql, {"schema": schema_name, "table": table_name}) 

324 

325 

326def get_time_column(client: PgClient, schema_name: str, table_name: str) -> str | None: 

327 """Get the primary time dimension column for a hypertable. Returns None if not a hypertable.""" 

328 sql = """ 

329 SELECT d.column_name 

330 FROM timescaledb_information.dimensions d 

331 WHERE d.hypertable_schema = %(schema)s 

332 AND d.hypertable_name = %(table)s 

333 AND d.dimension_number = 1 

334 LIMIT 1 

335 """ 

336 try: 

337 result = client.execute_query(sql, {"schema": schema_name, "table": table_name}) 

338 if result.rows: 

339 return str(result.rows[0][0]) 

340 except Exception: 

341 pass 

342 return None 

343 

344 

345def get_timestamp_range( 

346 client: PgClient, 

347 schema_name: str, 

348 table_name: str, 

349 time_column: str, 

350) -> QueryResult: 

351 """Get min/max timestamps for a table's time column.""" 

352 sql = f""" 

353 SELECT 

354 MIN({time_column})::text AS min_timestamp, 

355 MAX({time_column})::text AS max_timestamp 

356 FROM {schema_name}.{table_name} 

357 """ 

358 return client.execute_query(sql) 

359 

360 

361def preview_table( 

362 client: PgClient, 

363 schema_name: str, 

364 table_name: str, 

365 *, 

366 head: int | None = None, 

367 tail: int | None = None, 

368 sample: int | None = None, 

369 time_column: str | None = None, 

370) -> QueryResult | None: 

371 """Preview table data with head/tail/sample modes. Returns None if no rows.""" 

372 limit = head or tail or sample 

373 if not limit: 

374 return None 

375 

376 if sample is not None: 

377 if time_column: 

378 sql = f""" 

379 SELECT * FROM {schema_name}.{table_name} 

380 WHERE {time_column} >= ( 

381 SELECT MAX({time_column}) - interval '7 days' 

382 FROM {schema_name}.{table_name} 

383 ) 

384 ORDER BY random() 

385 LIMIT %(limit)s 

386 """ 

387 else: 

388 sql = f""" 

389 SELECT * FROM {schema_name}.{table_name} 

390 TABLESAMPLE BERNOULLI(1) 

391 LIMIT %(limit)s 

392 """ 

393 else: 

394 if head is not None and time_column: 

395 order_clause = f"ORDER BY {time_column} ASC" 

396 elif tail is not None and time_column: 

397 order_clause = f"ORDER BY {time_column} DESC" 

398 else: 

399 order_clause = "" 

400 

401 sql = f""" 

402 SELECT * FROM {schema_name}.{table_name} 

403 {order_clause} 

404 LIMIT %(limit)s 

405 """ 

406 

407 result = client.execute_query(sql, {"limit": limit}) 

408 return result if result.rows else None 

409 

410 

411def list_connections( 

412 client: PgClient, 

413 *, 

414 include_all: bool = False, 

415 min_duration: float | None = None, 

416 filter_user: str | None = None, 

417 filter_db: str | None = None, 

418 filter_state: str | None = None, 

419) -> QueryResult: 

420 """Query pg_stat_activity with filters. 

421 

422 Returns rows of (pid, user, db, app, client_addr, state, wait_event, 

423 connected_since, connected_seconds, query_start, query_seconds, query). 

424 """ 

425 filters = ["pid != pg_backend_pid()"] 

426 query_params: dict[str, Any] = {} 

427 

428 if not include_all: 

429 filters.append("state IS NOT NULL AND state != 'idle'") 

430 

431 if min_duration is not None: 

432 filters.append(f"(now() - query_start) > interval '{min_duration} seconds'") 

433 

434 if filter_user is not None: 

435 filters.append("usename = %(filter_user)s") 

436 query_params["filter_user"] = filter_user 

437 

438 if filter_db is not None: 

439 filters.append("datname = %(filter_db)s") 

440 query_params["filter_db"] = filter_db 

441 

442 if filter_state is not None: 

443 filters.append("state = %(filter_state)s") 

444 query_params["filter_state"] = filter_state 

445 

446 where_clause = " AND ".join(f"({f})" for f in filters) 

447 

448 sql = f""" 

449 SELECT 

450 pid, 

451 usename AS user, 

452 datname AS database, 

453 application_name, 

454 client_addr::text AS client_address, 

455 state, 

456 wait_event, 

457 backend_start::text AS connected_since, 

458 EXTRACT(EPOCH FROM (now() - backend_start)) AS connected_seconds, 

459 query_start::text AS query_start, 

460 EXTRACT(EPOCH FROM (now() - query_start)) AS query_seconds, 

461 query 

462 FROM pg_stat_activity 

463 WHERE {where_clause} 

464 ORDER BY query_start 

465 """ 

466 

467 return client.execute_query(sql, query_params if query_params else None) 

468 

469 

470def connections_summary(client: PgClient) -> QueryResult: 

471 """Connection counts grouped by state plus memory configuration settings.""" 

472 conn_sql = """ 

473 SELECT 

474 COALESCE(state, 'total') AS state, 

475 COUNT(*) AS count 

476 FROM pg_stat_activity 

477 WHERE pid != pg_backend_pid() 

478 GROUP BY ROLLUP(state) 

479 ORDER BY CASE WHEN state IS NULL THEN 1 ELSE 0 END, count DESC 

480 """ 

481 

482 mem_sql = """ 

483 SELECT name AS setting, current_setting(name) AS value 

484 FROM pg_settings 

485 WHERE name IN ( 

486 'max_connections', 'shared_buffers', 'effective_cache_size', 

487 'work_mem', 'maintenance_work_mem' 

488 ) 

489 ORDER BY name 

490 """ 

491 

492 conn_result = client.execute_query(conn_sql) 

493 mem_result = client.execute_query(mem_sql) 

494 

495 combined_rows: list[tuple[Any, ...]] = [] 

496 for row in conn_result.rows: 

497 combined_rows.append((row[0], str(row[1]))) 

498 combined_rows.append(("---", "---")) 

499 for row in mem_result.rows: 

500 combined_rows.append((row[0], str(row[1]))) 

501 

502 return QueryResult( 

503 columns=[ 

504 ColumnMeta(name="property", type_oid=25, type_name="text"), 

505 ColumnMeta(name="value", type_oid=25, type_name="text"), 

506 ], 

507 rows=combined_rows, 

508 row_count=len(combined_rows), 

509 status_message=f"SELECT {len(combined_rows)}", 

510 )