Coverage for src / invariant / registry.py: 85.84%
113 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-08 09:24 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-05-08 09:24 +0000
1"""OpRegistry for mapping operation names to callables."""
3import importlib
4import types
5from collections.abc import Callable, Iterable
6from dataclasses import dataclass
7from importlib.metadata import entry_points
8from typing import Any
10from invariant.traits import TraitLike, decorated_traits, normalize_traits
12# Type alias for op packages: dict mapping short names to op callables
13OpPackage = dict[str, Callable[..., Any]]
16@dataclass(frozen=True)
17class OpBinding:
18 """Registered operation plus scheduler metadata."""
20 name: str
21 op: Callable[..., Any]
22 traits: frozenset[str]
23 implementation_ref: str | None = None
26def import_implementation_ref(ref: str) -> Callable[..., Any]:
27 """Import a callable from ``module.path:qualname``."""
28 if ":" not in ref:
29 raise ValueError(
30 "implementation_ref must use 'module.path:qualname' format"
31 )
33 module_name, qualname = ref.split(":", 1)
34 if not module_name or not qualname:
35 raise ValueError(
36 "implementation_ref must use 'module.path:qualname' format"
37 )
39 obj: Any = importlib.import_module(module_name)
40 for part in qualname.split("."):
41 if not part:
42 raise ValueError(f"Invalid implementation_ref {ref!r}")
43 obj = getattr(obj, part)
45 if not callable(obj):
46 raise TypeError(f"implementation_ref {ref!r} does not resolve to a callable")
47 return obj
50def infer_implementation_ref(op: Callable[..., Any]) -> str | None:
51 """Infer an importable implementation ref for a top-level callable.
53 Local functions, lambdas, ``__main__`` callables, and dynamically replaced
54 attributes return ``None``. Process and remote schedulers require an exact
55 worker-resolvable reference and should reject missing refs.
56 """
57 module_name = getattr(op, "__module__", None)
58 qualname = getattr(op, "__qualname__", None)
59 if (
60 not module_name
61 or not qualname
62 or module_name == "__main__"
63 or "<locals>" in qualname
64 or qualname == "<lambda>"
65 ):
66 return None
68 ref = f"{module_name}:{qualname}"
69 try:
70 imported = import_implementation_ref(ref)
71 except Exception:
72 return None
73 if imported is not op:
74 return None
75 return ref
78class OpRegistry:
79 """Singleton registry mapping string identifiers to executable Python callables.
81 Decouples the "string" name in the graph definition from the actual Python code.
82 """
84 _instance: "OpRegistry | None" = None
85 _initialized: bool = False
87 def __new__(cls) -> "OpRegistry":
88 """Ensure singleton pattern."""
89 if cls._instance is None:
90 cls._instance = super().__new__(cls)
91 return cls._instance
93 def __init__(self) -> None:
94 """Initialize the registry (only once)."""
95 if not OpRegistry._initialized:
96 self._bindings: dict[str, OpBinding] = {}
97 OpRegistry._initialized = True
99 def register(
100 self,
101 name: str,
102 op: Callable[..., Any],
103 *,
104 traits: Iterable[TraitLike] | None = None,
105 implementation_ref: str | None = None,
106 ) -> None:
107 """Register an operation.
109 Args:
110 name: The string identifier for the operation.
111 op: The callable that implements the operation.
112 Should be a plain Python function with typed parameters.
113 traits: Optional execution traits for scheduler routing.
114 implementation_ref: Optional worker-resolvable implementation
115 reference in ``module.path:qualname`` form.
117 Raises:
118 ValueError: If name is empty or already registered.
119 """
120 if not name:
121 raise ValueError("Operation name cannot be empty")
122 if name in self._bindings:
123 raise ValueError(f"Operation '{name}' is already registered")
125 if implementation_ref is not None:
126 import_implementation_ref(implementation_ref)
127 else:
128 implementation_ref = infer_implementation_ref(op)
130 merged_traits = decorated_traits(op) | normalize_traits(traits)
131 self._bindings[name] = OpBinding(
132 name=name,
133 op=op,
134 traits=merged_traits,
135 implementation_ref=implementation_ref,
136 )
138 def get(self, name: str) -> Callable[..., Any]:
139 """Get an operation by name.
141 Args:
142 name: The string identifier for the operation.
144 Returns:
145 The callable that implements the operation.
147 Raises:
148 KeyError: If operation is not registered.
149 """
150 if name not in self._bindings:
151 raise KeyError(f"Operation '{name}' is not registered")
152 return self._bindings[name].op
154 def get_binding(self, name: str) -> OpBinding:
155 """Get the full operation binding by name."""
156 if name not in self._bindings:
157 raise KeyError(f"Operation '{name}' is not registered")
158 return self._bindings[name]
160 def traits(self, name: str) -> frozenset[str]:
161 """Get normalized execution traits for an operation."""
162 return self.get_binding(name).traits
164 def implementation_ref(self, name: str) -> str | None:
165 """Get the worker-resolvable implementation reference for an operation."""
166 return self.get_binding(name).implementation_ref
168 def has(self, name: str) -> bool:
169 """Check if an operation is registered.
171 Args:
172 name: The string identifier for the operation.
174 Returns:
175 True if registered, False otherwise.
176 """
177 return name in self._bindings
179 def clear(self) -> None:
180 """Clear all registered operations (mainly for testing)."""
181 self._bindings.clear()
183 def register_package(self, prefix: str, ops: OpPackage | Any) -> None:
184 """Register all ops from a package under a common prefix.
186 Args:
187 prefix: The namespace prefix (e.g. "poly").
188 ops: Either a dict mapping short names to callables (OpPackage),
189 or a Python module that has an OPS dict attribute.
191 Raises:
192 ValueError: If prefix is empty, ops is invalid, or any operation
193 name is already registered.
194 AttributeError: If ops is a module but doesn't have an OPS attribute.
195 """
196 if not prefix:
197 raise ValueError("Package prefix cannot be empty")
199 # Extract the ops dict from the input
200 ops_dict: OpPackage
201 if isinstance(ops, dict):
202 ops_dict = ops
203 elif isinstance(ops, types.ModuleType):
204 # It's a module - check for OPS attribute
205 if not hasattr(ops, "OPS"):
206 raise AttributeError(
207 f"Module {ops.__name__} does not have an OPS attribute"
208 )
209 ops_dict = ops.OPS
210 if not isinstance(ops_dict, dict):
211 raise ValueError(f"OPS attribute must be a dict, got {type(ops_dict)}")
212 elif hasattr(ops, "OPS"):
213 # Object with OPS attribute (not a module)
214 ops_dict = ops.OPS
215 if not isinstance(ops_dict, dict):
216 raise ValueError(f"OPS attribute must be a dict, got {type(ops_dict)}")
217 else:
218 raise ValueError(
219 f"ops must be a dict or module with OPS attribute, got {type(ops)}"
220 )
222 # Register each op with the prefix
223 for name, op in ops_dict.items():
224 full_name = f"{prefix}:{name}"
225 self.register(full_name, op)
227 def auto_discover(self) -> None:
228 """Discover and register op packages from entry points.
230 Scans the 'invariant.ops' entry point group. Each entry point
231 should resolve to either:
232 - A dict[str, Callable] (the OPS dict directly)
233 - A callable that returns such a dict
235 The entry point name becomes the package prefix.
237 Raises:
238 ValueError: If any operation name is already registered
239 (via register_package).
240 """
241 eps = entry_points(group="invariant.ops")
243 for ep in eps:
244 try:
245 # Load the entry point
246 loaded = ep.load()
248 # Extract the ops dict
249 ops_dict: OpPackage
250 if isinstance(loaded, dict):
251 ops_dict = loaded
252 elif callable(loaded):
253 # Callable that returns the dict
254 result = loaded()
255 if not isinstance(result, dict):
256 continue # Skip invalid entry points
257 ops_dict = result
258 else:
259 continue # Skip invalid entry points
261 # Register the package using the entry point name as prefix
262 self.register_package(ep.name, ops_dict)
263 except Exception:
264 # Skip invalid entry points silently
265 continue