Coverage for src / invariant / registry.py: 85.84%

113 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-08 09:24 +0000

1"""OpRegistry for mapping operation names to callables.""" 

2 

3import importlib 

4import types 

5from collections.abc import Callable, Iterable 

6from dataclasses import dataclass 

7from importlib.metadata import entry_points 

8from typing import Any 

9 

10from invariant.traits import TraitLike, decorated_traits, normalize_traits 

11 

12# Type alias for op packages: dict mapping short names to op callables 

13OpPackage = dict[str, Callable[..., Any]] 

14 

15 

16@dataclass(frozen=True) 

17class OpBinding: 

18 """Registered operation plus scheduler metadata.""" 

19 

20 name: str 

21 op: Callable[..., Any] 

22 traits: frozenset[str] 

23 implementation_ref: str | None = None 

24 

25 

26def import_implementation_ref(ref: str) -> Callable[..., Any]: 

27 """Import a callable from ``module.path:qualname``.""" 

28 if ":" not in ref: 

29 raise ValueError( 

30 "implementation_ref must use 'module.path:qualname' format" 

31 ) 

32 

33 module_name, qualname = ref.split(":", 1) 

34 if not module_name or not qualname: 

35 raise ValueError( 

36 "implementation_ref must use 'module.path:qualname' format" 

37 ) 

38 

39 obj: Any = importlib.import_module(module_name) 

40 for part in qualname.split("."): 

41 if not part: 

42 raise ValueError(f"Invalid implementation_ref {ref!r}") 

43 obj = getattr(obj, part) 

44 

45 if not callable(obj): 

46 raise TypeError(f"implementation_ref {ref!r} does not resolve to a callable") 

47 return obj 

48 

49 

50def infer_implementation_ref(op: Callable[..., Any]) -> str | None: 

51 """Infer an importable implementation ref for a top-level callable. 

52 

53 Local functions, lambdas, ``__main__`` callables, and dynamically replaced 

54 attributes return ``None``. Process and remote schedulers require an exact 

55 worker-resolvable reference and should reject missing refs. 

56 """ 

57 module_name = getattr(op, "__module__", None) 

58 qualname = getattr(op, "__qualname__", None) 

59 if ( 

60 not module_name 

61 or not qualname 

62 or module_name == "__main__" 

63 or "<locals>" in qualname 

64 or qualname == "<lambda>" 

65 ): 

66 return None 

67 

68 ref = f"{module_name}:{qualname}" 

69 try: 

70 imported = import_implementation_ref(ref) 

71 except Exception: 

72 return None 

73 if imported is not op: 

74 return None 

75 return ref 

76 

77 

78class OpRegistry: 

79 """Singleton registry mapping string identifiers to executable Python callables. 

80 

81 Decouples the "string" name in the graph definition from the actual Python code. 

82 """ 

83 

84 _instance: "OpRegistry | None" = None 

85 _initialized: bool = False 

86 

87 def __new__(cls) -> "OpRegistry": 

88 """Ensure singleton pattern.""" 

89 if cls._instance is None: 

90 cls._instance = super().__new__(cls) 

91 return cls._instance 

92 

93 def __init__(self) -> None: 

94 """Initialize the registry (only once).""" 

95 if not OpRegistry._initialized: 

96 self._bindings: dict[str, OpBinding] = {} 

97 OpRegistry._initialized = True 

98 

99 def register( 

100 self, 

101 name: str, 

102 op: Callable[..., Any], 

103 *, 

104 traits: Iterable[TraitLike] | None = None, 

105 implementation_ref: str | None = None, 

106 ) -> None: 

107 """Register an operation. 

108 

109 Args: 

110 name: The string identifier for the operation. 

111 op: The callable that implements the operation. 

112 Should be a plain Python function with typed parameters. 

113 traits: Optional execution traits for scheduler routing. 

114 implementation_ref: Optional worker-resolvable implementation 

115 reference in ``module.path:qualname`` form. 

116 

117 Raises: 

118 ValueError: If name is empty or already registered. 

119 """ 

120 if not name: 

121 raise ValueError("Operation name cannot be empty") 

122 if name in self._bindings: 

123 raise ValueError(f"Operation '{name}' is already registered") 

124 

125 if implementation_ref is not None: 

126 import_implementation_ref(implementation_ref) 

127 else: 

128 implementation_ref = infer_implementation_ref(op) 

129 

130 merged_traits = decorated_traits(op) | normalize_traits(traits) 

131 self._bindings[name] = OpBinding( 

132 name=name, 

133 op=op, 

134 traits=merged_traits, 

135 implementation_ref=implementation_ref, 

136 ) 

137 

138 def get(self, name: str) -> Callable[..., Any]: 

139 """Get an operation by name. 

140 

141 Args: 

142 name: The string identifier for the operation. 

143 

144 Returns: 

145 The callable that implements the operation. 

146 

147 Raises: 

148 KeyError: If operation is not registered. 

149 """ 

150 if name not in self._bindings: 

151 raise KeyError(f"Operation '{name}' is not registered") 

152 return self._bindings[name].op 

153 

154 def get_binding(self, name: str) -> OpBinding: 

155 """Get the full operation binding by name.""" 

156 if name not in self._bindings: 

157 raise KeyError(f"Operation '{name}' is not registered") 

158 return self._bindings[name] 

159 

160 def traits(self, name: str) -> frozenset[str]: 

161 """Get normalized execution traits for an operation.""" 

162 return self.get_binding(name).traits 

163 

164 def implementation_ref(self, name: str) -> str | None: 

165 """Get the worker-resolvable implementation reference for an operation.""" 

166 return self.get_binding(name).implementation_ref 

167 

168 def has(self, name: str) -> bool: 

169 """Check if an operation is registered. 

170 

171 Args: 

172 name: The string identifier for the operation. 

173 

174 Returns: 

175 True if registered, False otherwise. 

176 """ 

177 return name in self._bindings 

178 

179 def clear(self) -> None: 

180 """Clear all registered operations (mainly for testing).""" 

181 self._bindings.clear() 

182 

183 def register_package(self, prefix: str, ops: OpPackage | Any) -> None: 

184 """Register all ops from a package under a common prefix. 

185 

186 Args: 

187 prefix: The namespace prefix (e.g. "poly"). 

188 ops: Either a dict mapping short names to callables (OpPackage), 

189 or a Python module that has an OPS dict attribute. 

190 

191 Raises: 

192 ValueError: If prefix is empty, ops is invalid, or any operation 

193 name is already registered. 

194 AttributeError: If ops is a module but doesn't have an OPS attribute. 

195 """ 

196 if not prefix: 

197 raise ValueError("Package prefix cannot be empty") 

198 

199 # Extract the ops dict from the input 

200 ops_dict: OpPackage 

201 if isinstance(ops, dict): 

202 ops_dict = ops 

203 elif isinstance(ops, types.ModuleType): 

204 # It's a module - check for OPS attribute 

205 if not hasattr(ops, "OPS"): 

206 raise AttributeError( 

207 f"Module {ops.__name__} does not have an OPS attribute" 

208 ) 

209 ops_dict = ops.OPS 

210 if not isinstance(ops_dict, dict): 

211 raise ValueError(f"OPS attribute must be a dict, got {type(ops_dict)}") 

212 elif hasattr(ops, "OPS"): 

213 # Object with OPS attribute (not a module) 

214 ops_dict = ops.OPS 

215 if not isinstance(ops_dict, dict): 

216 raise ValueError(f"OPS attribute must be a dict, got {type(ops_dict)}") 

217 else: 

218 raise ValueError( 

219 f"ops must be a dict or module with OPS attribute, got {type(ops)}" 

220 ) 

221 

222 # Register each op with the prefix 

223 for name, op in ops_dict.items(): 

224 full_name = f"{prefix}:{name}" 

225 self.register(full_name, op) 

226 

227 def auto_discover(self) -> None: 

228 """Discover and register op packages from entry points. 

229 

230 Scans the 'invariant.ops' entry point group. Each entry point 

231 should resolve to either: 

232 - A dict[str, Callable] (the OPS dict directly) 

233 - A callable that returns such a dict 

234 

235 The entry point name becomes the package prefix. 

236 

237 Raises: 

238 ValueError: If any operation name is already registered 

239 (via register_package). 

240 """ 

241 eps = entry_points(group="invariant.ops") 

242 

243 for ep in eps: 

244 try: 

245 # Load the entry point 

246 loaded = ep.load() 

247 

248 # Extract the ops dict 

249 ops_dict: OpPackage 

250 if isinstance(loaded, dict): 

251 ops_dict = loaded 

252 elif callable(loaded): 

253 # Callable that returns the dict 

254 result = loaded() 

255 if not isinstance(result, dict): 

256 continue # Skip invalid entry points 

257 ops_dict = result 

258 else: 

259 continue # Skip invalid entry points 

260 

261 # Register the package using the entry point name as prefix 

262 self.register_package(ep.name, ops_dict) 

263 except Exception: 

264 # Skip invalid entry points silently 

265 continue