Coverage for src / sql_tool / core / config.py: 100%

171 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-14 15:28 -0500

1"""Configuration management for SQL Tool. 

2 

3Handles TOML config files, environment variables, named profiles, 

4and configuration precedence resolution. 

5 

6Precedence order (highest to lowest): 

71. CLI flags (--host, --port, etc.) 

82. --dsn flag (parsed into components) 

93. Environment variables (PGHOST, PGPORT, PGDATABASE, PGUSER, PGPASSWORD) 

104. Named profile (--profile or SQL_PROFILE env var) 

115. Config file defaults 

126. Built-in defaults 

13""" 

14 

15from __future__ import annotations 

16 

17import os 

18import tomllib 

19from pathlib import Path 

20from typing import Any 

21from urllib.parse import parse_qs, urlparse 

22 

23from pydantic import BaseModel, computed_field, field_validator, model_validator 

24 

25from sql_tool.core.exceptions import ConfigError 

26 

27DEFAULT_CONFIG_PATH = Path.home() / ".config" / "sql-tool" / "config.toml" 

28 

29_PG_ENV_VARS: dict[str, str] = { 

30 "PGHOST": "host", 

31 "PGPORT": "port", 

32 "PGDATABASE": "dbname", 

33 "PGUSER": "user", 

34 "PGPASSWORD": "password", # pragma: allowlist secret 

35} 

36 

37_PROFILE_DEFAULTS: dict[str, Any] = { 

38 "host": "localhost", 

39 "port": 5432, 

40 "dbname": "postgres", 

41 "user": None, 

42 "password": None, 

43 "sslmode": "prefer", 

44 "connect_timeout": 10, 

45 "application_name": "sql-tool", 

46} 

47 

48 

49def parse_dsn(dsn: str) -> dict[str, Any]: 

50 """Supports postgresql:// and postgres:// schemes with query params.""" 

51 parsed = urlparse(dsn) 

52 if parsed.scheme not in ("postgresql", "postgres"): 

53 msg = f"Invalid DSN scheme: '{parsed.scheme}'. Expected 'postgresql' or 'postgres'" 

54 raise ConfigError(msg) 

55 

56 result: dict[str, Any] = {} 

57 if parsed.hostname: 

58 result["host"] = parsed.hostname 

59 if parsed.port: 

60 result["port"] = parsed.port 

61 if parsed.path and parsed.path.strip("/"): 

62 result["dbname"] = parsed.path.strip("/") 

63 if parsed.username: 

64 result["user"] = parsed.username 

65 if parsed.password: 

66 result["password"] = parsed.password 

67 query_params = parse_qs(parsed.query) 

68 if "sslmode" in query_params: 

69 result["sslmode"] = query_params["sslmode"][0] 

70 if "connect_timeout" in query_params: 

71 result["connect_timeout"] = int(query_params["connect_timeout"][0]) 

72 if "application_name" in query_params: 

73 result["application_name"] = query_params["application_name"][0] 

74 return result 

75 

76 

77class PgProfile(BaseModel): 

78 dsn: str | None = None 

79 host: str = "localhost" 

80 port: int = 5432 

81 dbname: str = "postgres" 

82 user: str | None = None 

83 password: str | None = None 

84 sslmode: str = "prefer" 

85 connect_timeout: int = 10 

86 application_name: str = "sql-tool" 

87 default_schema: str | None = None 

88 

89 @model_validator(mode="before") 

90 @classmethod 

91 def parse_dsn_into_components(cls, data: Any) -> Any: 

92 if isinstance(data, dict) and data.get("dsn"): 

93 dsn_fields = parse_dsn(data["dsn"]) 

94 for key, value in dsn_fields.items(): 

95 if key not in data: 

96 data[key] = value 

97 return data 

98 

99 @field_validator("sslmode") 

100 @classmethod 

101 def validate_sslmode(cls, v: str) -> str: 

102 valid_modes = { 

103 "disable", 

104 "allow", 

105 "prefer", 

106 "require", 

107 "verify-ca", 

108 "verify-full", 

109 } 

110 if v not in valid_modes: 

111 msg = f"Invalid sslmode: '{v}'. Must be one of: {', '.join(sorted(valid_modes))}" 

112 raise ValueError(msg) 

113 return v 

114 

115 @field_validator("port") 

116 @classmethod 

117 def validate_port(cls, v: int) -> int: 

118 if not (1 <= v <= 65535): 

119 msg = f"Invalid port: {v}. Must be 1-65535" 

120 raise ValueError(msg) 

121 return v 

122 

123 @computed_field # type: ignore[prop-decorator] 

124 @property 

125 def connection_url(self) -> str: 

126 userinfo = "" 

127 if self.user: 

128 if self.password: 

129 userinfo = f"{self.user}:{self.password}@" 

130 else: 

131 userinfo = f"{self.user}@" 

132 return ( 

133 f"postgresql://{userinfo}{self.host}:{self.port}" 

134 f"/{self.dbname}?sslmode={self.sslmode}" 

135 ) 

136 

137 

138class AppConfig(BaseModel): 

139 default_timeout: float = 30.0 

140 default_format: str = "table" 

141 default_profile: str | None = None 

142 profiles: dict[str, PgProfile] = {} 

143 

144 

145class ResolvedConfig(BaseModel): 

146 host: str = "localhost" 

147 port: int = 5432 

148 dbname: str = "postgres" 

149 user: str | None = None 

150 password: str | None = None 

151 sslmode: str = "prefer" 

152 connect_timeout: int = 10 

153 application_name: str = "sql-tool" 

154 default_timeout: float = 30.0 

155 default_format: str = "table" 

156 default_schema: str | None = None 

157 active_profile: str | None = None 

158 sources: dict[str, str] = {} 

159 

160 

161def load_config(config_path: Path | None = None) -> AppConfig: 

162 """Load configuration from TOML file. 

163 

164 Returns default AppConfig if file doesn't exist. 

165 Raises ConfigError on malformed TOML or invalid config. 

166 """ 

167 if config_path is None: 

168 config_path = DEFAULT_CONFIG_PATH 

169 

170 if not config_path.exists(): 

171 return AppConfig() 

172 

173 try: 

174 with open(config_path, "rb") as f: 

175 data = tomllib.load(f) 

176 except tomllib.TOMLDecodeError as e: 

177 msg = f"Malformed TOML in {config_path}: {e}" 

178 raise ConfigError(msg) from e 

179 

180 try: 

181 return AppConfig.model_validate(data) 

182 except Exception as e: 

183 msg = f"Invalid configuration in {config_path}: {e}" 

184 raise ConfigError(msg) from e 

185 

186 

187def resolve_config( 

188 config: AppConfig, 

189 profile_name: str | None = None, 

190 dsn: str | None = None, 

191 **cli_overrides: Any, 

192) -> ResolvedConfig: 

193 """Resolve configuration using precedence chain. 

194 

195 CLI > DSN > env > profile > config defaults > built-in defaults. 

196 """ 

197 sources: dict[str, str] = {} 

198 resolved: dict[str, Any] = {} 

199 

200 # Layer 1: Built-in defaults 

201 resolved.update(_PROFILE_DEFAULTS) 

202 resolved["default_timeout"] = 30.0 

203 resolved["default_format"] = "table" 

204 for key in resolved: 

205 sources[key] = "default" 

206 

207 # Layer 2: Config file global defaults 

208 if config.default_timeout != 30.0: 

209 resolved["default_timeout"] = config.default_timeout 

210 sources["default_timeout"] = "config" 

211 if config.default_format != "table": 

212 resolved["default_format"] = config.default_format 

213 sources["default_format"] = "config" 

214 

215 # Layer 3: Named profile 

216 effective_profile = profile_name 

217 if not effective_profile: 

218 effective_profile = os.environ.get("SQL_PROFILE") 

219 if not effective_profile: 

220 effective_profile = config.default_profile 

221 

222 if effective_profile: 

223 if effective_profile not in config.profiles: 

224 available = ( 

225 ", ".join(sorted(config.profiles.keys())) if config.profiles else "none" 

226 ) 

227 msg = f"Unknown profile: '{effective_profile}'. Available profiles: {available}" 

228 raise ConfigError(msg) 

229 profile = config.profiles[effective_profile] 

230 for key in profile.model_fields_set: 

231 if key in ("dsn",): 

232 continue 

233 if key in resolved: 

234 resolved[key] = getattr(profile, key) 

235 sources[key] = f"profile: {effective_profile}" 

236 

237 # Layer 4: Environment variables 

238 for env_var, field_name in _PG_ENV_VARS.items(): 

239 value = os.environ.get(env_var) 

240 if value is not None: 

241 if field_name == "port": 

242 try: 

243 resolved[field_name] = int(value) 

244 except ValueError: 

245 msg = f"Invalid {env_var} value: '{value}'. Must be an integer" 

246 raise ConfigError(msg) from None 

247 else: 

248 resolved[field_name] = value 

249 sources[field_name] = f"env: {env_var}" 

250 

251 # Layer 5: DSN flag 

252 if dsn: 

253 dsn_fields = parse_dsn(dsn) 

254 for key, value in dsn_fields.items(): 

255 if key in resolved: 

256 resolved[key] = value 

257 sources[key] = "dsn" 

258 

259 # Layer 6: CLI flags (highest priority) 

260 cli_to_field = { 

261 "host": "host", 

262 "port": "port", 

263 "database": "dbname", 

264 "user": "user", 

265 "password": "password", # pragma: allowlist secret 

266 "sslmode": "sslmode", 

267 "timeout": "default_timeout", 

268 "schema": "default_schema", 

269 } 

270 for cli_name, field_name in cli_to_field.items(): 

271 value = cli_overrides.get(cli_name) 

272 if value is not None: 

273 resolved[field_name] = value 

274 sources[field_name] = f"cli: --{cli_name}" 

275 

276 resolved["active_profile"] = effective_profile 

277 resolved["sources"] = sources 

278 return ResolvedConfig(**resolved)