Coverage for src / sql_tool / core / client.py: 96%

53 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-14 15:28 -0500

1"""PostgreSQL client for SQL Tool. 

2 

3Wraps psycopg v3 synchronous connections with query execution, 

4statement timeout, and exception mapping to SqlToolError hierarchy. 

5""" 

6 

7from __future__ import annotations 

8 

9from typing import TYPE_CHECKING, Any 

10 

11import psycopg 

12import psycopg.errors 

13import structlog 

14 

15from sql_tool.core.exceptions import NetworkError, SqlToolError, TimeoutError 

16from sql_tool.core.models import ColumnMeta, QueryResult 

17 

18if TYPE_CHECKING: 

19 from sql_tool.core.config import ResolvedConfig 

20 

21log = structlog.get_logger() 

22 

23# Mapping from psycopg type OIDs to human-readable names. 

24# Covers the most common PostgreSQL types; unknown OIDs fall back to "unknown". 

25_TYPE_NAMES: dict[int, str] = { 

26 16: "bool", 

27 20: "int8", 

28 21: "int2", 

29 23: "int4", 

30 25: "text", 

31 26: "oid", 

32 114: "json", 

33 142: "xml", 

34 700: "float4", 

35 701: "float8", 

36 790: "money", 

37 1042: "bpchar", 

38 1043: "varchar", 

39 1082: "date", 

40 1083: "time", 

41 1114: "timestamp", 

42 1184: "timestamptz", 

43 1186: "interval", 

44 1700: "numeric", 

45 2950: "uuid", 

46 3802: "jsonb", 

47} 

48 

49 

50class PgClient: 

51 """Synchronous PostgreSQL client using psycopg v3.""" 

52 

53 def __init__(self, config: ResolvedConfig) -> None: 

54 self.config = config 

55 self._connection: psycopg.Connection[Any] | None = None 

56 

57 def __enter__(self) -> PgClient: 

58 return self 

59 

60 def __exit__(self, *exc: object) -> None: 

61 self.close() 

62 

63 def _connect(self) -> psycopg.Connection[Any]: 

64 if self._connection is not None and not self._connection.closed: 

65 return self._connection 

66 

67 try: 

68 self._connection = psycopg.connect( 

69 host=self.config.host, 

70 port=self.config.port, 

71 dbname=self.config.dbname, 

72 user=self.config.user, 

73 password=self.config.password, 

74 sslmode=self.config.sslmode, 

75 connect_timeout=self.config.connect_timeout, 

76 application_name=self.config.application_name, 

77 autocommit=True, 

78 ) 

79 except psycopg.OperationalError as e: 

80 msg = ( 

81 f"Connection failed to {self.config.host}:{self.config.port} " 

82 f"database '{self.config.dbname}': {e}" 

83 ) 

84 raise NetworkError(msg) from e 

85 

86 return self._connection 

87 

88 def execute_query( 

89 self, sql: str, params: dict[str, Any] | None = None 

90 ) -> QueryResult: 

91 """Execute SQL and return a QueryResult.""" 

92 conn = self._connect() 

93 timeout_ms = int(self.config.default_timeout * 1000) 

94 

95 try: 

96 with conn.cursor() as cur: 

97 cur.execute(f"SET statement_timeout = {timeout_ms}") 

98 cur.execute(sql, params) 

99 

100 columns: list[ColumnMeta] = [] 

101 rows: list[tuple[Any, ...]] = [] 

102 

103 if cur.description: 

104 for desc in cur.description: 

105 columns.append( 

106 ColumnMeta( 

107 name=desc.name, 

108 type_oid=desc.type_code, 

109 type_name=_TYPE_NAMES.get(desc.type_code, "unknown"), 

110 ) 

111 ) 

112 rows = cur.fetchall() 

113 

114 return QueryResult( 

115 columns=columns, 

116 rows=rows, 

117 row_count=len(rows), 

118 status_message=cur.statusmessage or "", 

119 ) 

120 

121 except psycopg.errors.QueryCanceled as e: 

122 msg = f"Query timed out after {self.config.default_timeout}s: {e}" 

123 raise TimeoutError(msg) from e 

124 except psycopg.errors.SyntaxError as e: 

125 raise SqlToolError(f"SQL error: {e}") from e 

126 except psycopg.OperationalError as e: 

127 raise NetworkError(f"Database error: {e}") from e 

128 

129 def close(self) -> None: 

130 """Close the database connection.""" 

131 if self._connection is not None: 

132 self._connection.close() 

133 self._connection = None