Coverage for src/clauth/commands/models.py: 80%

66 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-28 14:37 -0400

1""" 

2CLAUTH Model Management Commands. 

3 

4This module provides commands for listing and managing Bedrock models, 

5organized under a 'model' subcommand group. 

6""" 

7 

8import typer 

9from clauth.config import get_config_manager 

10from clauth.aws_utils import user_is_authenticated, list_bedrock_profiles 

11from clauth.helpers import handle_authentication_failure 

12from rich.console import Console 

13from InquirerPy import inquirer, get_style 

14 

15console = Console() 

16 

17model_app = typer.Typer( 

18 name="model", 

19 help="Manage and switch between Bedrock models.", 

20 no_args_is_help=True, 

21) 

22 

23 

24@model_app.command("list") 

25def list_models( 

26 profile: str = typer.Option(None, "--profile", "-p", help="AWS profile to use"), 

27 region: str = typer.Option(None, "--region", "-r", help="AWS region to use"), 

28 show_arn: bool = typer.Option(False, "--show-arn", help="Show model ARNs"), 

29): 

30 """ 

31 List available Bedrock models. 

32 """ 

33 # Load configuration and apply CLI overrides 

34 config_manager = get_config_manager() 

35 config = config_manager.load() 

36 

37 if profile is not None: 

38 config.aws.profile = profile 

39 if region is not None: 

40 config.aws.region = region 

41 

42 if not user_is_authenticated(profile=config.aws.profile): 

43 if not handle_authentication_failure(config.aws.profile): 

44 raise typer.Exit(1) 

45 

46 with console.status("[bold blue]Fetching available models...") as status: 

47 model_ids, model_arns = list_bedrock_profiles( 

48 profile=config.aws.profile, 

49 region=config.aws.region, 

50 provider=config.models.provider_filter, 

51 ) 

52 for model_id, model_arn in zip(model_ids, model_arns): 

53 if show_arn: 

54 print(model_id, " --> ", model_arn) 

55 else: 

56 print(model_id) 

57 

58 

59@model_app.command("switch") 

60def switch_models( 

61 profile: str = typer.Option(None, "--profile", "-p", help="AWS profile to use"), 

62 region: str = typer.Option(None, "--region", "-r", help="AWS region to use"), 

63 default_only: bool = typer.Option( 

64 False, "--default-only", help="Only change default model" 

65 ), 

66 fast_only: bool = typer.Option(False, "--fast-only", help="Only change fast model"), 

67): 

68 """ 

69 Interactively switch the default and fast models. 

70 """ 

71 # Load configuration and apply CLI overrides 

72 config_manager = get_config_manager() 

73 config = config_manager.load() 

74 

75 if profile is not None: 

76 config.aws.profile = profile 

77 if region is not None: 

78 config.aws.region = region 

79 

80 # Validate that both flags aren't set 

81 if default_only and fast_only: 

82 typer.secho( 

83 "Error: Cannot use both --default-only and --fast-only flags together.", 

84 fg=typer.colors.RED, 

85 ) 

86 raise typer.Exit(1) 

87 

88 # Check authentication 

89 if not user_is_authenticated(profile=config.aws.profile): 

90 if not handle_authentication_failure(config.aws.profile): 

91 raise typer.Exit(1) 

92 

93 # Check if models are configured 

94 if not config.models.default_model or not config.models.fast_model: 

95 typer.secho( 

96 "Model configuration missing. Run 'clauth init' for initial setup.", 

97 fg=typer.colors.RED, 

98 ) 

99 raise typer.Exit(1) 

100 

101 # Show current models 

102 console.print("\n[bold cyan]Current Models[/bold cyan]") 

103 console.print(f" Default: [green]{config.models.default_model}[/green]") 

104 console.print(f" Fast: [green]{config.models.fast_model}[/green]") 

105 console.print() 

106 

107 # Discover available models 

108 with console.status("[bold blue]Discovering available models...") as status: 

109 model_ids, model_arns = list_bedrock_profiles( 

110 profile=config.aws.profile, 

111 region=config.aws.region, 

112 provider=config.models.provider_filter, 

113 ) 

114 

115 if not model_ids: 

116 typer.secho( 

117 "No models found. Check your AWS permissions and region.", 

118 fg=typer.colors.RED, 

119 ) 

120 raise typer.Exit(1) 

121 

122 # Create model map for ARN lookup 

123 model_map = {id: arn for id, arn in zip(model_ids, model_arns)} 

124 

125 # Get custom style for inquirer 

126 custom_style = get_style(config_manager.get_custom_style()) 

127 

128 # Initialize with current values 

129 new_default_model = config.models.default_model 

130 new_fast_model = config.models.fast_model 

131 

132 # Interactive model selection 

133 if not fast_only: 

134 # Select new default model 

135 new_default_model = inquirer.select( 

136 message="Select new default model:", 

137 instruction="↑↓ move • Enter select", 

138 pointer="▶ ", 

139 amark="✔", 

140 choices=model_ids, 

141 default=config.models.default_model 

142 if config.models.default_model in model_ids 

143 else (model_ids[0] if model_ids else None), 

144 style=custom_style, 

145 max_height="100%", 

146 ).execute() 

147 

148 if not default_only: 

149 # Select new fast model 

150 new_fast_model = inquirer.select( 

151 message="Select new small/fast model:", 

152 instruction="↑↓ move • Enter select", 

153 pointer="▶ ", 

154 amark="✔", 

155 choices=model_ids, 

156 default=config.models.fast_model 

157 if config.models.fast_model in model_ids 

158 else (model_ids[-1] if model_ids else None), 

159 style=custom_style, 

160 max_height="100%", 

161 ).execute() 

162 

163 # Check if anything changed 

164 if ( 

165 new_default_model == config.models.default_model 

166 and new_fast_model == config.models.fast_model 

167 ): 

168 console.print("[yellow]No changes made.[/yellow]") 

169 return 

170 

171 # Update configuration 

172 config_manager.update_model_settings( 

173 default_model=new_default_model, 

174 fast_model=new_fast_model, 

175 default_arn=model_map[new_default_model], 

176 fast_arn=model_map[new_fast_model], 

177 ) 

178 

179 # Show confirmation 

180 console.print("\n[bold green]✅ Models updated successfully![/bold green]") 

181 console.print(f" Default: [green]{new_default_model}[/green]") 

182 console.print(f" Fast: [green]{new_fast_model}[/green]")