Coverage for src / sql_tool / cli / commands / _shared.py: 88%
69 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-14 15:28 -0500
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-14 15:28 -0500
1"""Shared CLI plumbing for command modules.
3Client creation, format-option handling, and output helpers.
4Distinct from cli.helpers which contains pure data-formatting functions.
5"""
7from __future__ import annotations
9import sys
10from typing import TYPE_CHECKING, Any
12from sql_tool.cli.helpers import fmt_size
13from sql_tool.cli.output import get_formatter, resolve_format, write_output
14from sql_tool.core.client import PgClient
15from sql_tool.core.config import load_config, resolve_config
17if TYPE_CHECKING:
18 import typer
20 from sql_tool.cli.output import OutputFormat
21 from sql_tool.core.models import QueryResult
24def get_client(ctx: typer.Context, timeout: float | None = None) -> PgClient:
25 obj = ctx.ensure_object(dict)
26 config = load_config(obj.get("config_file"))
28 cli_overrides: dict[str, Any] = {}
29 for key in ("host", "port", "database", "user", "password", "schema"):
30 val = obj.get(key)
31 if val is not None:
32 cli_overrides[key] = val
33 if timeout is not None:
34 cli_overrides["timeout"] = timeout
36 resolved = resolve_config(
37 config,
38 profile_name=obj.get("profile"),
39 dsn=obj.get("dsn"),
40 **cli_overrides,
41 )
43 return PgClient(resolved)
46def format_options(ctx: typer.Context) -> dict[str, Any]:
47 obj = ctx.ensure_object(dict)
48 return {
49 "format_flag": obj.get("format"),
50 "compact": obj.get("compact", False),
51 "width": obj.get("width", 40),
52 "no_header": obj.get("no_header", False),
53 }
56def output_result(ctx: typer.Context, result: QueryResult) -> None:
57 opts = format_options(ctx)
58 formatter = get_formatter(**opts)
59 write_output(formatter, result)
62def apply_local_format_options(
63 ctx: typer.Context,
64 *,
65 format: OutputFormat | None = None,
66 table: bool = False,
67 compact: bool = False,
68 width: int | None = None,
69 no_header: bool = False,
70) -> None:
71 if format is not None or table or compact or width is not None or no_header:
72 obj = ctx.ensure_object(dict)
73 if format is not None:
74 obj["format"] = format.value
75 if table:
76 obj["format"] = "table"
77 if compact:
78 obj["compact"] = compact
79 if width is not None:
80 obj["width"] = width
81 if no_header:
82 obj["no_header"] = no_header
85def is_table_format(ctx: typer.Context) -> bool:
86 obj = ctx.ensure_object(dict)
87 return resolve_format(obj.get("format")) == "table"
90def size_formatter(ctx: typer.Context) -> tuple[Any, bool]:
91 is_tbl = is_table_format(ctx)
92 fmt = fmt_size if is_tbl else (lambda b: str(b or 0))
93 return fmt, is_tbl
96def parse_table_arg(table_arg: str) -> tuple[str, str]:
97 if "." in table_arg:
98 schema, table = table_arg.split(".", 1)
99 return schema, table
100 return "public", table_arg
103def preprocess_optional_int_flags() -> None:
104 """Insert default '10' after bare --head/--tail/--sample when used without a value.
106 Typer 0.23 doesn't support Click's flag_value parameter, so these options
107 always require an explicit argument. This preprocessor lets users write
108 ``--head`` instead of ``--head 10`` by inserting the default before Typer
109 parses argv.
110 """
111 if "table" not in sys.argv:
112 return
113 flags = {"--head", "--tail", "--sample"}
114 new_argv: list[str] = []
115 i = 0
116 while i < len(sys.argv):
117 new_argv.append(sys.argv[i])
118 if sys.argv[i] in flags:
119 next_is_value = i + 1 < len(sys.argv) and not sys.argv[i + 1].startswith(
120 "-"
121 )
122 if not next_is_value:
123 new_argv.append("10")
124 i += 1
125 sys.argv = new_argv