Coverage for src / invariant / graph_serialization.py: 86.52%

356 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-06 12:18 +0000

1"""Graph serialization: JSON wire format for Invariant graphs. 

2 

3Encodes graphs (Node, SubGraphNode, SwitchNode) and params (ref, cel, Decimal, tuple, 

4ICacheable) for storage and transmission. Distinct from artifact serialization 

5in store/codec.py. 

6""" 

7 

8import base64 

9import importlib 

10import json 

11from decimal import Decimal 

12from io import BytesIO 

13from typing import Any 

14from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit 

15 

16from invariant.graph import Graph 

17from invariant.node import Node, SubGraphNode, SwitchNode 

18from invariant.params import cel, ref 

19from invariant.protocol import ICacheable 

20 

21SUPPORTED_VERSIONS = {1} 

22FORMAT_ID = "invariant-graph" 

23GRAPH_MEDIA_TYPE = "application/vnd.invariant.graph+json" 

24GRAPH_DATA_URI_PREFIX = f"data:{GRAPH_MEDIA_TYPE};base64," 

25 

26RESERVED_KEYS = frozenset( 

27 {"$ref", "$cel", "$decimal", "$tuple", "$literal", "$icacheable"} 

28) 

29 

30 

31def _encode_param_value(value: Any) -> Any: 

32 """Recursively encode a parameter value to JSON-serializable form.""" 

33 # ref marker 

34 if isinstance(value, ref): 

35 return {"$ref": value.dep} 

36 

37 # cel marker 

38 if isinstance(value, cel): 

39 return {"$cel": value.expr} 

40 

41 # Decimal 

42 if isinstance(value, Decimal): 

43 return {"$decimal": str(value)} 

44 

45 # tuple 

46 if isinstance(value, tuple): 

47 return {"$tuple": [_encode_param_value(item) for item in value]} 

48 

49 # ICacheable 

50 if isinstance(value, ICacheable): 

51 type_name = f"{value.__class__.__module__}.{value.__class__.__name__}" 

52 if hasattr(value, "to_json_value") and callable(value.to_json_value): 

53 return {"$icacheable": {"type": type_name, "value": value.to_json_value()}} 

54 stream = BytesIO() 

55 value.to_stream(stream) 

56 payload_b64 = base64.b64encode(stream.getvalue()).decode("ascii") 

57 return {"$icacheable": {"type": type_name, "payload_b64": payload_b64}} 

58 

59 # dict 

60 if isinstance(value, dict): 

61 encoded = {k: _encode_param_value(v) for k, v in value.items()} 

62 # Collision: plain dict that would decode as marker -> wrap in $literal 

63 if len(encoded) == 1: 

64 (single_key,) = encoded.keys() 

65 if single_key in RESERVED_KEYS: 

66 return {"$literal": encoded} 

67 return encoded 

68 

69 # list 

70 if isinstance(value, list): 

71 return [_encode_param_value(item) for item in value] 

72 

73 # Primitives: None, bool, int, str 

74 return value 

75 

76 

77def _decode_param_value(obj: Any, literal_mode: bool = False) -> Any: 

78 """Recursively decode a JSON value to Python parameter value.""" 

79 # In literal mode, never treat dicts as markers 

80 if literal_mode: 

81 if isinstance(obj, dict): 

82 return { 

83 k: _decode_param_value(v, literal_mode=True) for k, v in obj.items() 

84 } 

85 if isinstance(obj, list): 

86 return [_decode_param_value(item, literal_mode=True) for item in obj] 

87 return obj 

88 

89 # Single-key dict with reserved key -> marker or escape 

90 if isinstance(obj, dict): 

91 if len(obj) == 1: 

92 (key, val) = next(iter(obj.items())) 

93 if key == "$ref": 

94 return ref(val) 

95 if key == "$cel": 

96 return cel(val) 

97 if key == "$decimal": 

98 return Decimal(val) 

99 if key == "$tuple": 

100 return tuple(_decode_param_value(item) for item in val) 

101 if key == "$literal": 

102 return _decode_param_value(val, literal_mode=True) 

103 if key == "$icacheable": 

104 return _decode_icacheable(val) 

105 # Multi-key or non-reserved: recursive decode 

106 return {k: _decode_param_value(v) for k, v in obj.items()} 

107 

108 if isinstance(obj, list): 

109 return [_decode_param_value(item) for item in obj] 

110 

111 return obj 

112 

113 

114def _decode_icacheable(obj: dict) -> Any: 

115 """Decode $icacheable object to ICacheable instance.""" 

116 if not isinstance(obj, dict): 

117 raise ValueError("$icacheable value must be an object") 

118 type_name = obj.get("type") 

119 if not type_name or not isinstance(type_name, str): 

120 raise ValueError("$icacheable must have non-empty string 'type'") 

121 if "payload_b64" in obj and "value" in obj: 

122 raise ValueError( 

123 "$icacheable must have exactly one of 'payload_b64' or 'value'" 

124 ) 

125 if "payload_b64" not in obj and "value" not in obj: 

126 raise ValueError("$icacheable must have 'payload_b64' or 'value'") 

127 

128 module_path, class_name = type_name.rsplit(".", 1) 

129 try: 

130 module = importlib.import_module(module_path) 

131 cls = getattr(module, class_name) 

132 except (ImportError, AttributeError) as e: 

133 raise ValueError( 

134 f"$icacheable type '{type_name}' could not be imported: {e}" 

135 ) from e 

136 

137 if "value" in obj: 

138 if not hasattr(cls, "from_json_value"): 

139 raise ValueError( 

140 f"$icacheable type '{type_name}' has 'value' but no " 

141 "from_json_value method" 

142 ) 

143 return cls.from_json_value(obj["value"]) 

144 

145 # payload_b64 

146 try: 

147 payload = base64.b64decode(obj["payload_b64"]) 

148 except Exception as e: 

149 raise ValueError(f"$icacheable payload_b64 is invalid base64: {e}") from e 

150 stream = BytesIO(payload) 

151 try: 

152 return cls.from_stream(stream) 

153 except Exception as e: 

154 raise ValueError( 

155 f"$icacheable from_stream failed for '{type_name}': {e}" 

156 ) from e 

157 

158 

159def _encode_params(params: dict[str, Any]) -> dict[str, Any]: 

160 """Encode params dict with sorted keys for determinism.""" 

161 return dict(sorted((k, _encode_param_value(v)) for k, v in params.items())) 

162 

163 

164def _decode_params(obj: dict) -> dict[str, Any]: 

165 """Decode params dict.""" 

166 return {k: _decode_param_value(v) for k, v in obj.items()} 

167 

168 

169def dump_value_to_jsonable(value: Any) -> Any: 

170 """Serialize a cacheable value to the graph JSON marker encoding.""" 

171 return _encode_param_value(value) 

172 

173 

174def load_value_from_jsonable(obj: Any) -> Any: 

175 """Deserialize a value from the graph JSON marker encoding.""" 

176 return _decode_param_value(obj) 

177 

178 

179def _encode_vertex(vertex: Node | SubGraphNode | SwitchNode) -> dict: 

180 """Encode a single graph vertex to JSON object.""" 

181 if isinstance(vertex, Node): 

182 result: dict = { 

183 "kind": "node", 

184 "op_name": vertex.op_name, 

185 "params": _encode_params(vertex.params), 

186 "deps": sorted(vertex.deps), 

187 } 

188 if not vertex.cache: 

189 result["cache"] = False 

190 return result 

191 if isinstance(vertex, SubGraphNode): 

192 return { 

193 "kind": "subgraph", 

194 "params": _encode_params(vertex.params), 

195 "deps": sorted(vertex.deps), 

196 "graph": _encode_graph(vertex.graph), 

197 "output": vertex.output, 

198 } 

199 

200 # SwitchNode 

201 result = { 

202 "kind": "switch", 

203 "selector": _encode_param_value(vertex.selector), 

204 "deps": sorted(vertex.deps), 

205 "cases": dict(sorted(vertex.cases.items())), 

206 } 

207 if vertex.default is not None: 

208 result["default"] = vertex.default 

209 return { 

210 key: result[key] 

211 for key in ("kind", "selector", "deps", "cases", "default") 

212 if key in result 

213 } 

214 

215 

216def _decode_vertex( 

217 obj: dict, legacy_kind_inference: bool = False 

218) -> Node | SubGraphNode | SwitchNode: 

219 """Decode a JSON object to a graph vertex. Validates before construction.""" 

220 if not isinstance(obj, dict): 

221 raise ValueError("Vertex must be an object") 

222 

223 kind = obj.get("kind") 

224 if kind is None and legacy_kind_inference: 

225 if "op_name" in obj and "graph" not in obj: 

226 kind = "node" 

227 elif "graph" in obj and "output" in obj: 

228 kind = "subgraph" 

229 else: 

230 raise ValueError( 

231 "Vertex has no 'kind' and cannot infer from op_name/graph/output" 

232 ) 

233 if kind is None: 

234 raise ValueError("Vertex must have 'kind'") 

235 if kind not in ("node", "subgraph", "switch"): 

236 raise ValueError(f"Vertex has unsupported kind: {kind!r}") 

237 

238 if kind == "node": 

239 _validate_node(obj, expected_kind=kind) 

240 return Node( 

241 op_name=obj["op_name"].strip(), 

242 params=_decode_params(obj["params"]), 

243 deps=list(obj["deps"]), 

244 cache=obj.get("cache", True), 

245 ) 

246 if kind == "subgraph": 

247 _validate_subgraph(obj, legacy_kind_inference) 

248 return SubGraphNode( 

249 params=_decode_params(obj["params"]), 

250 deps=list(obj["deps"]), 

251 graph=_decode_graph(obj["graph"], legacy_kind_inference), 

252 output=obj["output"], 

253 ) 

254 if kind == "switch": 

255 _validate_switch(obj) 

256 return SwitchNode( 

257 selector=_decode_param_value(obj["selector"]), 

258 deps=list(obj["deps"]), 

259 cases=dict(obj["cases"]), 

260 default=obj.get("default"), 

261 ) 

262 raise ValueError(f"Vertex has unsupported kind: {kind!r}") 

263 

264 

265def _validate_node(obj: dict, expected_kind: str | None = None) -> None: 

266 """Validate node object before construction.""" 

267 kind = expected_kind if expected_kind is not None else obj.get("kind") 

268 if kind != "node": 

269 raise ValueError("Node must have kind 'node'") 

270 op_name = obj.get("op_name") 

271 if not isinstance(op_name, str): 

272 raise ValueError("Node must have string 'op_name'") 

273 if not op_name.strip(): 

274 raise ValueError("Node op_name cannot be empty") 

275 if "params" not in obj or not isinstance(obj["params"], dict): 

276 raise ValueError("Node must have 'params' object") 

277 if "deps" not in obj or not isinstance(obj["deps"], list): 

278 raise ValueError("Node must have 'deps' array") 

279 for i, dep in enumerate(obj["deps"]): 

280 if not isinstance(dep, str): 

281 raise ValueError(f"Node deps[{i}] must be string, got {type(dep).__name__}") 

282 cache_val = obj.get("cache") 

283 if cache_val is not None and not isinstance(cache_val, bool): 

284 raise ValueError("Node 'cache' must be boolean when present") 

285 

286 

287def _validate_subgraph(obj: dict, legacy_kind_inference: bool = False) -> None: 

288 """Validate subgraph object before construction.""" 

289 kind = obj.get("kind") 

290 if not legacy_kind_inference and kind != "subgraph": 

291 raise ValueError("SubGraphNode must have kind 'subgraph'") 

292 if "params" not in obj or not isinstance(obj["params"], dict): 

293 raise ValueError("SubGraphNode must have 'params' object") 

294 if "deps" not in obj or not isinstance(obj["deps"], list): 

295 raise ValueError("SubGraphNode must have 'deps' array") 

296 for i, dep in enumerate(obj["deps"]): 

297 if not isinstance(dep, str): 

298 raise ValueError( 

299 f"SubGraphNode deps[{i}] must be string, got {type(dep).__name__}" 

300 ) 

301 if "graph" not in obj or not isinstance(obj["graph"], dict): 

302 raise ValueError("SubGraphNode must have 'graph' object") 

303 output = obj.get("output") 

304 if not isinstance(output, str): 

305 raise ValueError("SubGraphNode must have string 'output'") 

306 if output not in obj["graph"]: 

307 raise ValueError( 

308 f"SubGraphNode output '{output}' must be key in graph. " 

309 f"Graph keys: {list(obj['graph'].keys())}" 

310 ) 

311 for node_id, vertex_obj in obj["graph"].items(): 

312 _validate_vertex_for_kind(vertex_obj, node_id, legacy_kind_inference) 

313 

314 

315def _validate_switch(obj: dict) -> None: 

316 """Validate switch object before construction.""" 

317 if obj.get("kind") != "switch": 

318 raise ValueError("SwitchNode must have kind 'switch'") 

319 if "selector" not in obj: 

320 raise ValueError("SwitchNode must have 'selector'") 

321 if "deps" not in obj or not isinstance(obj["deps"], list): 

322 raise ValueError("SwitchNode must have 'deps' array") 

323 for i, dep in enumerate(obj["deps"]): 

324 if not isinstance(dep, str): 

325 raise ValueError( 

326 f"SwitchNode deps[{i}] must be string, got {type(dep).__name__}" 

327 ) 

328 if "cases" not in obj or not isinstance(obj["cases"], dict): 

329 raise ValueError("SwitchNode must have 'cases' object") 

330 if not obj["cases"]: 

331 raise ValueError("SwitchNode cases must not be empty") 

332 for case_key, target in obj["cases"].items(): 

333 if not isinstance(case_key, str): 

334 raise ValueError("SwitchNode cases keys must be strings") 

335 if not isinstance(target, str) or not target: 

336 raise ValueError("SwitchNode cases values must be non-empty strings") 

337 default = obj.get("default") 

338 if default is not None and (not isinstance(default, str) or not default): 

339 raise ValueError("SwitchNode default must be a non-empty string when present") 

340 

341 

342def _validate_vertex_for_kind( 

343 vertex_obj: Any, node_id: str, legacy_kind_inference: bool = False 

344) -> None: 

345 """Validate a vertex object has valid kind and structure.""" 

346 if not isinstance(vertex_obj, dict): 

347 raise ValueError(f"Vertex '{node_id}' must be an object") 

348 kind = vertex_obj.get("kind") 

349 if kind is None and legacy_kind_inference: 

350 if "op_name" in vertex_obj and "graph" not in vertex_obj: 

351 kind = "node" 

352 elif "graph" in vertex_obj and "output" in vertex_obj: 

353 kind = "subgraph" 

354 else: 

355 raise ValueError( 

356 f"Vertex '{node_id}' has no 'kind' and cannot infer from " 

357 "op_name/graph/output" 

358 ) 

359 if kind == "node": 

360 _validate_node(vertex_obj, expected_kind="node") 

361 elif kind == "subgraph": 

362 _validate_subgraph(vertex_obj, legacy_kind_inference) 

363 elif kind == "switch": 

364 _validate_switch(vertex_obj) 

365 else: 

366 raise ValueError(f"Vertex '{node_id}' has unsupported kind: {kind!r}") 

367 

368 

369def _encode_graph(graph: Graph) -> dict: 

370 """Encode graph to JSON object with sorted keys.""" 

371 return dict(sorted((k, _encode_vertex(v)) for k, v in graph.items())) 

372 

373 

374def _decode_graph(obj: dict, legacy_kind_inference: bool = False) -> Graph: 

375 """Decode graph from JSON object.""" 

376 if not isinstance(obj, dict): 

377 raise ValueError("Graph must be an object") 

378 result: Graph = {} 

379 for node_id, vertex_obj in obj.items(): 

380 result[node_id] = _decode_vertex(vertex_obj, legacy_kind_inference) 

381 _validate_switch_targets(result) 

382 return result 

383 

384 

385def _validate_switch_targets(graph: Graph) -> None: 

386 """Validate that switch branch targets are graph-local node IDs.""" 

387 node_ids = set(graph) 

388 for node_id, vertex in graph.items(): 

389 if not isinstance(vertex, SwitchNode): 

390 continue 

391 targets = list(vertex.cases.values()) 

392 if vertex.default is not None: 

393 targets.append(vertex.default) 

394 for target in targets: 

395 if target not in node_ids: 

396 raise ValueError( 

397 f"SwitchNode '{node_id}' targets '{target}' which must be " 

398 f"a key in graph. Graph keys: {list(graph.keys())}" 

399 ) 

400 

401 

402def _validate_output(graph: Graph, output: str | None) -> None: 

403 if output is None: 

404 return 

405 if not isinstance(output, str) or not output: 

406 raise ValueError("Document 'output' must be a non-empty string when present") 

407 if output not in graph: 

408 raise ValueError( 

409 f"Document output '{output}' must be key in graph. " 

410 f"Graph keys: {list(graph.keys())}" 

411 ) 

412 

413 

414def _validate_output_arg(graph: Graph, output: str | None) -> None: 

415 if output is None: 

416 return 

417 if not isinstance(output, str) or not output: 

418 raise ValueError("output must be a non-empty string when present") 

419 if output not in graph: 

420 raise ValueError( 

421 f"output '{output}' must be a key in graph. " 

422 f"Graph keys: {list(graph.keys())}" 

423 ) 

424 

425 

426def _validate_envelope(obj: dict) -> None: 

427 """Validate top-level envelope.""" 

428 if not isinstance(obj, dict): 

429 raise ValueError("Document must be a JSON object") 

430 fmt = obj.get("format") 

431 if fmt != FORMAT_ID: 

432 raise ValueError(f"Document format must be '{FORMAT_ID}', got {fmt!r}") 

433 version = obj.get("version") 

434 if version not in SUPPORTED_VERSIONS: 

435 raise ValueError( 

436 f"Document version {version} is not supported. " 

437 f"Supported: {sorted(SUPPORTED_VERSIONS)}" 

438 ) 

439 if "graph" not in obj: 

440 raise ValueError("Document must have 'graph'") 

441 if not isinstance(obj["graph"], dict): 

442 raise ValueError("Document 'graph' must be an object") 

443 if "output" in obj: 

444 output = obj["output"] 

445 if not isinstance(output, str) or not output: 

446 raise ValueError( 

447 "Document 'output' must be a non-empty string when present" 

448 ) 

449 

450 

451def dump_graph_to_dict(graph: Graph, *, output: str | None = None) -> dict: 

452 """Serialize graph to envelope dict. Deterministic (sorted keys).""" 

453 _validate_output_arg(graph, output) 

454 document = { 

455 "format": FORMAT_ID, 

456 "version": 1, 

457 } 

458 if output is not None: 

459 document["output"] = output 

460 document["graph"] = _encode_graph(graph) 

461 return document 

462 

463 

464def dump_graph(graph: Graph, *, output: str | None = None) -> str: 

465 """Serialize graph to JSON string. Deterministic output.""" 

466 return json.dumps(dump_graph_to_dict(graph, output=output), sort_keys=True) 

467 

468 

469def load_graph_document_from_dict( 

470 obj: dict, legacy_kind_inference: bool = False 

471) -> tuple[Graph, str | None]: 

472 """Load graph document from envelope dict, preserving optional output.""" 

473 _validate_envelope(obj) 

474 graph = _decode_graph(obj["graph"], legacy_kind_inference) 

475 output = obj.get("output") 

476 _validate_output(graph, output) 

477 return graph, output 

478 

479 

480def load_graph_from_dict(obj: dict, legacy_kind_inference: bool = False) -> Graph: 

481 """Load graph from envelope dict, discarding optional document output.""" 

482 graph, _output = load_graph_document_from_dict(obj, legacy_kind_inference) 

483 return graph 

484 

485 

486def load_graph_document( 

487 data: str | bytes, legacy_kind_inference: bool = False 

488) -> tuple[Graph, str | None]: 

489 """Deserialize JSON string or bytes to graph document.""" 

490 if isinstance(data, bytes): 

491 data = data.decode("utf-8") 

492 obj = json.loads(data) 

493 return load_graph_document_from_dict(obj, legacy_kind_inference) 

494 

495 

496def load_graph(data: str | bytes, legacy_kind_inference: bool = False) -> Graph: 

497 """Deserialize JSON string or bytes to graph, discarding document output.""" 

498 graph, _output = load_graph_document(data, legacy_kind_inference) 

499 return graph 

500 

501 

502def _encode_graph_document_payload(graph: Graph, output: str | None = None) -> str: 

503 payload = json.dumps( 

504 dump_graph_to_dict(graph, output=output), 

505 separators=(",", ":"), 

506 sort_keys=True, 

507 ) 

508 return base64.b64encode(payload.encode("utf-8")).decode("ascii") 

509 

510 

511def _query_value_to_string(value: Any) -> str: 

512 if isinstance(value, str): 

513 try: 

514 json.loads(value) 

515 except json.JSONDecodeError: 

516 return value 

517 return json.dumps( 

518 dump_value_to_jsonable(value), 

519 separators=(",", ":"), 

520 sort_keys=True, 

521 ) 

522 

523 

524def _query_value_from_string(value: str) -> Any: 

525 try: 

526 return load_value_from_jsonable(json.loads(value)) 

527 except json.JSONDecodeError: 

528 return value 

529 

530 

531def dump_graph_data_uri( 

532 graph: Graph, 

533 *, 

534 output: str | None = None, 

535 context: dict[str, Any] | None = None, 

536) -> str: 

537 """Serialize a graph document plus optional query context as a data URI.""" 

538 

539 encoded = _encode_graph_document_payload(graph, output) 

540 uri = f"{GRAPH_DATA_URI_PREFIX}{encoded}" 

541 if not context: 

542 return uri 

543 

544 for key in context: 

545 if not isinstance(key, str) or not key: 

546 raise ValueError("context keys must be non-empty strings") 

547 query_pairs = [ 

548 (key, _query_value_to_string(value)) 

549 for key, value in sorted(context.items()) 

550 ] 

551 return f"{uri}?{urlencode(query_pairs)}" 

552 

553 

554def graph_data_uri_cache_key(data: str) -> str | None: 

555 """Return the static graph data URI without query context.""" 

556 

557 if not isinstance(data, str): 

558 return None 

559 

560 parts = urlsplit(data) 

561 graph_path_prefix = f"{GRAPH_MEDIA_TYPE};base64," 

562 if parts.scheme != "data" or not parts.path.startswith(graph_path_prefix): 

563 return None 

564 

565 if parts.fragment: 

566 raise ValueError("Invariant graph data URIs must not include fragments") 

567 return urlunsplit((parts.scheme, parts.netloc, parts.path, "", "")) 

568 

569 

570def _decode_query_context(query: str) -> dict[str, Any]: 

571 context: dict[str, Any] = {} 

572 for key, value in parse_qsl(query, keep_blank_values=True): 

573 if not key: 

574 raise ValueError("Invariant graph data URI query keys must be non-empty") 

575 if key in context: 

576 raise ValueError( 

577 f"Invariant graph data URI query key {key!r} is duplicated" 

578 ) 

579 context[key] = _query_value_from_string(value) 

580 return context 

581 

582 

583def load_graph_data_uri( 

584 data: str, legacy_kind_inference: bool = False 

585) -> tuple[Graph, str | None, dict[str, Any]] | None: 

586 """Decode an Invariant graph data URI with optional query context.""" 

587 

588 cache_key = graph_data_uri_cache_key(data) 

589 if cache_key is None: 

590 return None 

591 

592 parts = urlsplit(data) 

593 encoded = parts.path[len(f"{GRAPH_MEDIA_TYPE};base64,") :] 

594 try: 

595 payload = base64.b64decode(encoded, validate=True) 

596 obj = json.loads(payload.decode("utf-8")) 

597 except Exception as exc: 

598 raise ValueError(f"Invalid Invariant graph data URI payload: {exc}") from exc 

599 

600 graph, output = load_graph_document_from_dict(obj, legacy_kind_inference) 

601 context = _decode_query_context(parts.query) 

602 return graph, output, context