Coverage for src / invariant / scheduler.py: 67.07%

82 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-05-08 09:24 +0000

1"""Invocation schedulers for async execution.""" 

2 

3import asyncio 

4from collections.abc import Callable 

5from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor 

6from dataclasses import dataclass 

7from typing import Any, Protocol 

8 

9from invariant.invocation import invoke_op 

10from invariant.registry import OpRegistry, import_implementation_ref 

11from invariant.store.codec import deserialize, serialize 

12from invariant.traits import OpTrait 

13 

14 

15@dataclass(frozen=True) 

16class InvocationRequest: 

17 """A scheduler-facing operation invocation request.""" 

18 

19 op_name: str 

20 op: Callable[..., Any] 

21 manifest: dict[str, Any] 

22 traits: frozenset[str] 

23 implementation_ref: str | None = None 

24 cache_key: tuple[str, str] | None = None 

25 

26 

27class InvocationScheduler(Protocol): 

28 """Protocol implemented by local and remote invocation schedulers.""" 

29 

30 async def invoke(self, request: InvocationRequest) -> Any: 

31 """Invoke an operation and return its artifact.""" 

32 ... 

33 

34 

35class InlineScheduler: 

36 """Invoke operations directly on the event loop thread.""" 

37 

38 async def invoke(self, request: InvocationRequest) -> Any: 

39 """Invoke an operation inline.""" 

40 return invoke_op(request.op, request.op_name, request.manifest) 

41 

42 

43class ThreadPoolScheduler: 

44 """Invoke operations in a thread pool.""" 

45 

46 def __init__( 

47 self, 

48 max_workers: int | None = None, 

49 executor: ThreadPoolExecutor | None = None, 

50 ) -> None: 

51 if max_workers is not None and executor is not None: 

52 raise ValueError("max_workers cannot be set when executor is provided") 

53 self._executor = executor or ThreadPoolExecutor(max_workers=max_workers) 

54 self._owns_executor = executor is None 

55 

56 async def invoke(self, request: InvocationRequest) -> Any: 

57 """Invoke an operation in the thread pool.""" 

58 loop = asyncio.get_running_loop() 

59 return await loop.run_in_executor( 

60 self._executor, 

61 invoke_op, 

62 request.op, 

63 request.op_name, 

64 request.manifest, 

65 ) 

66 

67 async def aclose(self) -> None: 

68 """Shut down the owned thread pool.""" 

69 if self._owns_executor: 

70 self._executor.shutdown(wait=True) 

71 

72 

73class ProcessPoolScheduler: 

74 """Invoke worker-resolvable operations in a process pool.""" 

75 

76 def __init__( 

77 self, 

78 max_workers: int | None = None, 

79 executor: ProcessPoolExecutor | None = None, 

80 ) -> None: 

81 if max_workers is not None and executor is not None: 

82 raise ValueError("max_workers cannot be set when executor is provided") 

83 self._executor = executor or ProcessPoolExecutor(max_workers=max_workers) 

84 self._owns_executor = executor is None 

85 

86 async def invoke(self, request: InvocationRequest) -> Any: 

87 """Invoke an operation through an Invariant codec process boundary.""" 

88 if not request.implementation_ref: 

89 raise ValueError( 

90 f"Op '{request.op_name}' cannot run in a process because it has " 

91 "no worker-resolvable implementation_ref" 

92 ) 

93 

94 manifest_payload = serialize(request.manifest) 

95 loop = asyncio.get_running_loop() 

96 artifact_payload = await loop.run_in_executor( 

97 self._executor, 

98 _process_worker_invoke, 

99 request.op_name, 

100 request.implementation_ref, 

101 manifest_payload, 

102 ) 

103 return deserialize(artifact_payload) 

104 

105 async def aclose(self) -> None: 

106 """Shut down the owned process pool.""" 

107 if self._owns_executor: 

108 self._executor.shutdown(wait=True) 

109 

110 

111class RoutingScheduler: 

112 """Route invocations to local schedulers according to traits.""" 

113 

114 def __init__( 

115 self, 

116 *, 

117 inline_scheduler: InvocationScheduler | None = None, 

118 thread_scheduler: InvocationScheduler | None = None, 

119 process_scheduler: InvocationScheduler | None = None, 

120 ) -> None: 

121 self.inline_scheduler = inline_scheduler or InlineScheduler() 

122 self.thread_scheduler = thread_scheduler 

123 self.process_scheduler = process_scheduler 

124 

125 async def invoke(self, request: InvocationRequest) -> Any: 

126 """Route an invocation to the first configured matching scheduler.""" 

127 if ( 

128 OpTrait.PROCESS_SAFE.value in request.traits 

129 and self.process_scheduler is not None 

130 ): 

131 return await self.process_scheduler.invoke(request) 

132 

133 if self.thread_scheduler is not None and ( 

134 OpTrait.BLOCKING.value in request.traits 

135 or OpTrait.IO_BOUND.value in request.traits 

136 ): 

137 return await self.thread_scheduler.invoke(request) 

138 

139 return await self.inline_scheduler.invoke(request) 

140 

141 async def aclose(self) -> None: 

142 """Close child schedulers that expose ``aclose``.""" 

143 for scheduler in ( 

144 self.process_scheduler, 

145 self.thread_scheduler, 

146 self.inline_scheduler, 

147 ): 

148 close = getattr(scheduler, "aclose", None) 

149 if close is not None: 

150 await close() 

151 

152 

153def _process_worker_invoke( 

154 op_name: str, 

155 implementation_ref: str, 

156 manifest_payload: bytes, 

157) -> bytes: 

158 """Process worker entrypoint. 

159 

160 The parent sends only simple strings and Invariant codec bytes. The worker 

161 resolves the exact callable locally, invokes it, and returns codec bytes. 

162 """ 

163 registry = OpRegistry() 

164 registry.clear() 

165 registry.auto_discover() 

166 

167 if registry.has(op_name): 

168 binding = registry.get_binding(op_name) 

169 if binding.implementation_ref != implementation_ref: 

170 raise ValueError( 

171 f"Worker discovered op '{op_name}' as " 

172 f"{binding.implementation_ref!r}, but request requires " 

173 f"{implementation_ref!r}" 

174 ) 

175 op = binding.op 

176 else: 

177 op = import_implementation_ref(implementation_ref) 

178 registry.register( 

179 op_name, 

180 op, 

181 implementation_ref=implementation_ref, 

182 ) 

183 

184 manifest = deserialize(manifest_payload) 

185 artifact = invoke_op(op, op_name, manifest) 

186 return serialize(artifact) 

187 

188 

189__all__ = [ 

190 "InlineScheduler", 

191 "InvocationRequest", 

192 "InvocationScheduler", 

193 "ProcessPoolScheduler", 

194 "RoutingScheduler", 

195 "ThreadPoolScheduler", 

196]