Coverage for src/clauth/helpers.py: 92%

88 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-28 14:37 -0400

1# Copyright (c) 2025 Mahmood Khordoo 

2# 

3# This software is licensed under the MIT License. 

4# See the LICENSE file in the root directory for details. 

5 

6""" 

7CLAUTH Shared Utility Functions. 

8 

9This module contains utility functions used across multiple CLI commands 

10and modules to avoid circular imports. 

11""" 

12 

13import os 

14import shutil 

15import subprocess 

16import typer 

17from clauth.config import get_config_manager 

18from clauth.aws_utils import list_bedrock_profiles 

19from InquirerPy import inquirer 

20from rich.console import Console 

21from pyfiglet import Figlet 

22from textwrap import dedent 

23from InquirerPy import get_style 

24 

25console = Console() 

26 

27 

28class ExecutableNotFoundError(Exception): 

29 """Raised when executable cannot be found in system PATH.""" 

30 

31 pass 

32 

33 

34def clear_screen(): 

35 """Clear the terminal screen in a cross-platform manner.""" 

36 os.system("cls" if os.name == "nt" else "clear") 

37 

38 

39def show_welcome_logo(console: Console) -> None: 

40 """ 

41 Display the CLAUTH welcome logo. 

42 

43 Args: 

44 console: Rich console instance for styled output 

45 """ 

46 f = Figlet(font="slant") 

47 logo = f.renderText("CLAUTH") 

48 console.print(logo, style="bold cyan") 

49 

50 console.print( 

51 dedent(""" 

52 [bold]Welcome to CLAUTH[/bold] 

53 Let’s set up your environment for Claude Code on Amazon Bedrock. 

54 

55 Prerequisites: 

56 • AWS CLI v2 

57 • Claude Code CLI 

58 

59 Tip: run [bold]clauth init --help[/bold] to view options. 

60 """).strip() 

61 ) 

62 

63 

64def choose_auth_method(): 

65 """ 

66 Interactive authentication method selection. 

67 

68 Returns: 

69 str: Selected authentication method ('sso', 'iam', or 'skip') 

70 """ 

71 from InquirerPy import inquirer 

72 from clauth.config import get_config_manager 

73 

74 # Get custom style 

75 config_manager = get_config_manager() 

76 custom_style = get_style(config_manager.get_custom_style()) 

77 

78 return inquirer.select( 

79 message="Choose your authentication method:", 

80 instruction="↑↓ move • Enter select", 

81 choices=[ 

82 {"name": "AWS SSO (for teams/organizations)", "value": "sso"}, 

83 {"name": "IAM User Access Keys (for solo developers)", "value": "iam"}, 

84 {"name": "Skip (I'm already configured)", "value": "skip"}, 

85 ], 

86 pointer="▶ ", 

87 amark="✔", 

88 style=custom_style, 

89 max_height="100%", 

90 ).execute() 

91 

92 

93def get_app_path(exe_name: str = "claude") -> str: 

94 """Find the full path to an executable in a cross-platform way. 

95 

96 On Windows, prefers .cmd and .exe versions when multiple variants exist, 

97 matching the original behavior that selected the .cmd version specifically. 

98 

99 Args: 

100 exe_name: Name of the executable to find 

101 

102 Returns: 

103 Full path to the executable 

104 

105 Raises: 

106 ExecutableNotFoundError: If executable is not found in PATH 

107 ValueError: If executable name is invalid 

108 """ 

109 if not exe_name or not exe_name.strip(): 

110 raise ValueError(f"Invalid executable name provided: {exe_name!r}") 

111 

112 # First, try the basic lookup 

113 claude_path = shutil.which(exe_name) 

114 if claude_path is None: 

115 raise ExecutableNotFoundError( 

116 f"{exe_name} not found in system PATH. Please ensure it is installed and in your PATH." 

117 ) 

118 

119 # On Windows, prefer .cmd/.exe versions if they exist (matches original behavior) 

120 if os.name == "nt": 

121 preferred_extensions = [".cmd", ".exe"] 

122 for ext in preferred_extensions: 

123 if not exe_name.lower().endswith(ext): 

124 preferred_path = shutil.which(exe_name + ext) 

125 if preferred_path: 

126 typer.echo( 

127 f"Found multiple {exe_name} executables, using: {preferred_path}" 

128 ) 

129 return preferred_path 

130 

131 typer.echo(f"Using executable: {claude_path}") 

132 return claude_path 

133 

134 

135def is_sso_profile(profile: str) -> bool: 

136 """ 

137 Check if a given AWS profile is configured for SSO. 

138 

139 Args: 

140 profile: AWS profile name to check 

141 

142 Returns: 

143 bool: True if profile has SSO configuration, False otherwise 

144 """ 

145 try: 

146 result = subprocess.run( 

147 ["aws", "configure", "get", "sso_start_url", "--profile", profile], 

148 capture_output=True, 

149 text=True, 

150 check=False, 

151 ) 

152 return result.returncode == 0 and result.stdout.strip() 

153 except Exception: 

154 return False 

155 

156 

157def handle_authentication_failure(profile: str) -> bool: 

158 """ 

159 Handle authentication failure with appropriate method based on profile type. 

160 

161 For SSO profiles, attempts automatic re-authentication. 

162 For non-SSO profiles, directs user to run clauth init. 

163 

164 Args: 

165 profile: AWS profile name that failed authentication 

166 

167 Returns: 

168 bool: True if successfully authenticated, False otherwise 

169 """ 

170 if is_sso_profile(profile): 

171 typer.secho( 

172 "SSO token expired. Attempting to re-authenticate...", 

173 fg=typer.colors.YELLOW, 

174 ) 

175 try: 

176 subprocess.run(["aws", "sso", "login", "--profile", profile], check=True) 

177 typer.secho( 

178 f"Successfully re-authenticated with profile '{profile}'", 

179 fg=typer.colors.GREEN, 

180 ) 

181 return True 

182 except subprocess.CalledProcessError: 

183 typer.secho( 

184 "SSO login failed. Run 'clauth init' for full setup.", 

185 fg=typer.colors.RED, 

186 ) 

187 return False 

188 else: 

189 # Non-SSO profile - direct to init 

190 typer.secho( 

191 "Authentication required. Please run 'clauth init' to set up authentication.", 

192 fg=typer.colors.RED, 

193 ) 

194 return False 

195 

196 

197def prompt_for_region_if_needed(config, cli_overrides): 

198 """Prompt user for AWS region if not provided.""" 

199 if not cli_overrides.get("region"): 

200 console.print("\n[bold]AWS Region Selection[/bold]") 

201 console.print("Please select your preferred AWS region.") 

202 console.print("This will be used for default AWS services.\n") 

203 

204 custom_region_option = "Other (enter custom region)" 

205 region_options = [ 

206 "us-east-1", 

207 "us-west-2", 

208 "eu-west-1", 

209 "ap-southeast-1", 

210 "ap-southeast-2", 

211 "ap-northeast-1", 

212 "ca-central-1", 

213 custom_region_option, 

214 ] 

215 

216 selected_option = inquirer.select( 

217 message="Select your AWS region:", 

218 instruction="↑↓ move • Enter select", 

219 choices=region_options, 

220 default=config.aws.region 

221 if config.aws.region in region_options 

222 else "us-east-1", 

223 pointer="▶ ", 

224 amark="✔", 

225 ).execute() 

226 

227 if selected_option == custom_region_option: 

228 custom_region = typer.prompt("AWS Region") 

229 if not custom_region or not custom_region.replace("-", "").isalnum(): 

230 typer.secho("Error: Invalid region format.", fg=typer.colors.RED) 

231 return False 

232 selected_region = custom_region 

233 else: 

234 selected_region = selected_option 

235 

236 config.aws.region = selected_region 

237 get_config_manager()._config = config 

238 get_config_manager().save() 

239 console.print(f"[green]✓ Region set to: {selected_region}[/green]\n") 

240 return True 

241 

242 

243def validate_model_id(id: str): 

244 """ 

245 Validate that a model ID exists in available Bedrock profiles. 

246 

247 Args: 

248 id: Model ID to validate 

249 

250 Returns: 

251 str: The validated model ID 

252 

253 Raises: 

254 typer.Exit: If model ID is not found in available models 

255 """ 

256 config = get_config_manager().load() 

257 with console.status("[bold blue]Validating model ID...") as status: 

258 model_ids, model_arns = list_bedrock_profiles( 

259 profile=config.aws.profile, 

260 region=config.aws.region, 

261 provider=config.models.provider_filter, 

262 ) 

263 if id not in model_ids: 

264 raise typer.BadParameter( 

265 f"{id} is not valid or supported model. Valid Models: {model_ids}" 

266 ) 

267 return id