Coverage for src / invariant / cli.py: 32.31%

130 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-06 12:18 +0000

1"""Command-line interface for executing serialized Invariant graphs.""" 

2 

3import argparse 

4import json 

5import sys 

6from dataclasses import dataclass 

7from pathlib import Path 

8from typing import Any, TextIO 

9 

10from invariant.executor import Executor 

11from invariant.graph import Graph 

12from invariant.graph_serialization import ( 

13 dump_value_to_jsonable, 

14 load_graph_document_from_dict, 

15 load_value_from_jsonable, 

16) 

17from invariant.protocol import ICacheable 

18from invariant.registry import OpRegistry 

19from invariant.store.null import NullStore 

20from invariant.yaml_serialization import _load_yaml_document 

21 

22 

23@dataclass(frozen=True) 

24class _CliOutput: 

25 value: Any 

26 is_mapping: bool 

27 selected_key: str | None 

28 

29 

30def _build_parser() -> argparse.ArgumentParser: 

31 parser = argparse.ArgumentParser( 

32 prog="invariant", 

33 description="Execute a serialized Invariant graph and emit JSON results.", 

34 ) 

35 parser.add_argument( 

36 "graph", 

37 nargs="?", 

38 default="-", 

39 help="Path to graph JSON or YAML. Reads stdin when omitted or '-'.", 

40 ) 

41 parser.add_argument( 

42 "--input-format", 

43 choices=["auto", "json", "yaml"], 

44 default="auto", 

45 help=( 

46 "Graph input format. auto detects .yaml/.yml files; stdin defaults " 

47 "to JSON." 

48 ), 

49 ) 

50 parser.add_argument( 

51 "--context", 

52 metavar="CONTEXT_FILE", 

53 help="Path to a JSON object containing external context values.", 

54 ) 

55 parser.add_argument( 

56 "--param", 

57 action="append", 

58 default=[], 

59 metavar="KEY=VALUE", 

60 help=( 

61 "Override or add one external context value. VALUE accepts JSON " 

62 "scalars/objects, Invariant JSON markers, and bare strings." 

63 ), 

64 ) 

65 parser.add_argument( 

66 "--pick", 

67 action="append", 

68 default=[], 

69 metavar="KEY", 

70 help="Requested graph output. May be supplied multiple times.", 

71 ) 

72 parser.add_argument( 

73 "--pretty", 

74 action="store_true", 

75 help="Emit indented JSON.", 

76 ) 

77 parser.add_argument( 

78 "-o", 

79 "--output", 

80 metavar="FILE", 

81 help="Write output to FILE instead of stdout.", 

82 ) 

83 parser.add_argument( 

84 "--output-format", 

85 choices=["auto", "json", "binary"], 

86 default="auto", 

87 help=( 

88 "File output format. auto writes selected ICacheable values as binary " 

89 "and everything else as JSON." 

90 ), 

91 ) 

92 return parser 

93 

94 

95def _read_graph_arg(graph_arg: str, stdin: TextIO) -> str: 

96 if graph_arg == "-": 

97 return stdin.read() 

98 return Path(graph_arg).read_text(encoding="utf-8") 

99 

100 

101def _detect_input_format(graph_arg: str, input_format: str) -> str: 

102 if input_format != "auto": 

103 return input_format 

104 if graph_arg == "-": 

105 return "json" 

106 suffix = Path(graph_arg).suffix.lower() 

107 if suffix in {".yaml", ".yml"}: 

108 return "yaml" 

109 return "json" 

110 

111 

112def _load_input_document( 

113 data: str, *, graph_arg: str = "-", input_format: str = "auto" 

114) -> tuple[Graph, str | None]: 

115 detected_format = _detect_input_format(graph_arg, input_format) 

116 obj = ( 

117 _load_yaml_document(data) 

118 if detected_format == "yaml" 

119 else json.loads(data) 

120 ) 

121 if not isinstance(obj, dict): 

122 raise ValueError("Graph document must be an object") 

123 

124 return load_graph_document_from_dict(obj) 

125 

126 

127def _load_context(path: str | None) -> dict[str, Any]: 

128 if path is None: 

129 return {} 

130 

131 obj = json.loads(Path(path).read_text(encoding="utf-8")) 

132 if not isinstance(obj, dict): 

133 raise ValueError("Context document must be a JSON object") 

134 

135 return {key: load_value_from_jsonable(value) for key, value in obj.items()} 

136 

137 

138def _parse_param_value(value: str) -> Any: 

139 try: 

140 return load_value_from_jsonable(json.loads(value)) 

141 except json.JSONDecodeError: 

142 return value 

143 

144 

145def _parse_param(param: str) -> tuple[str, Any]: 

146 key, separator, value = param.partition("=") 

147 if separator == "" or not key: 

148 raise ValueError("--param must be in KEY=VALUE form") 

149 return key, _parse_param_value(value) 

150 

151 

152def _load_params(params: list[str]) -> dict[str, Any]: 

153 return dict(_parse_param(param) for param in params) 

154 

155 

156def _encode_result_context(results: dict[str, Any]) -> dict[str, Any]: 

157 return {key: dump_value_to_jsonable(value) for key, value in results.items()} 

158 

159 

160def _execute_cli(args: argparse.Namespace, stdin: TextIO) -> _CliOutput: 

161 graph, document_output = _load_input_document( 

162 _read_graph_arg(args.graph, stdin), 

163 graph_arg=args.graph, 

164 input_format=args.input_format, 

165 ) 

166 context = _load_context(args.context) 

167 context.update(_load_params(args.param)) 

168 

169 requested_outputs = tuple(args.pick) if args.pick else () 

170 if not requested_outputs and document_output is not None: 

171 requested_outputs = (document_output,) 

172 if not requested_outputs: 

173 raise ValueError( 

174 "Graph document has no default output; supply at least one --pick" 

175 ) 

176 

177 registry = OpRegistry() 

178 registry.clear() 

179 registry.auto_discover() 

180 

181 executor = Executor(registry, NullStore()) 

182 results = executor.execute(graph, requested_outputs, context=context) 

183 if len(requested_outputs) == 1: 

184 selected_key = requested_outputs[0] 

185 return _CliOutput( 

186 results[selected_key], 

187 is_mapping=False, 

188 selected_key=selected_key, 

189 ) 

190 return _CliOutput(results, is_mapping=True, selected_key=None) 

191 

192 

193def _jsonable_output(output: _CliOutput) -> Any: 

194 if output.is_mapping: 

195 return _encode_result_context(output.value) 

196 return dump_value_to_jsonable(output.value) 

197 

198 

199def _write_json_output( 

200 output: _CliOutput, stream: TextIO, *, pretty: bool 

201) -> None: 

202 json.dump( 

203 _jsonable_output(output), 

204 stream, 

205 indent=2 if pretty else None, 

206 separators=None if pretty else (",", ":"), 

207 sort_keys=True, 

208 ) 

209 stream.write("\n") 

210 

211 

212def _write_binary_output(output: _CliOutput, path: Path) -> None: 

213 if output.is_mapping: 

214 raise ValueError("Binary output requires exactly one selected output") 

215 

216 value = output.value 

217 if not isinstance(value, ICacheable): 

218 selected = f" '{output.selected_key}'" if output.selected_key else "" 

219 raise ValueError( 

220 f"Output{selected} is {type(value).__name__}, not an ICacheable value" 

221 ) 

222 

223 path.parent.mkdir(parents=True, exist_ok=True) 

224 to_file = getattr(value, "to_file", None) 

225 if callable(to_file): 

226 to_file(path) 

227 return 

228 

229 with path.open("wb") as stream: 

230 value.to_stream(stream) 

231 

232 

233def _write_output_file( 

234 output: _CliOutput, 

235 *, 

236 path: Path, 

237 output_format: str, 

238 pretty: bool, 

239) -> None: 

240 if output_format == "binary" or ( 

241 output_format == "auto" 

242 and not output.is_mapping 

243 and isinstance(output.value, ICacheable) 

244 ): 

245 _write_binary_output(output, path) 

246 return 

247 

248 if output_format == "auto" or output_format == "json": 

249 path.parent.mkdir(parents=True, exist_ok=True) 

250 with path.open("w", encoding="utf-8") as stream: 

251 _write_json_output(output, stream, pretty=pretty) 

252 return 

253 

254 raise ValueError(f"Unsupported output format: {output_format}") 

255 

256 

257def main(argv: list[str] | None = None) -> int: 

258 parser = _build_parser() 

259 args = parser.parse_args(argv) 

260 

261 try: 

262 output = _execute_cli(args, sys.stdin) 

263 if args.output: 

264 _write_output_file( 

265 output, 

266 path=Path(args.output), 

267 output_format=args.output_format, 

268 pretty=args.pretty, 

269 ) 

270 else: 

271 _write_json_output(output, sys.stdout, pretty=args.pretty) 

272 except Exception as e: 

273 print(f"invariant: error: {e}", file=sys.stderr) 

274 return 1 

275 

276 return 0 

277 

278 

279if __name__ == "__main__": 

280 raise SystemExit(main())