Coverage for src / sql_tool / cli / commands / _shared.py: 88%

69 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-14 15:28 -0500

1"""Shared CLI plumbing for command modules. 

2 

3Client creation, format-option handling, and output helpers. 

4Distinct from cli.helpers which contains pure data-formatting functions. 

5""" 

6 

7from __future__ import annotations 

8 

9import sys 

10from typing import TYPE_CHECKING, Any 

11 

12from sql_tool.cli.helpers import fmt_size 

13from sql_tool.cli.output import get_formatter, resolve_format, write_output 

14from sql_tool.core.client import PgClient 

15from sql_tool.core.config import load_config, resolve_config 

16 

17if TYPE_CHECKING: 

18 import typer 

19 

20 from sql_tool.cli.output import OutputFormat 

21 from sql_tool.core.models import QueryResult 

22 

23 

24def get_client(ctx: typer.Context, timeout: float | None = None) -> PgClient: 

25 obj = ctx.ensure_object(dict) 

26 config = load_config(obj.get("config_file")) 

27 

28 cli_overrides: dict[str, Any] = {} 

29 for key in ("host", "port", "database", "user", "password", "schema"): 

30 val = obj.get(key) 

31 if val is not None: 

32 cli_overrides[key] = val 

33 if timeout is not None: 

34 cli_overrides["timeout"] = timeout 

35 

36 resolved = resolve_config( 

37 config, 

38 profile_name=obj.get("profile"), 

39 dsn=obj.get("dsn"), 

40 **cli_overrides, 

41 ) 

42 

43 return PgClient(resolved) 

44 

45 

46def format_options(ctx: typer.Context) -> dict[str, Any]: 

47 obj = ctx.ensure_object(dict) 

48 return { 

49 "format_flag": obj.get("format"), 

50 "compact": obj.get("compact", False), 

51 "width": obj.get("width", 40), 

52 "no_header": obj.get("no_header", False), 

53 } 

54 

55 

56def output_result(ctx: typer.Context, result: QueryResult) -> None: 

57 opts = format_options(ctx) 

58 formatter = get_formatter(**opts) 

59 write_output(formatter, result) 

60 

61 

62def apply_local_format_options( 

63 ctx: typer.Context, 

64 *, 

65 format: OutputFormat | None = None, 

66 table: bool = False, 

67 compact: bool = False, 

68 width: int | None = None, 

69 no_header: bool = False, 

70) -> None: 

71 if format is not None or table or compact or width is not None or no_header: 

72 obj = ctx.ensure_object(dict) 

73 if format is not None: 

74 obj["format"] = format.value 

75 if table: 

76 obj["format"] = "table" 

77 if compact: 

78 obj["compact"] = compact 

79 if width is not None: 

80 obj["width"] = width 

81 if no_header: 

82 obj["no_header"] = no_header 

83 

84 

85def is_table_format(ctx: typer.Context) -> bool: 

86 obj = ctx.ensure_object(dict) 

87 return resolve_format(obj.get("format")) == "table" 

88 

89 

90def size_formatter(ctx: typer.Context) -> tuple[Any, bool]: 

91 is_tbl = is_table_format(ctx) 

92 fmt = fmt_size if is_tbl else (lambda b: str(b or 0)) 

93 return fmt, is_tbl 

94 

95 

96def parse_table_arg(table_arg: str) -> tuple[str, str]: 

97 if "." in table_arg: 

98 schema, table = table_arg.split(".", 1) 

99 return schema, table 

100 return "public", table_arg 

101 

102 

103def preprocess_optional_int_flags() -> None: 

104 """Insert default '10' after bare --head/--tail/--sample when used without a value. 

105 

106 Typer 0.23 doesn't support Click's flag_value parameter, so these options 

107 always require an explicit argument. This preprocessor lets users write 

108 ``--head`` instead of ``--head 10`` by inserting the default before Typer 

109 parses argv. 

110 """ 

111 if "table" not in sys.argv: 

112 return 

113 flags = {"--head", "--tail", "--sample"} 

114 new_argv: list[str] = [] 

115 i = 0 

116 while i < len(sys.argv): 

117 new_argv.append(sys.argv[i]) 

118 if sys.argv[i] in flags: 

119 next_is_value = i + 1 < len(sys.argv) and not sys.argv[i + 1].startswith( 

120 "-" 

121 ) 

122 if not next_is_value: 

123 new_argv.append("10") 

124 i += 1 

125 sys.argv = new_argv