Coverage for src\clauth\config.py: 37%

137 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-18 10:09 -0400

1"""Configuration management for CLAUTH.""" 

2 

3import os 

4import toml 

5from pathlib import Path 

6from typing import Optional, Dict, Any 

7from pydantic import BaseModel, Field, validator 

8 

9 

10class AWSConfig(BaseModel): 

11 """AWS-related configuration settings.""" 

12 

13 profile: str = Field(default="clauth", description="AWS profile name") 

14 region: str = Field(default="ap-southeast-2", description="Default AWS region") 

15 sso_start_url: str = Field( 

16 default="https://d-97671967ae.awsapps.com/start/#", 

17 description="IAM Identity Center (SSO) start URL" 

18 ) 

19 sso_region: str = Field(default="ap-southeast-2", description="SSO region") 

20 session_name: str = Field(default="clauth-session", description="SSO session name") 

21 output_format: str = Field(default="json", description="AWS CLI output format") 

22 

23 @validator('sso_start_url') 

24 def validate_sso_url(cls, v): 

25 if not v.startswith('https://'): 

26 raise ValueError('SSO start URL must be HTTPS') 

27 return v 

28 

29 

30class ModelConfig(BaseModel): 

31 """Model-related configuration settings.""" 

32 

33 provider_filter: str = Field(default="anthropic", description="Preferred model provider") 

34 default_model: Optional[str] = Field(None, description="Default model ID") 

35 fast_model: Optional[str] = Field(None, description="Fast/small model ID") 

36 default_model_arn: Optional[str] = Field(None, description="Default model ARN") 

37 fast_model_arn: Optional[str] = Field(None, description="Fast model ARN") 

38 

39 

40class CLIConfig(BaseModel): 

41 """CLI behavior and appearance settings.""" 

42 

43 claude_cli_name: str = Field(default="claude", description="Claude CLI executable name") 

44 auto_start: bool = Field(default=True, description="Auto-launch Claude Code after setup") 

45 show_progress: bool = Field(default=True, description="Show progress indicators") 

46 color_output: bool = Field(default=True, description="Enable colored output") 

47 

48 # UI styling 

49 pointer_style: str = Field(default="❯", description="Menu pointer character") 

50 selected_color: str = Field(default="ansiblue", description="Selected item color") 

51 highlighted_color: str = Field(default="ansiblue", description="Highlighted item color") 

52 

53 

54class ClauthConfig(BaseModel): 

55 """Main CLAUTH configuration container.""" 

56 

57 aws: AWSConfig = Field(default_factory=AWSConfig) 

58 models: ModelConfig = Field(default_factory=ModelConfig) 

59 cli: CLIConfig = Field(default_factory=CLIConfig) 

60 

61 class Config: 

62 extra = "forbid" # Don't allow extra fields 

63 

64 

65class ConfigManager: 

66 """Manages CLAUTH configuration loading, saving, and validation.""" 

67 

68 def __init__(self, config_dir: Optional[Path] = None): 

69 """Initialize ConfigManager with optional custom config directory.""" 

70 self.config_dir = config_dir or self._get_default_config_dir() 

71 self.config_file = self.config_dir / "config.toml" 

72 self.profiles_dir = self.config_dir / "profiles" 

73 self._config: Optional[ClauthConfig] = None 

74 

75 # Ensure config directory exists 

76 self.config_dir.mkdir(parents=True, exist_ok=True) 

77 self.profiles_dir.mkdir(parents=True, exist_ok=True) 

78 

79 def _get_default_config_dir(self) -> Path: 

80 """Get the default configuration directory.""" 

81 if os.name == 'nt': # Windows 

82 config_home = Path(os.environ.get('APPDATA', Path.home() / 'AppData' / 'Roaming')) 

83 else: # Unix-like 

84 config_home = Path(os.environ.get('XDG_CONFIG_HOME', Path.home() / '.config')) 

85 

86 return config_home / 'clauth' 

87 

88 def load(self, profile: Optional[str] = None) -> ClauthConfig: 

89 """Load configuration from file with optional profile support.""" 

90 config_file = self._get_config_file(profile) 

91 

92 if config_file.exists(): 

93 try: 

94 with open(config_file, 'r', encoding='utf-8') as f: 

95 config_data = toml.load(f) 

96 self._config = ClauthConfig(**config_data) 

97 except (toml.TomlDecodeError, ValueError) as e: 

98 # If config is invalid, create default and save it 

99 print(f"Warning: Invalid config file, creating default: {e}") 

100 self._config = ClauthConfig() 

101 self.save() 

102 else: 

103 # Create default configuration 

104 self._config = ClauthConfig() 

105 self.save() 

106 

107 # Apply environment variable overrides 

108 self._apply_env_overrides() 

109 

110 return self._config 

111 

112 def save(self, profile: Optional[str] = None) -> None: 

113 """Save current configuration to file.""" 

114 if self._config is None: 

115 raise ValueError("No configuration loaded to save") 

116 

117 config_file = self._get_config_file(profile) 

118 config_data = self._config.dict() 

119 

120 with open(config_file, 'w', encoding='utf-8') as f: 

121 toml.dump(config_data, f) 

122 

123 def _get_config_file(self, profile: Optional[str] = None) -> Path: 

124 """Get the config file path for a specific profile.""" 

125 if profile: 

126 return self.profiles_dir / f"{profile}.toml" 

127 return self.config_file 

128 

129 def _apply_env_overrides(self) -> None: 

130 """Apply environment variable overrides to configuration.""" 

131 if self._config is None: 

132 return 

133 

134 # AWS configuration overrides 

135 if env_profile := os.environ.get('CLAUTH_PROFILE'): 

136 self._config.aws.profile = env_profile 

137 if env_region := os.environ.get('CLAUTH_REGION'): 

138 self._config.aws.region = env_region 

139 if env_sso_url := os.environ.get('CLAUTH_SSO_START_URL'): 

140 self._config.aws.sso_start_url = env_sso_url 

141 if env_sso_region := os.environ.get('CLAUTH_SSO_REGION'): 

142 self._config.aws.sso_region = env_sso_region 

143 if env_session_name := os.environ.get('CLAUTH_SESSION_NAME'): 

144 self._config.aws.session_name = env_session_name 

145 

146 # CLI configuration overrides 

147 if env_claude_cli := os.environ.get('CLAUTH_CLAUDE_CLI_NAME'): 

148 self._config.cli.claude_cli_name = env_claude_cli 

149 if env_auto_start := os.environ.get('CLAUTH_AUTO_START'): 

150 self._config.cli.auto_start = env_auto_start.lower() in ('true', '1', 'yes') 

151 

152 # Model configuration overrides 

153 if env_provider := os.environ.get('CLAUTH_PROVIDER_FILTER'): 

154 self._config.models.provider_filter = env_provider 

155 if env_default_model := os.environ.get('CLAUTH_DEFAULT_MODEL'): 

156 self._config.models.default_model = env_default_model 

157 if env_fast_model := os.environ.get('CLAUTH_FAST_MODEL'): 

158 self._config.models.fast_model = env_fast_model 

159 

160 @property 

161 def config(self) -> ClauthConfig: 

162 """Get current configuration, loading if necessary.""" 

163 if self._config is None: 

164 self.load() 

165 return self._config 

166 

167 def update_model_settings(self, default_model: str, fast_model: str, 

168 default_arn: str, fast_arn: str) -> None: 

169 """Update model settings and save configuration.""" 

170 if self._config is None: 

171 self.load() 

172 

173 self._config.models.default_model = default_model 

174 self._config.models.fast_model = fast_model 

175 self._config.models.default_model_arn = default_arn 

176 self._config.models.fast_model_arn = fast_arn 

177 self.save() 

178 

179 def get_custom_style(self) -> Dict[str, str]: 

180 """Get InquirerPy custom style based on configuration.""" 

181 cli_config = self.config.cli 

182 return { 

183 "questionmark": "bold", 

184 "instruction": "dim", 

185 "answer": "bold", 

186 "pointer": cli_config.selected_color, 

187 "highlighted": cli_config.highlighted_color, 

188 "selected": cli_config.selected_color 

189 } 

190 

191 def list_profiles(self) -> list[str]: 

192 """List available configuration profiles.""" 

193 if not self.profiles_dir.exists(): 

194 return [] 

195 

196 profiles = [] 

197 for file in self.profiles_dir.glob("*.toml"): 

198 profiles.append(file.stem) 

199 return sorted(profiles) 

200 

201 def profile_exists(self, profile: str) -> bool: 

202 """Check if a configuration profile exists.""" 

203 return (self.profiles_dir / f"{profile}.toml").exists() 

204 

205 def delete_profile(self, profile: str) -> bool: 

206 """Delete a configuration profile.""" 

207 profile_file = self.profiles_dir / f"{profile}.toml" 

208 if profile_file.exists(): 

209 profile_file.unlink() 

210 return True 

211 return False 

212 

213 

214# Global config manager instance 

215_config_manager: Optional[ConfigManager] = None 

216 

217 

218def get_config_manager() -> ConfigManager: 

219 """Get or create the global ConfigManager instance.""" 

220 global _config_manager 

221 if _config_manager is None: 

222 _config_manager = ConfigManager() 

223 return _config_manager 

224 

225 

226def get_config(profile: Optional[str] = None) -> ClauthConfig: 

227 """Convenience function to get configuration.""" 

228 return get_config_manager().load(profile)