Coverage for src/clauth/config.py: 86%
145 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-28 14:38 -0400
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-28 14:38 -0400
1# Copyright (c) 2025 Mahmood Khordoo
2#
3# This software is licensed under the MIT License.
4# See the LICENSE file in the root directory for details.
6"""
7Configuration management for CLAUTH.
9This module handles persistent configuration storage and management for CLAUTH.
10It provides a configuration system using TOML files with support for multiple
11profiles, environment variable overrides, and platform-appropriate storage locations.
13Classes:
14 AwsConfig: AWS-related configuration settings
15 ModelConfig: Model selection and provider configuration
16 CliConfig: CLI behavior and appearance settings
17 ClauthConfig: Main configuration container
18 ConfigManager: Configuration file management and operations
19"""
21import os
22import toml
23from pathlib import Path
24from typing import Optional, Dict, Any
25from pydantic import BaseModel, Field, field_validator, ConfigDict
28class AWSConfig(BaseModel):
29 """AWS-related configuration settings."""
31 profile: str = Field(default="clauth", description="AWS profile name")
32 region: str = Field(default="ap-southeast-2", description="Default AWS region")
33 sso_start_url: Optional[str] = Field(
34 default=None,
35 description="IAM Identity Center (SSO) start URL (e.g., https://d-xxxxxxxxxx.awsapps.com/start/)",
36 )
37 sso_region: str = Field(default="ap-southeast-2", description="SSO region")
38 session_name: str = Field(default="clauth-session", description="SSO session name")
39 output_format: str = Field(default="json", description="AWS CLI output format")
41 @field_validator("sso_start_url")
42 @classmethod
43 def validate_sso_url(cls, v: str) -> Optional[str]:
44 if v is not None and not v.startswith("https://"):
45 raise ValueError("SSO start URL must be HTTPS")
46 return v
49class ModelConfig(BaseModel):
50 """Model-related configuration settings."""
52 provider_filter: str = Field(
53 default="anthropic", description="Preferred model provider"
54 )
55 default_model: Optional[str] = Field(None, description="Default model ID")
56 fast_model: Optional[str] = Field(None, description="Fast/small model ID")
57 default_model_arn: Optional[str] = Field(None, description="Default model ARN")
58 fast_model_arn: Optional[str] = Field(None, description="Fast model ARN")
61class CLIConfig(BaseModel):
62 """CLI behavior and appearance settings."""
64 claude_cli_name: str = Field(
65 default="claude", description="Claude CLI executable name"
66 )
67 auto_start: bool = Field(
68 default=True, description="Auto-launch Claude Code after setup"
69 )
70 show_progress: bool = Field(default=True, description="Show progress indicators")
71 color_output: bool = Field(default=True, description="Enable colored output")
73 # UI styling
74 pointer_style: str = Field(default="❯", description="Menu pointer character")
75 selected_color: str = Field(default="ansiblue", description="Selected item color")
76 highlighted_color: str = Field(
77 default="ansiblue", description="Highlighted item color"
78 )
81class ClauthConfig(BaseModel):
82 """Main CLAUTH configuration container."""
84 aws: AWSConfig = Field(default_factory=AWSConfig)
85 models: ModelConfig = Field(default_factory=ModelConfig)
86 cli: CLIConfig = Field(default_factory=CLIConfig)
88 model_config = ConfigDict(extra="forbid")
91class ConfigManager:
92 """Manages CLAUTH configuration loading, saving, and validation."""
94 def __init__(self, config_dir: Optional[Path] = None):
95 """Initialize ConfigManager with optional custom config directory."""
96 self.config_dir = config_dir or self._get_default_config_dir()
97 self.config_file = self.config_dir / "config.toml"
98 self.profiles_dir = self.config_dir / "profiles"
99 self._config: Optional[ClauthConfig] = None
101 # Ensure config directory exists
102 self.config_dir.mkdir(parents=True, exist_ok=True)
103 self.profiles_dir.mkdir(parents=True, exist_ok=True)
105 def _get_default_config_dir(self) -> Path:
106 """Get the default configuration directory."""
107 if os.name == "nt": # Windows
108 config_home = Path(
109 os.environ.get("APPDATA", Path.home() / "AppData" / "Roaming")
110 )
111 else: # Unix-like
112 config_home = Path(
113 os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config")
114 )
116 return config_home / "clauth"
118 def load(self, profile: Optional[str] = None) -> ClauthConfig:
119 """Load configuration from file with optional profile support."""
120 config_file = self._get_config_file(profile)
122 if config_file.exists():
123 try:
124 with open(config_file, "r", encoding="utf-8") as f:
125 config_data = toml.load(f)
126 self._config = ClauthConfig(**config_data)
127 except (toml.TomlDecodeError, ValueError) as e:
128 # If config is invalid, create default and save it
129 print(f"Warning: Invalid config file, creating default: {e}")
130 self._config = ClauthConfig()
131 self.save()
132 else:
133 # Create default configuration
134 self._config = ClauthConfig()
135 self.save()
137 # Migrate legacy placeholder URLs
138 self._migrate_placeholder_urls()
140 # Apply environment variable overrides
141 self._apply_env_overrides()
143 return self._config
145 def save(self, profile: Optional[str] = None) -> None:
146 """Save current configuration to file."""
147 if self._config is None:
148 raise ValueError("No configuration loaded to save")
150 config_file = self._get_config_file(profile)
151 config_data = self._config.model_dump()
153 with open(config_file, "w", encoding="utf-8") as f:
154 toml.dump(config_data, f)
156 def _get_config_file(self, profile: Optional[str] = None) -> Path:
157 """Get the config file path for a specific profile."""
158 if profile:
159 return self.profiles_dir / f"{profile}.toml"
160 return self.config_file
162 def _apply_env_overrides(self) -> None:
163 """Apply environment variable overrides to configuration."""
164 if self._config is None:
165 return
167 # AWS configuration overrides
168 if env_profile := os.environ.get("CLAUTH_PROFILE"):
169 self._config.aws.profile = env_profile
170 if env_region := os.environ.get("CLAUTH_REGION"):
171 self._config.aws.region = env_region
172 if env_sso_url := os.environ.get("CLAUTH_SSO_START_URL"):
173 self._config.aws.sso_start_url = env_sso_url
174 if env_sso_region := os.environ.get("CLAUTH_SSO_REGION"):
175 self._config.aws.sso_region = env_sso_region
176 if env_session_name := os.environ.get("CLAUTH_SESSION_NAME"):
177 self._config.aws.session_name = env_session_name
179 # CLI configuration overrides
180 if env_claude_cli := os.environ.get("CLAUTH_CLAUDE_CLI_NAME"):
181 self._config.cli.claude_cli_name = env_claude_cli
182 if env_auto_start := os.environ.get("CLAUTH_AUTO_START"):
183 self._config.cli.auto_start = env_auto_start.lower() in ("true", "1", "yes")
185 # Model configuration overrides
186 if env_provider := os.environ.get("CLAUTH_PROVIDER_FILTER"):
187 self._config.models.provider_filter = env_provider
188 if env_default_model := os.environ.get("CLAUTH_DEFAULT_MODEL"):
189 self._config.models.default_model = env_default_model
190 if env_fast_model := os.environ.get("CLAUTH_FAST_MODEL"):
191 self._config.models.fast_model = env_fast_model
193 def _migrate_placeholder_urls(self) -> None:
194 """Migrate legacy placeholder SSO URLs to None."""
195 if self._config is None:
196 return
198 # Check if sso_start_url contains the old placeholder
199 if (
200 self._config.aws.sso_start_url
201 and self._config.aws.sso_start_url
202 == "https://d-xxxxxxxxxx.awsapps.com/start/"
203 ):
204 print(
205 "Warning: Migrating placeholder SSO start URL to None. You'll need to provide a real URL during init."
206 )
207 self._config.aws.sso_start_url = None
208 # Save the migrated config
209 self.save()
211 @property
212 def config(self) -> ClauthConfig:
213 """Get current configuration, loading if necessary."""
214 if self._config is None:
215 self.load()
216 return self._config
218 def update_model_settings(
219 self, default_model: str, fast_model: str, default_arn: str, fast_arn: str
220 ) -> None:
221 """Update model settings and save configuration."""
222 if self._config is None:
223 self.load()
225 self._config.models.default_model = default_model
226 self._config.models.fast_model = fast_model
227 self._config.models.default_model_arn = default_arn
228 self._config.models.fast_model_arn = fast_arn
229 self.save()
231 def get_custom_style(self) -> Dict[str, str]:
232 """Get InquirerPy custom style based on configuration."""
233 cli_config = self.config.cli
234 return {
235 "questionmark": "bold",
236 "instruction": "#858585",
237 "answer": "bold",
238 "pointer": cli_config.selected_color,
239 "highlighted": cli_config.highlighted_color,
240 "selected": cli_config.selected_color,
241 "border": cli_config.selected_color,
242 }
244 def list_profiles(self) -> list[str]:
245 """List available configuration profiles."""
246 if not self.profiles_dir.exists():
247 return []
249 profiles = []
250 for file in self.profiles_dir.glob("*.toml"):
251 profiles.append(file.stem)
252 return sorted(profiles)
254 def profile_exists(self, profile: str) -> bool:
255 """Check if a configuration profile exists."""
256 return (self.profiles_dir / f"{profile}.toml").exists()
258 def delete_profile(self, profile: str) -> bool:
259 """Delete a configuration profile."""
260 profile_file = self.profiles_dir / f"{profile}.toml"
261 if profile_file.exists():
262 profile_file.unlink()
263 return True
264 return False
267# Global config manager instance
268_config_manager: Optional[ConfigManager] = None
271def get_config_manager() -> ConfigManager:
272 """Get or create the global ConfigManager instance."""
273 global _config_manager
274 if _config_manager is None:
275 _config_manager = ConfigManager()
276 return _config_manager
279def get_config(profile: Optional[str] = None) -> ClauthConfig:
280 """Convenience function to get configuration."""
281 return get_config_manager().load(profile)