Coverage for src / invariant / graph_serialization.py: 83.87%

248 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-03 19:45 +0000

1"""Graph serialization: JSON wire format for Invariant graphs. 

2 

3Encodes graphs (Node, SubGraphNode) and params (ref, cel, Decimal, tuple, 

4ICacheable) for storage and transmission. Distinct from artifact serialization 

5in store/codec.py. 

6""" 

7 

8import base64 

9import importlib 

10import json 

11from decimal import Decimal 

12from io import BytesIO 

13from typing import Any 

14 

15from invariant.graph import Graph 

16from invariant.node import Node, SubGraphNode 

17from invariant.params import cel, ref 

18from invariant.protocol import ICacheable 

19 

20SUPPORTED_VERSIONS = {1} 

21FORMAT_ID = "invariant-graph" 

22GRAPH_OUTPUT_MEDIA_TYPE = "application/vnd.invariant.graph-output+json" 

23GRAPH_OUTPUT_DATA_URI_PREFIX = f"data:{GRAPH_OUTPUT_MEDIA_TYPE};base64," 

24 

25RESERVED_KEYS = frozenset( 

26 {"$ref", "$cel", "$decimal", "$tuple", "$literal", "$icacheable"} 

27) 

28 

29 

30def _encode_param_value(value: Any) -> Any: 

31 """Recursively encode a parameter value to JSON-serializable form.""" 

32 # ref marker 

33 if isinstance(value, ref): 

34 return {"$ref": value.dep} 

35 

36 # cel marker 

37 if isinstance(value, cel): 

38 return {"$cel": value.expr} 

39 

40 # Decimal 

41 if isinstance(value, Decimal): 

42 return {"$decimal": str(value)} 

43 

44 # tuple 

45 if isinstance(value, tuple): 

46 return {"$tuple": [_encode_param_value(item) for item in value]} 

47 

48 # ICacheable 

49 if isinstance(value, ICacheable): 

50 type_name = f"{value.__class__.__module__}.{value.__class__.__name__}" 

51 if hasattr(value, "to_json_value") and callable(value.to_json_value): 

52 return {"$icacheable": {"type": type_name, "value": value.to_json_value()}} 

53 stream = BytesIO() 

54 value.to_stream(stream) 

55 payload_b64 = base64.b64encode(stream.getvalue()).decode("ascii") 

56 return {"$icacheable": {"type": type_name, "payload_b64": payload_b64}} 

57 

58 # dict 

59 if isinstance(value, dict): 

60 encoded = {k: _encode_param_value(v) for k, v in value.items()} 

61 # Collision: plain dict that would decode as marker -> wrap in $literal 

62 if len(encoded) == 1: 

63 (single_key,) = encoded.keys() 

64 if single_key in RESERVED_KEYS: 

65 return {"$literal": encoded} 

66 return encoded 

67 

68 # list 

69 if isinstance(value, list): 

70 return [_encode_param_value(item) for item in value] 

71 

72 # Primitives: None, bool, int, str 

73 return value 

74 

75 

76def _decode_param_value(obj: Any, literal_mode: bool = False) -> Any: 

77 """Recursively decode a JSON value to Python parameter value.""" 

78 # In literal mode, never treat dicts as markers 

79 if literal_mode: 

80 if isinstance(obj, dict): 

81 return { 

82 k: _decode_param_value(v, literal_mode=True) for k, v in obj.items() 

83 } 

84 if isinstance(obj, list): 

85 return [_decode_param_value(item, literal_mode=True) for item in obj] 

86 return obj 

87 

88 # Single-key dict with reserved key -> marker or escape 

89 if isinstance(obj, dict): 

90 if len(obj) == 1: 

91 (key, val) = next(iter(obj.items())) 

92 if key == "$ref": 

93 return ref(val) 

94 if key == "$cel": 

95 return cel(val) 

96 if key == "$decimal": 

97 return Decimal(val) 

98 if key == "$tuple": 

99 return tuple(_decode_param_value(item) for item in val) 

100 if key == "$literal": 

101 return _decode_param_value(val, literal_mode=True) 

102 if key == "$icacheable": 

103 return _decode_icacheable(val) 

104 # Multi-key or non-reserved: recursive decode 

105 return {k: _decode_param_value(v) for k, v in obj.items()} 

106 

107 if isinstance(obj, list): 

108 return [_decode_param_value(item) for item in obj] 

109 

110 return obj 

111 

112 

113def _decode_icacheable(obj: dict) -> Any: 

114 """Decode $icacheable object to ICacheable instance.""" 

115 if not isinstance(obj, dict): 

116 raise ValueError("$icacheable value must be an object") 

117 type_name = obj.get("type") 

118 if not type_name or not isinstance(type_name, str): 

119 raise ValueError("$icacheable must have non-empty string 'type'") 

120 if "payload_b64" in obj and "value" in obj: 

121 raise ValueError( 

122 "$icacheable must have exactly one of 'payload_b64' or 'value'" 

123 ) 

124 if "payload_b64" not in obj and "value" not in obj: 

125 raise ValueError("$icacheable must have 'payload_b64' or 'value'") 

126 

127 module_path, class_name = type_name.rsplit(".", 1) 

128 try: 

129 module = importlib.import_module(module_path) 

130 cls = getattr(module, class_name) 

131 except (ImportError, AttributeError) as e: 

132 raise ValueError( 

133 f"$icacheable type '{type_name}' could not be imported: {e}" 

134 ) from e 

135 

136 if "value" in obj: 

137 if not hasattr(cls, "from_json_value"): 

138 raise ValueError( 

139 f"$icacheable type '{type_name}' has 'value' but no " 

140 "from_json_value method" 

141 ) 

142 return cls.from_json_value(obj["value"]) 

143 

144 # payload_b64 

145 try: 

146 payload = base64.b64decode(obj["payload_b64"]) 

147 except Exception as e: 

148 raise ValueError(f"$icacheable payload_b64 is invalid base64: {e}") from e 

149 stream = BytesIO(payload) 

150 try: 

151 return cls.from_stream(stream) 

152 except Exception as e: 

153 raise ValueError( 

154 f"$icacheable from_stream failed for '{type_name}': {e}" 

155 ) from e 

156 

157 

158def _encode_params(params: dict[str, Any]) -> dict[str, Any]: 

159 """Encode params dict with sorted keys for determinism.""" 

160 return dict(sorted((k, _encode_param_value(v)) for k, v in params.items())) 

161 

162 

163def _decode_params(obj: dict) -> dict[str, Any]: 

164 """Decode params dict.""" 

165 return {k: _decode_param_value(v) for k, v in obj.items()} 

166 

167 

168def dump_value_to_jsonable(value: Any) -> Any: 

169 """Serialize a cacheable value to the graph JSON marker encoding.""" 

170 return _encode_param_value(value) 

171 

172 

173def load_value_from_jsonable(obj: Any) -> Any: 

174 """Deserialize a value from the graph JSON marker encoding.""" 

175 return _decode_param_value(obj) 

176 

177 

178def _encode_vertex(vertex: Node | SubGraphNode) -> dict: 

179 """Encode a single vertex (Node or SubGraphNode) to JSON object.""" 

180 if isinstance(vertex, Node): 

181 result: dict = { 

182 "kind": "node", 

183 "op_name": vertex.op_name, 

184 "params": _encode_params(vertex.params), 

185 "deps": sorted(vertex.deps), 

186 } 

187 if not vertex.cache: 

188 result["cache"] = False 

189 return result 

190 # SubGraphNode 

191 return { 

192 "kind": "subgraph", 

193 "params": _encode_params(vertex.params), 

194 "deps": sorted(vertex.deps), 

195 "graph": _encode_graph(vertex.graph), 

196 "output": vertex.output, 

197 } 

198 

199 

200def _decode_vertex( 

201 obj: dict, legacy_kind_inference: bool = False 

202) -> Node | SubGraphNode: 

203 """Decode a JSON object to Node or SubGraphNode. Validates before construction.""" 

204 if not isinstance(obj, dict): 

205 raise ValueError("Vertex must be an object") 

206 

207 kind = obj.get("kind") 

208 if kind is None and legacy_kind_inference: 

209 if "op_name" in obj and "graph" not in obj: 

210 kind = "node" 

211 elif "graph" in obj and "output" in obj: 

212 kind = "subgraph" 

213 else: 

214 raise ValueError( 

215 "Vertex has no 'kind' and cannot infer from op_name/graph/output" 

216 ) 

217 if kind is None: 

218 raise ValueError("Vertex must have 'kind'") 

219 if kind not in ("node", "subgraph"): 

220 raise ValueError(f"Vertex has unsupported kind: {kind!r}") 

221 

222 if kind == "node": 

223 _validate_node(obj, expected_kind=kind) 

224 return Node( 

225 op_name=obj["op_name"].strip(), 

226 params=_decode_params(obj["params"]), 

227 deps=list(obj["deps"]), 

228 cache=obj.get("cache", True), 

229 ) 

230 if kind == "subgraph": 

231 _validate_subgraph(obj, legacy_kind_inference) 

232 return SubGraphNode( 

233 params=_decode_params(obj["params"]), 

234 deps=list(obj["deps"]), 

235 graph=_decode_graph(obj["graph"], legacy_kind_inference), 

236 output=obj["output"], 

237 ) 

238 raise ValueError(f"Vertex has unsupported kind: {kind!r}") 

239 

240 

241def _validate_node(obj: dict, expected_kind: str | None = None) -> None: 

242 """Validate node object before construction.""" 

243 kind = expected_kind if expected_kind is not None else obj.get("kind") 

244 if kind != "node": 

245 raise ValueError("Node must have kind 'node'") 

246 op_name = obj.get("op_name") 

247 if not isinstance(op_name, str): 

248 raise ValueError("Node must have string 'op_name'") 

249 if not op_name.strip(): 

250 raise ValueError("Node op_name cannot be empty") 

251 if "params" not in obj or not isinstance(obj["params"], dict): 

252 raise ValueError("Node must have 'params' object") 

253 if "deps" not in obj or not isinstance(obj["deps"], list): 

254 raise ValueError("Node must have 'deps' array") 

255 for i, dep in enumerate(obj["deps"]): 

256 if not isinstance(dep, str): 

257 raise ValueError(f"Node deps[{i}] must be string, got {type(dep).__name__}") 

258 cache_val = obj.get("cache") 

259 if cache_val is not None and not isinstance(cache_val, bool): 

260 raise ValueError("Node 'cache' must be boolean when present") 

261 

262 

263def _validate_subgraph(obj: dict, legacy_kind_inference: bool = False) -> None: 

264 """Validate subgraph object before construction.""" 

265 kind = obj.get("kind") 

266 if not legacy_kind_inference and kind != "subgraph": 

267 raise ValueError("SubGraphNode must have kind 'subgraph'") 

268 if "params" not in obj or not isinstance(obj["params"], dict): 

269 raise ValueError("SubGraphNode must have 'params' object") 

270 if "deps" not in obj or not isinstance(obj["deps"], list): 

271 raise ValueError("SubGraphNode must have 'deps' array") 

272 for i, dep in enumerate(obj["deps"]): 

273 if not isinstance(dep, str): 

274 raise ValueError( 

275 f"SubGraphNode deps[{i}] must be string, got {type(dep).__name__}" 

276 ) 

277 if "graph" not in obj or not isinstance(obj["graph"], dict): 

278 raise ValueError("SubGraphNode must have 'graph' object") 

279 output = obj.get("output") 

280 if not isinstance(output, str): 

281 raise ValueError("SubGraphNode must have string 'output'") 

282 if output not in obj["graph"]: 

283 raise ValueError( 

284 f"SubGraphNode output '{output}' must be key in graph. " 

285 f"Graph keys: {list(obj['graph'].keys())}" 

286 ) 

287 for node_id, vertex_obj in obj["graph"].items(): 

288 _validate_vertex_for_kind(vertex_obj, node_id, legacy_kind_inference) 

289 

290 

291def _validate_vertex_for_kind( 

292 vertex_obj: Any, node_id: str, legacy_kind_inference: bool = False 

293) -> None: 

294 """Validate a vertex object has valid kind and structure.""" 

295 if not isinstance(vertex_obj, dict): 

296 raise ValueError(f"Vertex '{node_id}' must be an object") 

297 kind = vertex_obj.get("kind") 

298 if kind is None and legacy_kind_inference: 

299 if "op_name" in vertex_obj and "graph" not in vertex_obj: 

300 kind = "node" 

301 elif "graph" in vertex_obj and "output" in vertex_obj: 

302 kind = "subgraph" 

303 else: 

304 raise ValueError( 

305 f"Vertex '{node_id}' has no 'kind' and cannot infer from " 

306 "op_name/graph/output" 

307 ) 

308 if kind == "node": 

309 _validate_node(vertex_obj, expected_kind="node") 

310 elif kind == "subgraph": 

311 _validate_subgraph(vertex_obj, legacy_kind_inference) 

312 else: 

313 raise ValueError(f"Vertex '{node_id}' has unsupported kind: {kind!r}") 

314 

315 

316def _encode_graph(graph: Graph) -> dict: 

317 """Encode graph to JSON object with sorted keys.""" 

318 return dict(sorted((k, _encode_vertex(v)) for k, v in graph.items())) 

319 

320 

321def _decode_graph(obj: dict, legacy_kind_inference: bool = False) -> Graph: 

322 """Decode graph from JSON object.""" 

323 if not isinstance(obj, dict): 

324 raise ValueError("Graph must be an object") 

325 result: Graph = {} 

326 for node_id, vertex_obj in obj.items(): 

327 result[node_id] = _decode_vertex(vertex_obj, legacy_kind_inference) 

328 return result 

329 

330 

331def _validate_envelope(obj: dict) -> None: 

332 """Validate top-level envelope.""" 

333 if not isinstance(obj, dict): 

334 raise ValueError("Document must be a JSON object") 

335 fmt = obj.get("format") 

336 if fmt != FORMAT_ID: 

337 raise ValueError(f"Document format must be '{FORMAT_ID}', got {fmt!r}") 

338 version = obj.get("version") 

339 if version not in SUPPORTED_VERSIONS: 

340 raise ValueError( 

341 f"Document version {version} is not supported. " 

342 f"Supported: {sorted(SUPPORTED_VERSIONS)}" 

343 ) 

344 if "graph" not in obj: 

345 raise ValueError("Document must have 'graph'") 

346 if not isinstance(obj["graph"], dict): 

347 raise ValueError("Document 'graph' must be an object") 

348 

349 

350def dump_graph_to_dict(graph: Graph) -> dict: 

351 """Serialize graph to envelope dict. Deterministic (sorted keys).""" 

352 return { 

353 "format": FORMAT_ID, 

354 "version": 1, 

355 "graph": _encode_graph(graph), 

356 } 

357 

358 

359def dump_graph(graph: Graph) -> str: 

360 """Serialize graph to JSON string. Deterministic output.""" 

361 return json.dumps(dump_graph_to_dict(graph), sort_keys=True) 

362 

363 

364def load_graph_from_dict(obj: dict, legacy_kind_inference: bool = False) -> Graph: 

365 """Load graph from envelope dict.""" 

366 _validate_envelope(obj) 

367 return _decode_graph(obj["graph"], legacy_kind_inference) 

368 

369 

370def load_graph(data: str | bytes, legacy_kind_inference: bool = False) -> Graph: 

371 """Deserialize JSON string or bytes to graph.""" 

372 if isinstance(data, bytes): 

373 data = data.decode("utf-8") 

374 obj = json.loads(data) 

375 return load_graph_from_dict(obj, legacy_kind_inference) 

376 

377 

378def dump_graph_output_to_dict(graph: Graph, output: str) -> dict[str, Any]: 

379 """Serialize a graph plus output node name to a JSON-friendly wrapper.""" 

380 

381 return {"graph": dump_graph_to_dict(graph), "output": output} 

382 

383 

384def load_graph_output_from_dict( 

385 obj: dict[str, Any], legacy_kind_inference: bool = False 

386) -> tuple[Graph, str]: 

387 """Load a graph-plus-output wrapper from a dict.""" 

388 

389 if not isinstance(obj, dict): 

390 raise ValueError("Document must be an object") 

391 

392 output = obj.get("output", "output") 

393 if not isinstance(output, str) or not output: 

394 raise ValueError("Document 'output' must be a non-empty string") 

395 

396 raw_graph = obj.get("graph") 

397 if not isinstance(raw_graph, dict): 

398 raise ValueError("Document must have object 'graph'") 

399 

400 graph = load_graph_from_dict(raw_graph, legacy_kind_inference) 

401 return graph, output 

402 

403 

404def dump_graph_output_data_uri(graph: Graph, output: str) -> str: 

405 """Serialize a graph-plus-output wrapper as a deterministic data URI.""" 

406 

407 payload = json.dumps( 

408 dump_graph_output_to_dict(graph, output), 

409 separators=(",", ":"), 

410 sort_keys=True, 

411 ).encode("utf-8") 

412 encoded = base64.b64encode(payload).decode("ascii") 

413 return f"{GRAPH_OUTPUT_DATA_URI_PREFIX}{encoded}" 

414 

415 

416def load_graph_output_data_uri( 

417 data: str, legacy_kind_inference: bool = False 

418) -> tuple[Graph, str] | None: 

419 """Decode a graph-plus-output data URI. Returns None if parsing fails.""" 

420 

421 if not isinstance(data, str) or not data.startswith(GRAPH_OUTPUT_DATA_URI_PREFIX): 

422 return None 

423 

424 encoded = data[len(GRAPH_OUTPUT_DATA_URI_PREFIX) :] 

425 try: 

426 payload = base64.b64decode(encoded, validate=True) 

427 obj = json.loads(payload.decode("utf-8")) 

428 return load_graph_output_from_dict(obj, legacy_kind_inference) 

429 except Exception: 

430 return None