Coverage for src/clauth/aws_utils.py: 79%

207 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-28 14:48 -0400

1# Copyright (c) 2025 Mahmood Khordoo 

2# 

3# This software is licensed under the MIT License. 

4# See the LICENSE file in the root directory for details. 

5 

6""" 

7AWS utilities for CLAUTH. 

8 

9This module provides AWS-specific functionality including authentication checking, 

10Bedrock model discovery, and AWS service interactions. It handles AWS SSO 

11authentication verification and retrieves available Bedrock inference profiles. 

12 

13Functions: 

14 user_is_authenticated: Check if user has valid AWS credentials 

15 list_bedrock_profiles: Discover available Bedrock inference profiles 

16""" 

17 

18import typer 

19import subprocess 

20import boto3 

21from rich.console import Console 

22from botocore.config import Config 

23from botocore.exceptions import ( 

24 NoCredentialsError, 

25 ClientError, 

26 BotoCoreError, 

27 TokenRetrievalError, 

28) 

29 

30 

31console = Console() 

32 

33 

34def setup_iam_user_auth(profile: str, region: str) -> bool: 

35 """ 

36 Set up IAM user authentication for solo developers. 

37 

38 Args: 

39 profile: AWS profile name to configure 

40 region: Default AWS region 

41 

42 Returns: 

43 bool: True if setup successful, False otherwise 

44 """ 

45 typer.secho("Setting up IAM user authentication...", fg=typer.colors.BLUE) 

46 typer.echo("You'll need your AWS Access Key ID and Secret Access Key.") 

47 typer.echo( 

48 "Get these from: AWS Console → IAM → Users → [Your User] → Security credentials" 

49 ) 

50 typer.echo() 

51 

52 try: 

53 # Set the region first, so it's the default in the interactive prompt 

54 subprocess.run( 

55 ["aws", "configure", "set", "region", region, "--profile", profile], 

56 check=True, 

57 ) 

58 

59 # Run aws configure for the specific profile to get keys 

60 subprocess.run(["aws", "configure", "--profile", profile], check=True) 

61 

62 # Clear any lingering SSO settings using configparser to ensure they are removed 

63 try: 

64 from pathlib import Path 

65 import configparser 

66 

67 home = Path.home() 

68 aws_config_file = home / ".aws" / "config" 

69 

70 if aws_config_file.exists(): 

71 config_parser = configparser.ConfigParser() 

72 config_parser.read(aws_config_file) 

73 profile_section = f"profile {profile}" 

74 if config_parser.has_section(profile_section): 

75 sso_settings_to_clear = [ 

76 "sso_start_url", 

77 "sso_region", 

78 "sso_account_id", 

79 "sso_role_name", 

80 "sso_session", 

81 ] 

82 for setting in sso_settings_to_clear: 

83 if config_parser.has_option(profile_section, setting): 

84 config_parser.remove_option(profile_section, setting) 

85 with open(aws_config_file, "w") as f: 

86 config_parser.write(f) 

87 except Exception as e: 

88 console.print( 

89 f"[yellow]Warning: Could not clean SSO settings from AWS config: {e}[/yellow]" 

90 ) 

91 

92 # Verify that the credentials are valid 

93 if not user_is_authenticated(profile=profile): 

94 typer.secho( 

95 "❌ IAM authentication failed. Please check your credentials and try again.", 

96 fg=typer.colors.RED, 

97 ) 

98 return False 

99 

100 typer.secho( 

101 f"✅ IAM user authentication configured for profile '{profile}'", 

102 fg=typer.colors.GREEN, 

103 ) 

104 return True 

105 except subprocess.CalledProcessError: 

106 typer.secho( 

107 "❌ Failed to configure IAM user authentication", fg=typer.colors.RED 

108 ) 

109 return False 

110 

111 

112def setup_sso_auth(config, cli_overrides) -> bool: 

113 """ 

114 Set up AWS SSO authentication for enterprise users. 

115 

116 Args: 

117 config: Configuration object with SSO settings 

118 cli_overrides: Dict indicating which CLI parameters were provided 

119 

120 Returns: 

121 bool: True if setup successful, False otherwise 

122 """ 

123 # Set up basic profile configuration before SSO 

124 args = { 

125 "region": config.aws.region, # Pass default AWS region to avoid extra prompt 

126 "output": config.aws.output_format, 

127 "sso_session": config.aws.session_name, # Pre-set session name for consistency 

128 "sso_region": config.aws.region, 

129 } 

130 

131 try: 

132 typer.secho("Configuring AWS profile...", fg=typer.colors.BLUE) 

133 # Setup the basic profile entries 

134 for arg, value in args.items(): 

135 subprocess.run( 

136 [ 

137 "aws", 

138 "configure", 

139 "set", 

140 arg, 

141 value, 

142 "--profile", 

143 config.aws.profile, 

144 ], 

145 check=True, 

146 ) 

147 

148 typer.echo( 

149 "Opening the AWS SSO wizard. AWS CLI will prompt for SSO details. To reset the SSO session run `clauth reset --complete`" 

150 ) 

151 typer.secho( 

152 "Tip: The SSO Start URL typically looks like: https://d-...awsapps.com/start/", 

153 fg=typer.colors.YELLOW, 

154 ) 

155 

156 subprocess.run( 

157 ["aws", "configure", "sso", "--profile", config.aws.profile], check=True 

158 ) 

159 # Check for existing SSO session to reuse sso_start_url, cli complnas if its not availbane in profile but swet in session 

160 # aws cli will complain if sso_start_url is not set in session but set in profile): 

161 

162 existing_sso_start_url = ( 

163 get_existing_sso_start_url(config.aws.session_name) 

164 or config.aws.sso_start_url 

165 ) 

166 if existing_sso_start_url: 

167 typer.echo( 

168 f"Reusing existing SSO Start URL from session '{config.aws.session_name}'" 

169 ) 

170 subprocess.run( 

171 [ 

172 "aws", 

173 "configure", 

174 "set", 

175 "sso_start_url", 

176 existing_sso_start_url, 

177 "--profile", 

178 config.aws.profile, 

179 ] 

180 ) 

181 subprocess.run(["aws", "sso", "login", "--profile", config.aws.profile]) 

182 typer.secho( 

183 f"Authentication successful for profile '{config.aws.profile}'.", 

184 fg=typer.colors.GREEN, 

185 ) 

186 return True 

187 except subprocess.CalledProcessError: 

188 typer.secho("❌ SSO setup failed", fg=typer.colors.RED) 

189 return False 

190 

191 

192# Configuration management command group 

193# config_app = typer.Typer(help="Configuration management commands") 

194# app.add_typer(config_app, name="config") 

195 

196 

197def user_is_authenticated(profile: str) -> bool: 

198 """Check if user is authenticated with AWS using the specified profile.""" 

199 try: 

200 session = boto3.Session(profile_name=profile) 

201 sts = session.client("sts") 

202 ident = sts.get_caller_identity() 

203 account_id = ident["Account"] 

204 # print(f'User account: {account_id}') 

205 return True 

206 except (NoCredentialsError, TokenRetrievalError): 

207 print( 

208 "No credentials found. Please run 'clauth init' to set up authentication." 

209 ) 

210 return False 

211 except ClientError as e: 

212 error_code = e.response["Error"]["Code"] 

213 if error_code in ( 

214 "UnauthorizedSSOToken", 

215 "ExpiredToken", 

216 "InvalidClientTokenId", 

217 ): 

218 print( 

219 f"Credentials expired or invalid. Please run 'clauth init' to re-authenticate." 

220 ) 

221 return False 

222 else: 

223 print(f"Error getting token: {e}") 

224 return False 

225 except Exception as e: 

226 print(f"Unexpected error during authentication: {e}") 

227 return False 

228 

229 

230def get_existing_sso_start_url(session_name: str) -> str | None: 

231 """Get the existing SSO start URL from an SSO session. 

232 

233 Args: 

234 session_name: Name of the SSO session to check 

235 

236 Returns: 

237 str | None: The SSO start URL if found, None otherwise 

238 """ 

239 try: 

240 from pathlib import Path 

241 import configparser 

242 

243 aws_config_file = Path.home() / ".aws" / "config" 

244 if not aws_config_file.exists(): 

245 return None 

246 

247 config_parser = configparser.ConfigParser() 

248 config_parser.read(aws_config_file) 

249 

250 session_section = f"sso-session {session_name}" 

251 if config_parser.has_section(session_section): 

252 return config_parser.get(session_section, "sso_start_url", fallback=None) 

253 

254 return None 

255 

256 except Exception: 

257 # If we can't read the config, return None 

258 return None 

259 

260 

261def remove_sso_session(session_name: str) -> bool: 

262 """Remove SSO session section from ~/.aws/config. 

263 

264 Args: 

265 session_name: Name of the SSO session to remove 

266 

267 Returns: 

268 bool: True if session was removed or didn't exist, False on error 

269 """ 

270 try: 

271 from pathlib import Path 

272 import configparser 

273 

274 # Get AWS config file path 

275 home = Path.home() 

276 aws_config_file = home / ".aws" / "config" 

277 

278 if not aws_config_file.exists(): 

279 console.print("[yellow]No AWS config file found.[/yellow]") 

280 return True 

281 

282 # Read the AWS config file 

283 config_parser = configparser.ConfigParser() 

284 config_parser.read(aws_config_file) 

285 

286 # SSO sessions are stored as [sso-session <name>] 

287 sso_section_name = f"sso-session {session_name}" 

288 

289 if sso_section_name in config_parser.sections(): 

290 config_parser.remove_section(sso_section_name) 

291 

292 # Write back to file 

293 with open(aws_config_file, "w") as f: 

294 config_parser.write(f) 

295 

296 console.print( 

297 f"[green]SUCCESS: Removed SSO session '{session_name}' from AWS config.[/green]" 

298 ) 

299 else: 

300 console.print( 

301 f"[yellow]SSO session '{session_name}' not found in AWS config.[/yellow]" 

302 ) 

303 

304 return True 

305 

306 except Exception as e: 

307 console.print( 

308 f"[red]ERROR: Failed to remove SSO session '{session_name}': {e}[/red]" 

309 ) 

310 return False 

311 

312 

313def clear_sso_cache(profile_name: str = None) -> bool: 

314 """Clear AWS SSO token cache. 

315 

316 Args: 

317 profile_name: Optional profile name for targeted cleanup 

318 

319 Returns: 

320 bool: True if cache was cleared successfully, False on error 

321 """ 

322 try: 

323 import shutil 

324 from pathlib import Path 

325 

326 # Get AWS cache directory 

327 home = Path.home() 

328 aws_cache_dir = home / ".aws" / "sso" / "cache" 

329 

330 if not aws_cache_dir.exists(): 

331 console.print("[yellow]No SSO cache directory found.[/yellow]") 

332 return True 

333 

334 # Clear all SSO cache files 

335 cache_files_deleted = 0 

336 for cache_file in aws_cache_dir.glob("*.json"): 

337 try: 

338 cache_file.unlink() 

339 cache_files_deleted += 1 

340 except Exception as e: 

341 console.print( 

342 f"[yellow]Warning: Could not delete cache file {cache_file.name}: {e}[/yellow]" 

343 ) 

344 

345 if cache_files_deleted > 0: 

346 console.print( 

347 f"[green]SUCCESS: Cleared {cache_files_deleted} SSO cache files.[/green]" 

348 ) 

349 else: 

350 console.print("[yellow]No SSO cache files found to clear.[/yellow]") 

351 

352 return True 

353 

354 except Exception as e: 

355 console.print(f"[red]ERROR: Error clearing SSO cache: {e}[/red]") 

356 return False 

357 

358 

359def delete_aws_credentials_profile(profile_name: str) -> bool: 

360 """Delete an AWS profile from ~/.aws/credentials. 

361 

362 Args: 

363 profile_name: Name of the AWS profile to delete 

364 

365 Returns: 

366 bool: True if profile was deleted or didn't exist, False on error 

367 """ 

368 try: 

369 from pathlib import Path 

370 import configparser 

371 

372 home = Path.home() 

373 aws_credentials_file = home / ".aws" / "credentials" 

374 

375 if not aws_credentials_file.exists(): 

376 console.print( 

377 "[yellow]No AWS credentials file found to delete profile from.[/yellow]" 

378 ) 

379 return True 

380 

381 config_parser = configparser.ConfigParser() 

382 config_parser.read(aws_credentials_file) 

383 

384 if config_parser.has_section(profile_name): 

385 config_parser.remove_section(profile_name) 

386 with open(aws_credentials_file, "w") as f: 

387 config_parser.write(f) 

388 console.print( 

389 f"[green]SUCCESS: AWS credentials for profile '{profile_name}' deleted successfully.[/green]" 

390 ) 

391 else: 

392 console.print( 

393 f"[yellow]AWS credentials for profile '{profile_name}' do not exist.[/yellow]" 

394 ) 

395 

396 return True 

397 

398 except Exception as e: 

399 console.print( 

400 f"[red]ERROR: Unexpected error deleting AWS credentials profile: {e}[/red]" 

401 ) 

402 return False 

403 

404 

405def delete_aws_profile(profile_name: str) -> bool: 

406 """Delete an AWS profile from ~/.aws/config. 

407 

408 Args: 

409 profile_name: Name of the AWS profile to delete 

410 

411 Returns: 

412 bool: True if profile was deleted or didn't exist, False on error 

413 """ 

414 try: 

415 from pathlib import Path 

416 import configparser 

417 

418 home = Path.home() 

419 aws_config_file = home / ".aws" / "config" 

420 

421 if not aws_config_file.exists(): 

422 console.print( 

423 "[yellow]No AWS config file found to delete profile from.[/yellow]" 

424 ) 

425 return True 

426 

427 config_parser = configparser.ConfigParser() 

428 config_parser.read(aws_config_file) 

429 

430 profile_section = f"profile {profile_name}" 

431 if config_parser.has_section(profile_section): 

432 config_parser.remove_section(profile_section) 

433 with open(aws_config_file, "w") as f: 

434 config_parser.write(f) 

435 console.print( 

436 f"[green]SUCCESS: AWS profile '{profile_name}' deleted successfully.[/green]" 

437 ) 

438 else: 

439 console.print( 

440 f"[yellow]AWS profile '{profile_name}' does not exist in config file.[/yellow]" 

441 ) 

442 

443 return True 

444 

445 except Exception as e: 

446 console.print(f"[red]ERROR: Unexpected error deleting AWS profile: {e}[/red]") 

447 return False 

448 

449 

450def list_bedrock_profiles( 

451 profile: str, region: str, provider: str = "anthropic", sort: bool = True 

452) -> tuple[list[str], list[str]]: 

453 """ 

454 List available Bedrock inference profiles for the specified provider. 

455 

456 Args: 

457 profile: AWS profile name to use 

458 region: AWS region to query 

459 provider: Model provider to filter by (default: 'anthropic') 

460 sort: Whether to sort results in reverse order (default: True) 

461 

462 Returns: 

463 Tuple of (model_ids, model_arns) lists 

464 """ 

465 try: 

466 session = boto3.Session(profile_name=profile, region_name=region) 

467 client = session.client("bedrock") 

468 

469 resp = client.list_inference_profiles() 

470 inference_summaries = resp.get("inferenceProfileSummaries", []) 

471 

472 if not inference_summaries: 

473 print(f"No inference profiles found in region {region}") 

474 return [], [] 

475 

476 model_arns = [p["inferenceProfileArn"] for p in inference_summaries] 

477 

478 if model_arns and sort: 

479 model_arns.sort(reverse=True) 

480 

481 # Filter by provider 

482 model_arn_by_provider = [ 

483 arn for arn in model_arns if provider.lower() in arn.lower() 

484 ] 

485 

486 if not model_arn_by_provider: 

487 print(f"No models found for provider '{provider}' in region {region}") 

488 return [], [] 

489 

490 model_ids = [arn.split("/")[-1] for arn in model_arn_by_provider] 

491 return model_ids, model_arn_by_provider 

492 

493 except (BotoCoreError, ClientError) as e: 

494 print(f"Error listing inference profiles: {e}") 

495 return [], [] 

496 except Exception as e: 

497 print(f"Unexpected error listing models: {e}") 

498 return [], [] 

499 

500 

501if __name__ == "__main__": 

502 p = list_bedrock_profiles(profile="clauth", region="ap-southeast-2") 

503 print("===============") 

504 print(p)