Coverage for src / sql_tool / core / config.py: 100%
171 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-14 15:28 -0500
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-14 15:28 -0500
1"""Configuration management for SQL Tool.
3Handles TOML config files, environment variables, named profiles,
4and configuration precedence resolution.
6Precedence order (highest to lowest):
71. CLI flags (--host, --port, etc.)
82. --dsn flag (parsed into components)
93. Environment variables (PGHOST, PGPORT, PGDATABASE, PGUSER, PGPASSWORD)
104. Named profile (--profile or SQL_PROFILE env var)
115. Config file defaults
126. Built-in defaults
13"""
15from __future__ import annotations
17import os
18import tomllib
19from pathlib import Path
20from typing import Any
21from urllib.parse import parse_qs, urlparse
23from pydantic import BaseModel, computed_field, field_validator, model_validator
25from sql_tool.core.exceptions import ConfigError
27DEFAULT_CONFIG_PATH = Path.home() / ".config" / "sql-tool" / "config.toml"
29_PG_ENV_VARS: dict[str, str] = {
30 "PGHOST": "host",
31 "PGPORT": "port",
32 "PGDATABASE": "dbname",
33 "PGUSER": "user",
34 "PGPASSWORD": "password", # pragma: allowlist secret
35}
37_PROFILE_DEFAULTS: dict[str, Any] = {
38 "host": "localhost",
39 "port": 5432,
40 "dbname": "postgres",
41 "user": None,
42 "password": None,
43 "sslmode": "prefer",
44 "connect_timeout": 10,
45 "application_name": "sql-tool",
46}
49def parse_dsn(dsn: str) -> dict[str, Any]:
50 """Supports postgresql:// and postgres:// schemes with query params."""
51 parsed = urlparse(dsn)
52 if parsed.scheme not in ("postgresql", "postgres"):
53 msg = f"Invalid DSN scheme: '{parsed.scheme}'. Expected 'postgresql' or 'postgres'"
54 raise ConfigError(msg)
56 result: dict[str, Any] = {}
57 if parsed.hostname:
58 result["host"] = parsed.hostname
59 if parsed.port:
60 result["port"] = parsed.port
61 if parsed.path and parsed.path.strip("/"):
62 result["dbname"] = parsed.path.strip("/")
63 if parsed.username:
64 result["user"] = parsed.username
65 if parsed.password:
66 result["password"] = parsed.password
67 query_params = parse_qs(parsed.query)
68 if "sslmode" in query_params:
69 result["sslmode"] = query_params["sslmode"][0]
70 if "connect_timeout" in query_params:
71 result["connect_timeout"] = int(query_params["connect_timeout"][0])
72 if "application_name" in query_params:
73 result["application_name"] = query_params["application_name"][0]
74 return result
77class PgProfile(BaseModel):
78 dsn: str | None = None
79 host: str = "localhost"
80 port: int = 5432
81 dbname: str = "postgres"
82 user: str | None = None
83 password: str | None = None
84 sslmode: str = "prefer"
85 connect_timeout: int = 10
86 application_name: str = "sql-tool"
87 default_schema: str | None = None
89 @model_validator(mode="before")
90 @classmethod
91 def parse_dsn_into_components(cls, data: Any) -> Any:
92 if isinstance(data, dict) and data.get("dsn"):
93 dsn_fields = parse_dsn(data["dsn"])
94 for key, value in dsn_fields.items():
95 if key not in data:
96 data[key] = value
97 return data
99 @field_validator("sslmode")
100 @classmethod
101 def validate_sslmode(cls, v: str) -> str:
102 valid_modes = {
103 "disable",
104 "allow",
105 "prefer",
106 "require",
107 "verify-ca",
108 "verify-full",
109 }
110 if v not in valid_modes:
111 msg = f"Invalid sslmode: '{v}'. Must be one of: {', '.join(sorted(valid_modes))}"
112 raise ValueError(msg)
113 return v
115 @field_validator("port")
116 @classmethod
117 def validate_port(cls, v: int) -> int:
118 if not (1 <= v <= 65535):
119 msg = f"Invalid port: {v}. Must be 1-65535"
120 raise ValueError(msg)
121 return v
123 @computed_field # type: ignore[prop-decorator]
124 @property
125 def connection_url(self) -> str:
126 userinfo = ""
127 if self.user:
128 if self.password:
129 userinfo = f"{self.user}:{self.password}@"
130 else:
131 userinfo = f"{self.user}@"
132 return (
133 f"postgresql://{userinfo}{self.host}:{self.port}"
134 f"/{self.dbname}?sslmode={self.sslmode}"
135 )
138class AppConfig(BaseModel):
139 default_timeout: float = 30.0
140 default_format: str = "table"
141 default_profile: str | None = None
142 profiles: dict[str, PgProfile] = {}
145class ResolvedConfig(BaseModel):
146 host: str = "localhost"
147 port: int = 5432
148 dbname: str = "postgres"
149 user: str | None = None
150 password: str | None = None
151 sslmode: str = "prefer"
152 connect_timeout: int = 10
153 application_name: str = "sql-tool"
154 default_timeout: float = 30.0
155 default_format: str = "table"
156 default_schema: str | None = None
157 active_profile: str | None = None
158 sources: dict[str, str] = {}
161def load_config(config_path: Path | None = None) -> AppConfig:
162 """Load configuration from TOML file.
164 Returns default AppConfig if file doesn't exist.
165 Raises ConfigError on malformed TOML or invalid config.
166 """
167 if config_path is None:
168 config_path = DEFAULT_CONFIG_PATH
170 if not config_path.exists():
171 return AppConfig()
173 try:
174 with open(config_path, "rb") as f:
175 data = tomllib.load(f)
176 except tomllib.TOMLDecodeError as e:
177 msg = f"Malformed TOML in {config_path}: {e}"
178 raise ConfigError(msg) from e
180 try:
181 return AppConfig.model_validate(data)
182 except Exception as e:
183 msg = f"Invalid configuration in {config_path}: {e}"
184 raise ConfigError(msg) from e
187def resolve_config(
188 config: AppConfig,
189 profile_name: str | None = None,
190 dsn: str | None = None,
191 **cli_overrides: Any,
192) -> ResolvedConfig:
193 """Resolve configuration using precedence chain.
195 CLI > DSN > env > profile > config defaults > built-in defaults.
196 """
197 sources: dict[str, str] = {}
198 resolved: dict[str, Any] = {}
200 # Layer 1: Built-in defaults
201 resolved.update(_PROFILE_DEFAULTS)
202 resolved["default_timeout"] = 30.0
203 resolved["default_format"] = "table"
204 for key in resolved:
205 sources[key] = "default"
207 # Layer 2: Config file global defaults
208 if config.default_timeout != 30.0:
209 resolved["default_timeout"] = config.default_timeout
210 sources["default_timeout"] = "config"
211 if config.default_format != "table":
212 resolved["default_format"] = config.default_format
213 sources["default_format"] = "config"
215 # Layer 3: Named profile
216 effective_profile = profile_name
217 if not effective_profile:
218 effective_profile = os.environ.get("SQL_PROFILE")
219 if not effective_profile:
220 effective_profile = config.default_profile
222 if effective_profile:
223 if effective_profile not in config.profiles:
224 available = (
225 ", ".join(sorted(config.profiles.keys())) if config.profiles else "none"
226 )
227 msg = f"Unknown profile: '{effective_profile}'. Available profiles: {available}"
228 raise ConfigError(msg)
229 profile = config.profiles[effective_profile]
230 for key in profile.model_fields_set:
231 if key in ("dsn",):
232 continue
233 if key in resolved:
234 resolved[key] = getattr(profile, key)
235 sources[key] = f"profile: {effective_profile}"
237 # Layer 4: Environment variables
238 for env_var, field_name in _PG_ENV_VARS.items():
239 value = os.environ.get(env_var)
240 if value is not None:
241 if field_name == "port":
242 try:
243 resolved[field_name] = int(value)
244 except ValueError:
245 msg = f"Invalid {env_var} value: '{value}'. Must be an integer"
246 raise ConfigError(msg) from None
247 else:
248 resolved[field_name] = value
249 sources[field_name] = f"env: {env_var}"
251 # Layer 5: DSN flag
252 if dsn:
253 dsn_fields = parse_dsn(dsn)
254 for key, value in dsn_fields.items():
255 if key in resolved:
256 resolved[key] = value
257 sources[key] = "dsn"
259 # Layer 6: CLI flags (highest priority)
260 cli_to_field = {
261 "host": "host",
262 "port": "port",
263 "database": "dbname",
264 "user": "user",
265 "password": "password", # pragma: allowlist secret
266 "sslmode": "sslmode",
267 "timeout": "default_timeout",
268 "schema": "default_schema",
269 }
270 for cli_name, field_name in cli_to_field.items():
271 value = cli_overrides.get(cli_name)
272 if value is not None:
273 resolved[field_name] = value
274 sources[field_name] = f"cli: --{cli_name}"
276 resolved["active_profile"] = effective_profile
277 resolved["sources"] = sources
278 return ResolvedConfig(**resolved)