Coverage for src/clauth/commands/models.py: 80%
66 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-28 14:37 -0400
« prev ^ index » next coverage.py v7.10.7, created at 2025-09-28 14:37 -0400
1"""
2CLAUTH Model Management Commands.
4This module provides commands for listing and managing Bedrock models,
5organized under a 'model' subcommand group.
6"""
8import typer
9from clauth.config import get_config_manager
10from clauth.aws_utils import user_is_authenticated, list_bedrock_profiles
11from clauth.helpers import handle_authentication_failure
12from rich.console import Console
13from InquirerPy import inquirer, get_style
15console = Console()
17model_app = typer.Typer(
18 name="model",
19 help="Manage and switch between Bedrock models.",
20 no_args_is_help=True,
21)
24@model_app.command("list")
25def list_models(
26 profile: str = typer.Option(None, "--profile", "-p", help="AWS profile to use"),
27 region: str = typer.Option(None, "--region", "-r", help="AWS region to use"),
28 show_arn: bool = typer.Option(False, "--show-arn", help="Show model ARNs"),
29):
30 """
31 List available Bedrock models.
32 """
33 # Load configuration and apply CLI overrides
34 config_manager = get_config_manager()
35 config = config_manager.load()
37 if profile is not None:
38 config.aws.profile = profile
39 if region is not None:
40 config.aws.region = region
42 if not user_is_authenticated(profile=config.aws.profile):
43 if not handle_authentication_failure(config.aws.profile):
44 raise typer.Exit(1)
46 with console.status("[bold blue]Fetching available models...") as status:
47 model_ids, model_arns = list_bedrock_profiles(
48 profile=config.aws.profile,
49 region=config.aws.region,
50 provider=config.models.provider_filter,
51 )
52 for model_id, model_arn in zip(model_ids, model_arns):
53 if show_arn:
54 print(model_id, " --> ", model_arn)
55 else:
56 print(model_id)
59@model_app.command("switch")
60def switch_models(
61 profile: str = typer.Option(None, "--profile", "-p", help="AWS profile to use"),
62 region: str = typer.Option(None, "--region", "-r", help="AWS region to use"),
63 default_only: bool = typer.Option(
64 False, "--default-only", help="Only change default model"
65 ),
66 fast_only: bool = typer.Option(False, "--fast-only", help="Only change fast model"),
67):
68 """
69 Interactively switch the default and fast models.
70 """
71 # Load configuration and apply CLI overrides
72 config_manager = get_config_manager()
73 config = config_manager.load()
75 if profile is not None:
76 config.aws.profile = profile
77 if region is not None:
78 config.aws.region = region
80 # Validate that both flags aren't set
81 if default_only and fast_only:
82 typer.secho(
83 "Error: Cannot use both --default-only and --fast-only flags together.",
84 fg=typer.colors.RED,
85 )
86 raise typer.Exit(1)
88 # Check authentication
89 if not user_is_authenticated(profile=config.aws.profile):
90 if not handle_authentication_failure(config.aws.profile):
91 raise typer.Exit(1)
93 # Check if models are configured
94 if not config.models.default_model or not config.models.fast_model:
95 typer.secho(
96 "Model configuration missing. Run 'clauth init' for initial setup.",
97 fg=typer.colors.RED,
98 )
99 raise typer.Exit(1)
101 # Show current models
102 console.print("\n[bold cyan]Current Models[/bold cyan]")
103 console.print(f" Default: [green]{config.models.default_model}[/green]")
104 console.print(f" Fast: [green]{config.models.fast_model}[/green]")
105 console.print()
107 # Discover available models
108 with console.status("[bold blue]Discovering available models...") as status:
109 model_ids, model_arns = list_bedrock_profiles(
110 profile=config.aws.profile,
111 region=config.aws.region,
112 provider=config.models.provider_filter,
113 )
115 if not model_ids:
116 typer.secho(
117 "No models found. Check your AWS permissions and region.",
118 fg=typer.colors.RED,
119 )
120 raise typer.Exit(1)
122 # Create model map for ARN lookup
123 model_map = {id: arn for id, arn in zip(model_ids, model_arns)}
125 # Get custom style for inquirer
126 custom_style = get_style(config_manager.get_custom_style())
128 # Initialize with current values
129 new_default_model = config.models.default_model
130 new_fast_model = config.models.fast_model
132 # Interactive model selection
133 if not fast_only:
134 # Select new default model
135 new_default_model = inquirer.select(
136 message="Select new default model:",
137 instruction="↑↓ move • Enter select",
138 pointer="▶ ",
139 amark="✔",
140 choices=model_ids,
141 default=config.models.default_model
142 if config.models.default_model in model_ids
143 else (model_ids[0] if model_ids else None),
144 style=custom_style,
145 max_height="100%",
146 ).execute()
148 if not default_only:
149 # Select new fast model
150 new_fast_model = inquirer.select(
151 message="Select new small/fast model:",
152 instruction="↑↓ move • Enter select",
153 pointer="▶ ",
154 amark="✔",
155 choices=model_ids,
156 default=config.models.fast_model
157 if config.models.fast_model in model_ids
158 else (model_ids[-1] if model_ids else None),
159 style=custom_style,
160 max_height="100%",
161 ).execute()
163 # Check if anything changed
164 if (
165 new_default_model == config.models.default_model
166 and new_fast_model == config.models.fast_model
167 ):
168 console.print("[yellow]No changes made.[/yellow]")
169 return
171 # Update configuration
172 config_manager.update_model_settings(
173 default_model=new_default_model,
174 fast_model=new_fast_model,
175 default_arn=model_map[new_default_model],
176 fast_arn=model_map[new_fast_model],
177 )
179 # Show confirmation
180 console.print("\n[bold green]✅ Models updated successfully![/bold green]")
181 console.print(f" Default: [green]{new_default_model}[/green]")
182 console.print(f" Fast: [green]{new_fast_model}[/green]")