Coverage for src/clauth/config.py: 86%

145 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-28 14:38 -0400

1# Copyright (c) 2025 Mahmood Khordoo 

2# 

3# This software is licensed under the MIT License. 

4# See the LICENSE file in the root directory for details. 

5 

6""" 

7Configuration management for CLAUTH. 

8 

9This module handles persistent configuration storage and management for CLAUTH. 

10It provides a configuration system using TOML files with support for multiple 

11profiles, environment variable overrides, and platform-appropriate storage locations. 

12 

13Classes: 

14 AwsConfig: AWS-related configuration settings 

15 ModelConfig: Model selection and provider configuration 

16 CliConfig: CLI behavior and appearance settings 

17 ClauthConfig: Main configuration container 

18 ConfigManager: Configuration file management and operations 

19""" 

20 

21import os 

22import toml 

23from pathlib import Path 

24from typing import Optional, Dict, Any 

25from pydantic import BaseModel, Field, field_validator, ConfigDict 

26 

27 

28class AWSConfig(BaseModel): 

29 """AWS-related configuration settings.""" 

30 

31 profile: str = Field(default="clauth", description="AWS profile name") 

32 region: str = Field(default="ap-southeast-2", description="Default AWS region") 

33 sso_start_url: Optional[str] = Field( 

34 default=None, 

35 description="IAM Identity Center (SSO) start URL (e.g., https://d-xxxxxxxxxx.awsapps.com/start/)", 

36 ) 

37 sso_region: str = Field(default="ap-southeast-2", description="SSO region") 

38 session_name: str = Field(default="clauth-session", description="SSO session name") 

39 output_format: str = Field(default="json", description="AWS CLI output format") 

40 

41 @field_validator("sso_start_url") 

42 @classmethod 

43 def validate_sso_url(cls, v: str) -> Optional[str]: 

44 if v is not None and not v.startswith("https://"): 

45 raise ValueError("SSO start URL must be HTTPS") 

46 return v 

47 

48 

49class ModelConfig(BaseModel): 

50 """Model-related configuration settings.""" 

51 

52 provider_filter: str = Field( 

53 default="anthropic", description="Preferred model provider" 

54 ) 

55 default_model: Optional[str] = Field(None, description="Default model ID") 

56 fast_model: Optional[str] = Field(None, description="Fast/small model ID") 

57 default_model_arn: Optional[str] = Field(None, description="Default model ARN") 

58 fast_model_arn: Optional[str] = Field(None, description="Fast model ARN") 

59 

60 

61class CLIConfig(BaseModel): 

62 """CLI behavior and appearance settings.""" 

63 

64 claude_cli_name: str = Field( 

65 default="claude", description="Claude CLI executable name" 

66 ) 

67 auto_start: bool = Field( 

68 default=True, description="Auto-launch Claude Code after setup" 

69 ) 

70 show_progress: bool = Field(default=True, description="Show progress indicators") 

71 color_output: bool = Field(default=True, description="Enable colored output") 

72 

73 # UI styling 

74 pointer_style: str = Field(default="❯", description="Menu pointer character") 

75 selected_color: str = Field(default="ansiblue", description="Selected item color") 

76 highlighted_color: str = Field( 

77 default="ansiblue", description="Highlighted item color" 

78 ) 

79 

80 

81class ClauthConfig(BaseModel): 

82 """Main CLAUTH configuration container.""" 

83 

84 aws: AWSConfig = Field(default_factory=AWSConfig) 

85 models: ModelConfig = Field(default_factory=ModelConfig) 

86 cli: CLIConfig = Field(default_factory=CLIConfig) 

87 

88 model_config = ConfigDict(extra="forbid") 

89 

90 

91class ConfigManager: 

92 """Manages CLAUTH configuration loading, saving, and validation.""" 

93 

94 def __init__(self, config_dir: Optional[Path] = None): 

95 """Initialize ConfigManager with optional custom config directory.""" 

96 self.config_dir = config_dir or self._get_default_config_dir() 

97 self.config_file = self.config_dir / "config.toml" 

98 self.profiles_dir = self.config_dir / "profiles" 

99 self._config: Optional[ClauthConfig] = None 

100 

101 # Ensure config directory exists 

102 self.config_dir.mkdir(parents=True, exist_ok=True) 

103 self.profiles_dir.mkdir(parents=True, exist_ok=True) 

104 

105 def _get_default_config_dir(self) -> Path: 

106 """Get the default configuration directory.""" 

107 if os.name == "nt": # Windows 

108 config_home = Path( 

109 os.environ.get("APPDATA", Path.home() / "AppData" / "Roaming") 

110 ) 

111 else: # Unix-like 

112 config_home = Path( 

113 os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config") 

114 ) 

115 

116 return config_home / "clauth" 

117 

118 def load(self, profile: Optional[str] = None) -> ClauthConfig: 

119 """Load configuration from file with optional profile support.""" 

120 config_file = self._get_config_file(profile) 

121 

122 if config_file.exists(): 

123 try: 

124 with open(config_file, "r", encoding="utf-8") as f: 

125 config_data = toml.load(f) 

126 self._config = ClauthConfig(**config_data) 

127 except (toml.TomlDecodeError, ValueError) as e: 

128 # If config is invalid, create default and save it 

129 print(f"Warning: Invalid config file, creating default: {e}") 

130 self._config = ClauthConfig() 

131 self.save() 

132 else: 

133 # Create default configuration 

134 self._config = ClauthConfig() 

135 self.save() 

136 

137 # Migrate legacy placeholder URLs 

138 self._migrate_placeholder_urls() 

139 

140 # Apply environment variable overrides 

141 self._apply_env_overrides() 

142 

143 return self._config 

144 

145 def save(self, profile: Optional[str] = None) -> None: 

146 """Save current configuration to file.""" 

147 if self._config is None: 

148 raise ValueError("No configuration loaded to save") 

149 

150 config_file = self._get_config_file(profile) 

151 config_data = self._config.model_dump() 

152 

153 with open(config_file, "w", encoding="utf-8") as f: 

154 toml.dump(config_data, f) 

155 

156 def _get_config_file(self, profile: Optional[str] = None) -> Path: 

157 """Get the config file path for a specific profile.""" 

158 if profile: 

159 return self.profiles_dir / f"{profile}.toml" 

160 return self.config_file 

161 

162 def _apply_env_overrides(self) -> None: 

163 """Apply environment variable overrides to configuration.""" 

164 if self._config is None: 

165 return 

166 

167 # AWS configuration overrides 

168 if env_profile := os.environ.get("CLAUTH_PROFILE"): 

169 self._config.aws.profile = env_profile 

170 if env_region := os.environ.get("CLAUTH_REGION"): 

171 self._config.aws.region = env_region 

172 if env_sso_url := os.environ.get("CLAUTH_SSO_START_URL"): 

173 self._config.aws.sso_start_url = env_sso_url 

174 if env_sso_region := os.environ.get("CLAUTH_SSO_REGION"): 

175 self._config.aws.sso_region = env_sso_region 

176 if env_session_name := os.environ.get("CLAUTH_SESSION_NAME"): 

177 self._config.aws.session_name = env_session_name 

178 

179 # CLI configuration overrides 

180 if env_claude_cli := os.environ.get("CLAUTH_CLAUDE_CLI_NAME"): 

181 self._config.cli.claude_cli_name = env_claude_cli 

182 if env_auto_start := os.environ.get("CLAUTH_AUTO_START"): 

183 self._config.cli.auto_start = env_auto_start.lower() in ("true", "1", "yes") 

184 

185 # Model configuration overrides 

186 if env_provider := os.environ.get("CLAUTH_PROVIDER_FILTER"): 

187 self._config.models.provider_filter = env_provider 

188 if env_default_model := os.environ.get("CLAUTH_DEFAULT_MODEL"): 

189 self._config.models.default_model = env_default_model 

190 if env_fast_model := os.environ.get("CLAUTH_FAST_MODEL"): 

191 self._config.models.fast_model = env_fast_model 

192 

193 def _migrate_placeholder_urls(self) -> None: 

194 """Migrate legacy placeholder SSO URLs to None.""" 

195 if self._config is None: 

196 return 

197 

198 # Check if sso_start_url contains the old placeholder 

199 if ( 

200 self._config.aws.sso_start_url 

201 and self._config.aws.sso_start_url 

202 == "https://d-xxxxxxxxxx.awsapps.com/start/" 

203 ): 

204 print( 

205 "Warning: Migrating placeholder SSO start URL to None. You'll need to provide a real URL during init." 

206 ) 

207 self._config.aws.sso_start_url = None 

208 # Save the migrated config 

209 self.save() 

210 

211 @property 

212 def config(self) -> ClauthConfig: 

213 """Get current configuration, loading if necessary.""" 

214 if self._config is None: 

215 self.load() 

216 return self._config 

217 

218 def update_model_settings( 

219 self, default_model: str, fast_model: str, default_arn: str, fast_arn: str 

220 ) -> None: 

221 """Update model settings and save configuration.""" 

222 if self._config is None: 

223 self.load() 

224 

225 self._config.models.default_model = default_model 

226 self._config.models.fast_model = fast_model 

227 self._config.models.default_model_arn = default_arn 

228 self._config.models.fast_model_arn = fast_arn 

229 self.save() 

230 

231 def get_custom_style(self) -> Dict[str, str]: 

232 """Get InquirerPy custom style based on configuration.""" 

233 cli_config = self.config.cli 

234 return { 

235 "questionmark": "bold", 

236 "instruction": "#858585", 

237 "answer": "bold", 

238 "pointer": cli_config.selected_color, 

239 "highlighted": cli_config.highlighted_color, 

240 "selected": cli_config.selected_color, 

241 "border": cli_config.selected_color, 

242 } 

243 

244 def list_profiles(self) -> list[str]: 

245 """List available configuration profiles.""" 

246 if not self.profiles_dir.exists(): 

247 return [] 

248 

249 profiles = [] 

250 for file in self.profiles_dir.glob("*.toml"): 

251 profiles.append(file.stem) 

252 return sorted(profiles) 

253 

254 def profile_exists(self, profile: str) -> bool: 

255 """Check if a configuration profile exists.""" 

256 return (self.profiles_dir / f"{profile}.toml").exists() 

257 

258 def delete_profile(self, profile: str) -> bool: 

259 """Delete a configuration profile.""" 

260 profile_file = self.profiles_dir / f"{profile}.toml" 

261 if profile_file.exists(): 

262 profile_file.unlink() 

263 return True 

264 return False 

265 

266 

267# Global config manager instance 

268_config_manager: Optional[ConfigManager] = None 

269 

270 

271def get_config_manager() -> ConfigManager: 

272 """Get or create the global ConfigManager instance.""" 

273 global _config_manager 

274 if _config_manager is None: 

275 _config_manager = ConfigManager() 

276 return _config_manager 

277 

278 

279def get_config(profile: Optional[str] = None) -> ClauthConfig: 

280 """Convenience function to get configuration.""" 

281 return get_config_manager().load(profile)