Coverage for src / invariant / graph_serialization.py: 83.87%
248 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-03 19:45 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-03 19:45 +0000
1"""Graph serialization: JSON wire format for Invariant graphs.
3Encodes graphs (Node, SubGraphNode) and params (ref, cel, Decimal, tuple,
4ICacheable) for storage and transmission. Distinct from artifact serialization
5in store/codec.py.
6"""
8import base64
9import importlib
10import json
11from decimal import Decimal
12from io import BytesIO
13from typing import Any
15from invariant.graph import Graph
16from invariant.node import Node, SubGraphNode
17from invariant.params import cel, ref
18from invariant.protocol import ICacheable
20SUPPORTED_VERSIONS = {1}
21FORMAT_ID = "invariant-graph"
22GRAPH_OUTPUT_MEDIA_TYPE = "application/vnd.invariant.graph-output+json"
23GRAPH_OUTPUT_DATA_URI_PREFIX = f"data:{GRAPH_OUTPUT_MEDIA_TYPE};base64,"
25RESERVED_KEYS = frozenset(
26 {"$ref", "$cel", "$decimal", "$tuple", "$literal", "$icacheable"}
27)
30def _encode_param_value(value: Any) -> Any:
31 """Recursively encode a parameter value to JSON-serializable form."""
32 # ref marker
33 if isinstance(value, ref):
34 return {"$ref": value.dep}
36 # cel marker
37 if isinstance(value, cel):
38 return {"$cel": value.expr}
40 # Decimal
41 if isinstance(value, Decimal):
42 return {"$decimal": str(value)}
44 # tuple
45 if isinstance(value, tuple):
46 return {"$tuple": [_encode_param_value(item) for item in value]}
48 # ICacheable
49 if isinstance(value, ICacheable):
50 type_name = f"{value.__class__.__module__}.{value.__class__.__name__}"
51 if hasattr(value, "to_json_value") and callable(value.to_json_value):
52 return {"$icacheable": {"type": type_name, "value": value.to_json_value()}}
53 stream = BytesIO()
54 value.to_stream(stream)
55 payload_b64 = base64.b64encode(stream.getvalue()).decode("ascii")
56 return {"$icacheable": {"type": type_name, "payload_b64": payload_b64}}
58 # dict
59 if isinstance(value, dict):
60 encoded = {k: _encode_param_value(v) for k, v in value.items()}
61 # Collision: plain dict that would decode as marker -> wrap in $literal
62 if len(encoded) == 1:
63 (single_key,) = encoded.keys()
64 if single_key in RESERVED_KEYS:
65 return {"$literal": encoded}
66 return encoded
68 # list
69 if isinstance(value, list):
70 return [_encode_param_value(item) for item in value]
72 # Primitives: None, bool, int, str
73 return value
76def _decode_param_value(obj: Any, literal_mode: bool = False) -> Any:
77 """Recursively decode a JSON value to Python parameter value."""
78 # In literal mode, never treat dicts as markers
79 if literal_mode:
80 if isinstance(obj, dict):
81 return {
82 k: _decode_param_value(v, literal_mode=True) for k, v in obj.items()
83 }
84 if isinstance(obj, list):
85 return [_decode_param_value(item, literal_mode=True) for item in obj]
86 return obj
88 # Single-key dict with reserved key -> marker or escape
89 if isinstance(obj, dict):
90 if len(obj) == 1:
91 (key, val) = next(iter(obj.items()))
92 if key == "$ref":
93 return ref(val)
94 if key == "$cel":
95 return cel(val)
96 if key == "$decimal":
97 return Decimal(val)
98 if key == "$tuple":
99 return tuple(_decode_param_value(item) for item in val)
100 if key == "$literal":
101 return _decode_param_value(val, literal_mode=True)
102 if key == "$icacheable":
103 return _decode_icacheable(val)
104 # Multi-key or non-reserved: recursive decode
105 return {k: _decode_param_value(v) for k, v in obj.items()}
107 if isinstance(obj, list):
108 return [_decode_param_value(item) for item in obj]
110 return obj
113def _decode_icacheable(obj: dict) -> Any:
114 """Decode $icacheable object to ICacheable instance."""
115 if not isinstance(obj, dict):
116 raise ValueError("$icacheable value must be an object")
117 type_name = obj.get("type")
118 if not type_name or not isinstance(type_name, str):
119 raise ValueError("$icacheable must have non-empty string 'type'")
120 if "payload_b64" in obj and "value" in obj:
121 raise ValueError(
122 "$icacheable must have exactly one of 'payload_b64' or 'value'"
123 )
124 if "payload_b64" not in obj and "value" not in obj:
125 raise ValueError("$icacheable must have 'payload_b64' or 'value'")
127 module_path, class_name = type_name.rsplit(".", 1)
128 try:
129 module = importlib.import_module(module_path)
130 cls = getattr(module, class_name)
131 except (ImportError, AttributeError) as e:
132 raise ValueError(
133 f"$icacheable type '{type_name}' could not be imported: {e}"
134 ) from e
136 if "value" in obj:
137 if not hasattr(cls, "from_json_value"):
138 raise ValueError(
139 f"$icacheable type '{type_name}' has 'value' but no "
140 "from_json_value method"
141 )
142 return cls.from_json_value(obj["value"])
144 # payload_b64
145 try:
146 payload = base64.b64decode(obj["payload_b64"])
147 except Exception as e:
148 raise ValueError(f"$icacheable payload_b64 is invalid base64: {e}") from e
149 stream = BytesIO(payload)
150 try:
151 return cls.from_stream(stream)
152 except Exception as e:
153 raise ValueError(
154 f"$icacheable from_stream failed for '{type_name}': {e}"
155 ) from e
158def _encode_params(params: dict[str, Any]) -> dict[str, Any]:
159 """Encode params dict with sorted keys for determinism."""
160 return dict(sorted((k, _encode_param_value(v)) for k, v in params.items()))
163def _decode_params(obj: dict) -> dict[str, Any]:
164 """Decode params dict."""
165 return {k: _decode_param_value(v) for k, v in obj.items()}
168def dump_value_to_jsonable(value: Any) -> Any:
169 """Serialize a cacheable value to the graph JSON marker encoding."""
170 return _encode_param_value(value)
173def load_value_from_jsonable(obj: Any) -> Any:
174 """Deserialize a value from the graph JSON marker encoding."""
175 return _decode_param_value(obj)
178def _encode_vertex(vertex: Node | SubGraphNode) -> dict:
179 """Encode a single vertex (Node or SubGraphNode) to JSON object."""
180 if isinstance(vertex, Node):
181 result: dict = {
182 "kind": "node",
183 "op_name": vertex.op_name,
184 "params": _encode_params(vertex.params),
185 "deps": sorted(vertex.deps),
186 }
187 if not vertex.cache:
188 result["cache"] = False
189 return result
190 # SubGraphNode
191 return {
192 "kind": "subgraph",
193 "params": _encode_params(vertex.params),
194 "deps": sorted(vertex.deps),
195 "graph": _encode_graph(vertex.graph),
196 "output": vertex.output,
197 }
200def _decode_vertex(
201 obj: dict, legacy_kind_inference: bool = False
202) -> Node | SubGraphNode:
203 """Decode a JSON object to Node or SubGraphNode. Validates before construction."""
204 if not isinstance(obj, dict):
205 raise ValueError("Vertex must be an object")
207 kind = obj.get("kind")
208 if kind is None and legacy_kind_inference:
209 if "op_name" in obj and "graph" not in obj:
210 kind = "node"
211 elif "graph" in obj and "output" in obj:
212 kind = "subgraph"
213 else:
214 raise ValueError(
215 "Vertex has no 'kind' and cannot infer from op_name/graph/output"
216 )
217 if kind is None:
218 raise ValueError("Vertex must have 'kind'")
219 if kind not in ("node", "subgraph"):
220 raise ValueError(f"Vertex has unsupported kind: {kind!r}")
222 if kind == "node":
223 _validate_node(obj, expected_kind=kind)
224 return Node(
225 op_name=obj["op_name"].strip(),
226 params=_decode_params(obj["params"]),
227 deps=list(obj["deps"]),
228 cache=obj.get("cache", True),
229 )
230 if kind == "subgraph":
231 _validate_subgraph(obj, legacy_kind_inference)
232 return SubGraphNode(
233 params=_decode_params(obj["params"]),
234 deps=list(obj["deps"]),
235 graph=_decode_graph(obj["graph"], legacy_kind_inference),
236 output=obj["output"],
237 )
238 raise ValueError(f"Vertex has unsupported kind: {kind!r}")
241def _validate_node(obj: dict, expected_kind: str | None = None) -> None:
242 """Validate node object before construction."""
243 kind = expected_kind if expected_kind is not None else obj.get("kind")
244 if kind != "node":
245 raise ValueError("Node must have kind 'node'")
246 op_name = obj.get("op_name")
247 if not isinstance(op_name, str):
248 raise ValueError("Node must have string 'op_name'")
249 if not op_name.strip():
250 raise ValueError("Node op_name cannot be empty")
251 if "params" not in obj or not isinstance(obj["params"], dict):
252 raise ValueError("Node must have 'params' object")
253 if "deps" not in obj or not isinstance(obj["deps"], list):
254 raise ValueError("Node must have 'deps' array")
255 for i, dep in enumerate(obj["deps"]):
256 if not isinstance(dep, str):
257 raise ValueError(f"Node deps[{i}] must be string, got {type(dep).__name__}")
258 cache_val = obj.get("cache")
259 if cache_val is not None and not isinstance(cache_val, bool):
260 raise ValueError("Node 'cache' must be boolean when present")
263def _validate_subgraph(obj: dict, legacy_kind_inference: bool = False) -> None:
264 """Validate subgraph object before construction."""
265 kind = obj.get("kind")
266 if not legacy_kind_inference and kind != "subgraph":
267 raise ValueError("SubGraphNode must have kind 'subgraph'")
268 if "params" not in obj or not isinstance(obj["params"], dict):
269 raise ValueError("SubGraphNode must have 'params' object")
270 if "deps" not in obj or not isinstance(obj["deps"], list):
271 raise ValueError("SubGraphNode must have 'deps' array")
272 for i, dep in enumerate(obj["deps"]):
273 if not isinstance(dep, str):
274 raise ValueError(
275 f"SubGraphNode deps[{i}] must be string, got {type(dep).__name__}"
276 )
277 if "graph" not in obj or not isinstance(obj["graph"], dict):
278 raise ValueError("SubGraphNode must have 'graph' object")
279 output = obj.get("output")
280 if not isinstance(output, str):
281 raise ValueError("SubGraphNode must have string 'output'")
282 if output not in obj["graph"]:
283 raise ValueError(
284 f"SubGraphNode output '{output}' must be key in graph. "
285 f"Graph keys: {list(obj['graph'].keys())}"
286 )
287 for node_id, vertex_obj in obj["graph"].items():
288 _validate_vertex_for_kind(vertex_obj, node_id, legacy_kind_inference)
291def _validate_vertex_for_kind(
292 vertex_obj: Any, node_id: str, legacy_kind_inference: bool = False
293) -> None:
294 """Validate a vertex object has valid kind and structure."""
295 if not isinstance(vertex_obj, dict):
296 raise ValueError(f"Vertex '{node_id}' must be an object")
297 kind = vertex_obj.get("kind")
298 if kind is None and legacy_kind_inference:
299 if "op_name" in vertex_obj and "graph" not in vertex_obj:
300 kind = "node"
301 elif "graph" in vertex_obj and "output" in vertex_obj:
302 kind = "subgraph"
303 else:
304 raise ValueError(
305 f"Vertex '{node_id}' has no 'kind' and cannot infer from "
306 "op_name/graph/output"
307 )
308 if kind == "node":
309 _validate_node(vertex_obj, expected_kind="node")
310 elif kind == "subgraph":
311 _validate_subgraph(vertex_obj, legacy_kind_inference)
312 else:
313 raise ValueError(f"Vertex '{node_id}' has unsupported kind: {kind!r}")
316def _encode_graph(graph: Graph) -> dict:
317 """Encode graph to JSON object with sorted keys."""
318 return dict(sorted((k, _encode_vertex(v)) for k, v in graph.items()))
321def _decode_graph(obj: dict, legacy_kind_inference: bool = False) -> Graph:
322 """Decode graph from JSON object."""
323 if not isinstance(obj, dict):
324 raise ValueError("Graph must be an object")
325 result: Graph = {}
326 for node_id, vertex_obj in obj.items():
327 result[node_id] = _decode_vertex(vertex_obj, legacy_kind_inference)
328 return result
331def _validate_envelope(obj: dict) -> None:
332 """Validate top-level envelope."""
333 if not isinstance(obj, dict):
334 raise ValueError("Document must be a JSON object")
335 fmt = obj.get("format")
336 if fmt != FORMAT_ID:
337 raise ValueError(f"Document format must be '{FORMAT_ID}', got {fmt!r}")
338 version = obj.get("version")
339 if version not in SUPPORTED_VERSIONS:
340 raise ValueError(
341 f"Document version {version} is not supported. "
342 f"Supported: {sorted(SUPPORTED_VERSIONS)}"
343 )
344 if "graph" not in obj:
345 raise ValueError("Document must have 'graph'")
346 if not isinstance(obj["graph"], dict):
347 raise ValueError("Document 'graph' must be an object")
350def dump_graph_to_dict(graph: Graph) -> dict:
351 """Serialize graph to envelope dict. Deterministic (sorted keys)."""
352 return {
353 "format": FORMAT_ID,
354 "version": 1,
355 "graph": _encode_graph(graph),
356 }
359def dump_graph(graph: Graph) -> str:
360 """Serialize graph to JSON string. Deterministic output."""
361 return json.dumps(dump_graph_to_dict(graph), sort_keys=True)
364def load_graph_from_dict(obj: dict, legacy_kind_inference: bool = False) -> Graph:
365 """Load graph from envelope dict."""
366 _validate_envelope(obj)
367 return _decode_graph(obj["graph"], legacy_kind_inference)
370def load_graph(data: str | bytes, legacy_kind_inference: bool = False) -> Graph:
371 """Deserialize JSON string or bytes to graph."""
372 if isinstance(data, bytes):
373 data = data.decode("utf-8")
374 obj = json.loads(data)
375 return load_graph_from_dict(obj, legacy_kind_inference)
378def dump_graph_output_to_dict(graph: Graph, output: str) -> dict[str, Any]:
379 """Serialize a graph plus output node name to a JSON-friendly wrapper."""
381 return {"graph": dump_graph_to_dict(graph), "output": output}
384def load_graph_output_from_dict(
385 obj: dict[str, Any], legacy_kind_inference: bool = False
386) -> tuple[Graph, str]:
387 """Load a graph-plus-output wrapper from a dict."""
389 if not isinstance(obj, dict):
390 raise ValueError("Document must be an object")
392 output = obj.get("output", "output")
393 if not isinstance(output, str) or not output:
394 raise ValueError("Document 'output' must be a non-empty string")
396 raw_graph = obj.get("graph")
397 if not isinstance(raw_graph, dict):
398 raise ValueError("Document must have object 'graph'")
400 graph = load_graph_from_dict(raw_graph, legacy_kind_inference)
401 return graph, output
404def dump_graph_output_data_uri(graph: Graph, output: str) -> str:
405 """Serialize a graph-plus-output wrapper as a deterministic data URI."""
407 payload = json.dumps(
408 dump_graph_output_to_dict(graph, output),
409 separators=(",", ":"),
410 sort_keys=True,
411 ).encode("utf-8")
412 encoded = base64.b64encode(payload).decode("ascii")
413 return f"{GRAPH_OUTPUT_DATA_URI_PREFIX}{encoded}"
416def load_graph_output_data_uri(
417 data: str, legacy_kind_inference: bool = False
418) -> tuple[Graph, str] | None:
419 """Decode a graph-plus-output data URI. Returns None if parsing fails."""
421 if not isinstance(data, str) or not data.startswith(GRAPH_OUTPUT_DATA_URI_PREFIX):
422 return None
424 encoded = data[len(GRAPH_OUTPUT_DATA_URI_PREFIX) :]
425 try:
426 payload = base64.b64decode(encoded, validate=True)
427 obj = json.loads(payload.decode("utf-8"))
428 return load_graph_output_from_dict(obj, legacy_kind_inference)
429 except Exception:
430 return None