Coverage for src / invariant / graph_serialization.py: 86.52%
356 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-08 09:24 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-08 09:24 +0000
1"""Graph serialization: JSON wire format for Invariant graphs.
3Encodes graphs (Node, SubGraphNode, SwitchNode) and params (ref, cel, Decimal, tuple,
4ICacheable) for storage and transmission. Distinct from artifact serialization
5in store/codec.py.
6"""
8import base64
9import importlib
10import json
11from decimal import Decimal
12from io import BytesIO
13from typing import Any
14from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
16from invariant.graph import Graph
17from invariant.node import Node, SubGraphNode, SwitchNode
18from invariant.params import cel, ref
19from invariant.protocol import ICacheable
21SUPPORTED_VERSIONS = {1}
22FORMAT_ID = "invariant-graph"
23GRAPH_MEDIA_TYPE = "application/vnd.invariant.graph+json"
24GRAPH_DATA_URI_PREFIX = f"data:{GRAPH_MEDIA_TYPE};base64,"
26RESERVED_KEYS = frozenset(
27 {"$ref", "$cel", "$decimal", "$tuple", "$literal", "$icacheable"}
28)
31def _encode_param_value(value: Any) -> Any:
32 """Recursively encode a parameter value to JSON-serializable form."""
33 # ref marker
34 if isinstance(value, ref):
35 return {"$ref": value.dep}
37 # cel marker
38 if isinstance(value, cel):
39 return {"$cel": value.expr}
41 # Decimal
42 if isinstance(value, Decimal):
43 return {"$decimal": str(value)}
45 # tuple
46 if isinstance(value, tuple):
47 return {"$tuple": [_encode_param_value(item) for item in value]}
49 # ICacheable
50 if isinstance(value, ICacheable):
51 type_name = f"{value.__class__.__module__}.{value.__class__.__name__}"
52 if hasattr(value, "to_json_value") and callable(value.to_json_value):
53 return {"$icacheable": {"type": type_name, "value": value.to_json_value()}}
54 stream = BytesIO()
55 value.to_stream(stream)
56 payload_b64 = base64.b64encode(stream.getvalue()).decode("ascii")
57 return {"$icacheable": {"type": type_name, "payload_b64": payload_b64}}
59 # dict
60 if isinstance(value, dict):
61 encoded = {k: _encode_param_value(v) for k, v in value.items()}
62 # Collision: plain dict that would decode as marker -> wrap in $literal
63 if len(encoded) == 1:
64 (single_key,) = encoded.keys()
65 if single_key in RESERVED_KEYS:
66 return {"$literal": encoded}
67 return encoded
69 # list
70 if isinstance(value, list):
71 return [_encode_param_value(item) for item in value]
73 # Primitives: None, bool, int, str
74 return value
77def _decode_param_value(obj: Any, literal_mode: bool = False) -> Any:
78 """Recursively decode a JSON value to Python parameter value."""
79 # In literal mode, never treat dicts as markers
80 if literal_mode:
81 if isinstance(obj, dict):
82 return {
83 k: _decode_param_value(v, literal_mode=True) for k, v in obj.items()
84 }
85 if isinstance(obj, list):
86 return [_decode_param_value(item, literal_mode=True) for item in obj]
87 return obj
89 # Single-key dict with reserved key -> marker or escape
90 if isinstance(obj, dict):
91 if len(obj) == 1:
92 (key, val) = next(iter(obj.items()))
93 if key == "$ref":
94 return ref(val)
95 if key == "$cel":
96 return cel(val)
97 if key == "$decimal":
98 return Decimal(val)
99 if key == "$tuple":
100 return tuple(_decode_param_value(item) for item in val)
101 if key == "$literal":
102 return _decode_param_value(val, literal_mode=True)
103 if key == "$icacheable":
104 return _decode_icacheable(val)
105 # Multi-key or non-reserved: recursive decode
106 return {k: _decode_param_value(v) for k, v in obj.items()}
108 if isinstance(obj, list):
109 return [_decode_param_value(item) for item in obj]
111 return obj
114def _decode_icacheable(obj: dict) -> Any:
115 """Decode $icacheable object to ICacheable instance."""
116 if not isinstance(obj, dict):
117 raise ValueError("$icacheable value must be an object")
118 type_name = obj.get("type")
119 if not type_name or not isinstance(type_name, str):
120 raise ValueError("$icacheable must have non-empty string 'type'")
121 if "payload_b64" in obj and "value" in obj:
122 raise ValueError(
123 "$icacheable must have exactly one of 'payload_b64' or 'value'"
124 )
125 if "payload_b64" not in obj and "value" not in obj:
126 raise ValueError("$icacheable must have 'payload_b64' or 'value'")
128 module_path, class_name = type_name.rsplit(".", 1)
129 try:
130 module = importlib.import_module(module_path)
131 cls = getattr(module, class_name)
132 except (ImportError, AttributeError) as e:
133 raise ValueError(
134 f"$icacheable type '{type_name}' could not be imported: {e}"
135 ) from e
137 if "value" in obj:
138 if not hasattr(cls, "from_json_value"):
139 raise ValueError(
140 f"$icacheable type '{type_name}' has 'value' but no "
141 "from_json_value method"
142 )
143 return cls.from_json_value(obj["value"])
145 # payload_b64
146 try:
147 payload = base64.b64decode(obj["payload_b64"])
148 except Exception as e:
149 raise ValueError(f"$icacheable payload_b64 is invalid base64: {e}") from e
150 stream = BytesIO(payload)
151 try:
152 return cls.from_stream(stream)
153 except Exception as e:
154 raise ValueError(
155 f"$icacheable from_stream failed for '{type_name}': {e}"
156 ) from e
159def _encode_params(params: dict[str, Any]) -> dict[str, Any]:
160 """Encode params dict with sorted keys for determinism."""
161 return dict(sorted((k, _encode_param_value(v)) for k, v in params.items()))
164def _decode_params(obj: dict) -> dict[str, Any]:
165 """Decode params dict."""
166 return {k: _decode_param_value(v) for k, v in obj.items()}
169def dump_value_to_jsonable(value: Any) -> Any:
170 """Serialize a cacheable value to the graph JSON marker encoding."""
171 return _encode_param_value(value)
174def load_value_from_jsonable(obj: Any) -> Any:
175 """Deserialize a value from the graph JSON marker encoding."""
176 return _decode_param_value(obj)
179def _encode_vertex(vertex: Node | SubGraphNode | SwitchNode) -> dict:
180 """Encode a single graph vertex to JSON object."""
181 if isinstance(vertex, Node):
182 result: dict = {
183 "kind": "node",
184 "op_name": vertex.op_name,
185 "params": _encode_params(vertex.params),
186 "deps": sorted(vertex.deps),
187 }
188 if not vertex.cache:
189 result["cache"] = False
190 return result
191 if isinstance(vertex, SubGraphNode):
192 return {
193 "kind": "subgraph",
194 "params": _encode_params(vertex.params),
195 "deps": sorted(vertex.deps),
196 "graph": _encode_graph(vertex.graph),
197 "output": vertex.output,
198 }
200 # SwitchNode
201 result = {
202 "kind": "switch",
203 "selector": _encode_param_value(vertex.selector),
204 "deps": sorted(vertex.deps),
205 "cases": dict(sorted(vertex.cases.items())),
206 }
207 if vertex.default is not None:
208 result["default"] = vertex.default
209 return {
210 key: result[key]
211 for key in ("kind", "selector", "deps", "cases", "default")
212 if key in result
213 }
216def _decode_vertex(
217 obj: dict, legacy_kind_inference: bool = False
218) -> Node | SubGraphNode | SwitchNode:
219 """Decode a JSON object to a graph vertex. Validates before construction."""
220 if not isinstance(obj, dict):
221 raise ValueError("Vertex must be an object")
223 kind = obj.get("kind")
224 if kind is None and legacy_kind_inference:
225 if "op_name" in obj and "graph" not in obj:
226 kind = "node"
227 elif "graph" in obj and "output" in obj:
228 kind = "subgraph"
229 else:
230 raise ValueError(
231 "Vertex has no 'kind' and cannot infer from op_name/graph/output"
232 )
233 if kind is None:
234 raise ValueError("Vertex must have 'kind'")
235 if kind not in ("node", "subgraph", "switch"):
236 raise ValueError(f"Vertex has unsupported kind: {kind!r}")
238 if kind == "node":
239 _validate_node(obj, expected_kind=kind)
240 return Node(
241 op_name=obj["op_name"].strip(),
242 params=_decode_params(obj["params"]),
243 deps=list(obj["deps"]),
244 cache=obj.get("cache", True),
245 )
246 if kind == "subgraph":
247 _validate_subgraph(obj, legacy_kind_inference)
248 return SubGraphNode(
249 params=_decode_params(obj["params"]),
250 deps=list(obj["deps"]),
251 graph=_decode_graph(obj["graph"], legacy_kind_inference),
252 output=obj["output"],
253 )
254 if kind == "switch":
255 _validate_switch(obj)
256 return SwitchNode(
257 selector=_decode_param_value(obj["selector"]),
258 deps=list(obj["deps"]),
259 cases=dict(obj["cases"]),
260 default=obj.get("default"),
261 )
262 raise ValueError(f"Vertex has unsupported kind: {kind!r}")
265def _validate_node(obj: dict, expected_kind: str | None = None) -> None:
266 """Validate node object before construction."""
267 kind = expected_kind if expected_kind is not None else obj.get("kind")
268 if kind != "node":
269 raise ValueError("Node must have kind 'node'")
270 op_name = obj.get("op_name")
271 if not isinstance(op_name, str):
272 raise ValueError("Node must have string 'op_name'")
273 if not op_name.strip():
274 raise ValueError("Node op_name cannot be empty")
275 if "params" not in obj or not isinstance(obj["params"], dict):
276 raise ValueError("Node must have 'params' object")
277 if "deps" not in obj or not isinstance(obj["deps"], list):
278 raise ValueError("Node must have 'deps' array")
279 for i, dep in enumerate(obj["deps"]):
280 if not isinstance(dep, str):
281 raise ValueError(f"Node deps[{i}] must be string, got {type(dep).__name__}")
282 cache_val = obj.get("cache")
283 if cache_val is not None and not isinstance(cache_val, bool):
284 raise ValueError("Node 'cache' must be boolean when present")
287def _validate_subgraph(obj: dict, legacy_kind_inference: bool = False) -> None:
288 """Validate subgraph object before construction."""
289 kind = obj.get("kind")
290 if not legacy_kind_inference and kind != "subgraph":
291 raise ValueError("SubGraphNode must have kind 'subgraph'")
292 if "params" not in obj or not isinstance(obj["params"], dict):
293 raise ValueError("SubGraphNode must have 'params' object")
294 if "deps" not in obj or not isinstance(obj["deps"], list):
295 raise ValueError("SubGraphNode must have 'deps' array")
296 for i, dep in enumerate(obj["deps"]):
297 if not isinstance(dep, str):
298 raise ValueError(
299 f"SubGraphNode deps[{i}] must be string, got {type(dep).__name__}"
300 )
301 if "graph" not in obj or not isinstance(obj["graph"], dict):
302 raise ValueError("SubGraphNode must have 'graph' object")
303 output = obj.get("output")
304 if not isinstance(output, str):
305 raise ValueError("SubGraphNode must have string 'output'")
306 if output not in obj["graph"]:
307 raise ValueError(
308 f"SubGraphNode output '{output}' must be key in graph. "
309 f"Graph keys: {list(obj['graph'].keys())}"
310 )
311 for node_id, vertex_obj in obj["graph"].items():
312 _validate_vertex_for_kind(vertex_obj, node_id, legacy_kind_inference)
315def _validate_switch(obj: dict) -> None:
316 """Validate switch object before construction."""
317 if obj.get("kind") != "switch":
318 raise ValueError("SwitchNode must have kind 'switch'")
319 if "selector" not in obj:
320 raise ValueError("SwitchNode must have 'selector'")
321 if "deps" not in obj or not isinstance(obj["deps"], list):
322 raise ValueError("SwitchNode must have 'deps' array")
323 for i, dep in enumerate(obj["deps"]):
324 if not isinstance(dep, str):
325 raise ValueError(
326 f"SwitchNode deps[{i}] must be string, got {type(dep).__name__}"
327 )
328 if "cases" not in obj or not isinstance(obj["cases"], dict):
329 raise ValueError("SwitchNode must have 'cases' object")
330 if not obj["cases"]:
331 raise ValueError("SwitchNode cases must not be empty")
332 for case_key, target in obj["cases"].items():
333 if not isinstance(case_key, str):
334 raise ValueError("SwitchNode cases keys must be strings")
335 if not isinstance(target, str) or not target:
336 raise ValueError("SwitchNode cases values must be non-empty strings")
337 default = obj.get("default")
338 if default is not None and (not isinstance(default, str) or not default):
339 raise ValueError("SwitchNode default must be a non-empty string when present")
342def _validate_vertex_for_kind(
343 vertex_obj: Any, node_id: str, legacy_kind_inference: bool = False
344) -> None:
345 """Validate a vertex object has valid kind and structure."""
346 if not isinstance(vertex_obj, dict):
347 raise ValueError(f"Vertex '{node_id}' must be an object")
348 kind = vertex_obj.get("kind")
349 if kind is None and legacy_kind_inference:
350 if "op_name" in vertex_obj and "graph" not in vertex_obj:
351 kind = "node"
352 elif "graph" in vertex_obj and "output" in vertex_obj:
353 kind = "subgraph"
354 else:
355 raise ValueError(
356 f"Vertex '{node_id}' has no 'kind' and cannot infer from "
357 "op_name/graph/output"
358 )
359 if kind == "node":
360 _validate_node(vertex_obj, expected_kind="node")
361 elif kind == "subgraph":
362 _validate_subgraph(vertex_obj, legacy_kind_inference)
363 elif kind == "switch":
364 _validate_switch(vertex_obj)
365 else:
366 raise ValueError(f"Vertex '{node_id}' has unsupported kind: {kind!r}")
369def _encode_graph(graph: Graph) -> dict:
370 """Encode graph to JSON object with sorted keys."""
371 return dict(sorted((k, _encode_vertex(v)) for k, v in graph.items()))
374def _decode_graph(obj: dict, legacy_kind_inference: bool = False) -> Graph:
375 """Decode graph from JSON object."""
376 if not isinstance(obj, dict):
377 raise ValueError("Graph must be an object")
378 result: Graph = {}
379 for node_id, vertex_obj in obj.items():
380 result[node_id] = _decode_vertex(vertex_obj, legacy_kind_inference)
381 _validate_switch_targets(result)
382 return result
385def _validate_switch_targets(graph: Graph) -> None:
386 """Validate that switch branch targets are graph-local node IDs."""
387 node_ids = set(graph)
388 for node_id, vertex in graph.items():
389 if not isinstance(vertex, SwitchNode):
390 continue
391 targets = list(vertex.cases.values())
392 if vertex.default is not None:
393 targets.append(vertex.default)
394 for target in targets:
395 if target not in node_ids:
396 raise ValueError(
397 f"SwitchNode '{node_id}' targets '{target}' which must be "
398 f"a key in graph. Graph keys: {list(graph.keys())}"
399 )
402def _validate_output(graph: Graph, output: str | None) -> None:
403 if output is None:
404 return
405 if not isinstance(output, str) or not output:
406 raise ValueError("Document 'output' must be a non-empty string when present")
407 if output not in graph:
408 raise ValueError(
409 f"Document output '{output}' must be key in graph. "
410 f"Graph keys: {list(graph.keys())}"
411 )
414def _validate_output_arg(graph: Graph, output: str | None) -> None:
415 if output is None:
416 return
417 if not isinstance(output, str) or not output:
418 raise ValueError("output must be a non-empty string when present")
419 if output not in graph:
420 raise ValueError(
421 f"output '{output}' must be a key in graph. "
422 f"Graph keys: {list(graph.keys())}"
423 )
426def _validate_envelope(obj: dict) -> None:
427 """Validate top-level envelope."""
428 if not isinstance(obj, dict):
429 raise ValueError("Document must be a JSON object")
430 fmt = obj.get("format")
431 if fmt != FORMAT_ID:
432 raise ValueError(f"Document format must be '{FORMAT_ID}', got {fmt!r}")
433 version = obj.get("version")
434 if version not in SUPPORTED_VERSIONS:
435 raise ValueError(
436 f"Document version {version} is not supported. "
437 f"Supported: {sorted(SUPPORTED_VERSIONS)}"
438 )
439 if "graph" not in obj:
440 raise ValueError("Document must have 'graph'")
441 if not isinstance(obj["graph"], dict):
442 raise ValueError("Document 'graph' must be an object")
443 if "output" in obj:
444 output = obj["output"]
445 if not isinstance(output, str) or not output:
446 raise ValueError(
447 "Document 'output' must be a non-empty string when present"
448 )
451def dump_graph_to_dict(graph: Graph, *, output: str | None = None) -> dict:
452 """Serialize graph to envelope dict. Deterministic (sorted keys)."""
453 _validate_output_arg(graph, output)
454 document = {
455 "format": FORMAT_ID,
456 "version": 1,
457 }
458 if output is not None:
459 document["output"] = output
460 document["graph"] = _encode_graph(graph)
461 return document
464def dump_graph(graph: Graph, *, output: str | None = None) -> str:
465 """Serialize graph to JSON string. Deterministic output."""
466 return json.dumps(dump_graph_to_dict(graph, output=output), sort_keys=True)
469def load_graph_document_from_dict(
470 obj: dict, legacy_kind_inference: bool = False
471) -> tuple[Graph, str | None]:
472 """Load graph document from envelope dict, preserving optional output."""
473 _validate_envelope(obj)
474 graph = _decode_graph(obj["graph"], legacy_kind_inference)
475 output = obj.get("output")
476 _validate_output(graph, output)
477 return graph, output
480def load_graph_from_dict(obj: dict, legacy_kind_inference: bool = False) -> Graph:
481 """Load graph from envelope dict, discarding optional document output."""
482 graph, _output = load_graph_document_from_dict(obj, legacy_kind_inference)
483 return graph
486def load_graph_document(
487 data: str | bytes, legacy_kind_inference: bool = False
488) -> tuple[Graph, str | None]:
489 """Deserialize JSON string or bytes to graph document."""
490 if isinstance(data, bytes):
491 data = data.decode("utf-8")
492 obj = json.loads(data)
493 return load_graph_document_from_dict(obj, legacy_kind_inference)
496def load_graph(data: str | bytes, legacy_kind_inference: bool = False) -> Graph:
497 """Deserialize JSON string or bytes to graph, discarding document output."""
498 graph, _output = load_graph_document(data, legacy_kind_inference)
499 return graph
502def _encode_graph_document_payload(graph: Graph, output: str | None = None) -> str:
503 payload = json.dumps(
504 dump_graph_to_dict(graph, output=output),
505 separators=(",", ":"),
506 sort_keys=True,
507 )
508 return base64.b64encode(payload.encode("utf-8")).decode("ascii")
511def _query_value_to_string(value: Any) -> str:
512 if isinstance(value, str):
513 try:
514 json.loads(value)
515 except json.JSONDecodeError:
516 return value
517 return json.dumps(
518 dump_value_to_jsonable(value),
519 separators=(",", ":"),
520 sort_keys=True,
521 )
524def _query_value_from_string(value: str) -> Any:
525 try:
526 return load_value_from_jsonable(json.loads(value))
527 except json.JSONDecodeError:
528 return value
531def dump_graph_data_uri(
532 graph: Graph,
533 *,
534 output: str | None = None,
535 context: dict[str, Any] | None = None,
536) -> str:
537 """Serialize a graph document plus optional query context as a data URI."""
539 encoded = _encode_graph_document_payload(graph, output)
540 uri = f"{GRAPH_DATA_URI_PREFIX}{encoded}"
541 if not context:
542 return uri
544 for key in context:
545 if not isinstance(key, str) or not key:
546 raise ValueError("context keys must be non-empty strings")
547 query_pairs = [
548 (key, _query_value_to_string(value))
549 for key, value in sorted(context.items())
550 ]
551 return f"{uri}?{urlencode(query_pairs)}"
554def graph_data_uri_cache_key(data: str) -> str | None:
555 """Return the static graph data URI without query context."""
557 if not isinstance(data, str):
558 return None
560 parts = urlsplit(data)
561 graph_path_prefix = f"{GRAPH_MEDIA_TYPE};base64,"
562 if parts.scheme != "data" or not parts.path.startswith(graph_path_prefix):
563 return None
565 if parts.fragment:
566 raise ValueError("Invariant graph data URIs must not include fragments")
567 return urlunsplit((parts.scheme, parts.netloc, parts.path, "", ""))
570def _decode_query_context(query: str) -> dict[str, Any]:
571 context: dict[str, Any] = {}
572 for key, value in parse_qsl(query, keep_blank_values=True):
573 if not key:
574 raise ValueError("Invariant graph data URI query keys must be non-empty")
575 if key in context:
576 raise ValueError(
577 f"Invariant graph data URI query key {key!r} is duplicated"
578 )
579 context[key] = _query_value_from_string(value)
580 return context
583def load_graph_data_uri(
584 data: str, legacy_kind_inference: bool = False
585) -> tuple[Graph, str | None, dict[str, Any]] | None:
586 """Decode an Invariant graph data URI with optional query context."""
588 cache_key = graph_data_uri_cache_key(data)
589 if cache_key is None:
590 return None
592 parts = urlsplit(data)
593 encoded = parts.path[len(f"{GRAPH_MEDIA_TYPE};base64,") :]
594 try:
595 payload = base64.b64decode(encoded, validate=True)
596 obj = json.loads(payload.decode("utf-8"))
597 except Exception as exc:
598 raise ValueError(f"Invalid Invariant graph data URI payload: {exc}") from exc
600 graph, output = load_graph_document_from_dict(obj, legacy_kind_inference)
601 context = _decode_query_context(parts.query)
602 return graph, output, context