Coverage for src/par_run/executor.py: 98%

203 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-31 17:02 -0400

1"""Todo""" 

2 

3import asyncio 

4import configparser 

5import enum 

6import multiprocessing as mp 

7import queue 

8import subprocess 

9from collections import OrderedDict 

10from concurrent.futures import Future, ProcessPoolExecutor 

11from pathlib import Path 

12from queue import Queue 

13from typing import List, Optional, Protocol, TypeVar, Union 

14from pydantic import BaseModel, Field, ConfigDict 

15 

16 

17# Type alias for a generic future. 

18GenFuture = Union[Future, asyncio.Future] 

19 

20ContextT = TypeVar("ContextT") 

21 

22 

23class ProcessingStrategy(enum.Enum): 

24 """Enum for processing strategies.""" 

25 

26 ON_COMP = "comp" 

27 ON_RECV = "recv" 

28 

29 

30class CommandStatus(enum.Enum): 

31 """Enum for command status.""" 

32 

33 NOT_STARTED = "Not Started" 

34 RUNNING = "Running" 

35 SUCCESS = "Success" 

36 FAILURE = "Failure" 

37 

38 def completed(self) -> bool: 

39 """Return True if the command has completed.""" 

40 return self in [CommandStatus.SUCCESS, CommandStatus.FAILURE] 

41 

42 

43class Command(BaseModel): 

44 """Holder for a command and its name.""" 

45 

46 model_config = ConfigDict(arbitrary_types_allowed=True) 

47 

48 name: str 

49 cmd: str 

50 status: CommandStatus = CommandStatus.NOT_STARTED 

51 unflushed: List[str] = Field(default=[], exclude=True) 

52 num_non_empty_lines: int = Field(default=0, exclude=True) 

53 ret_code: Optional[int] = Field(default=None, exclude=True) 

54 fut: Optional[GenFuture] = Field(default=None, exclude=True) 

55 

56 def incr_line_count(self, line: str) -> None: 

57 """Increment the non-empty line count.""" 

58 if line.strip(): 

59 self.num_non_empty_lines += 1 

60 

61 def append_unflushed(self, line: str) -> None: 

62 """Append a line to the output and increment the non-empty line count.""" 

63 self.unflushed.append(line) 

64 

65 def clear_unflushed(self) -> None: 

66 """Clear the unflushed output.""" 

67 self.unflushed.clear() 

68 

69 def set_ret_code(self, ret_code: int): 

70 """Set the return code and status of the command.""" 

71 self.ret_code = ret_code 

72 if self.fut: 

73 self.fut.cancel() 

74 self.fut = None 

75 if ret_code == 0: 

76 self.status = CommandStatus.SUCCESS 

77 else: 

78 self.status = CommandStatus.FAILURE 

79 

80 def set_running(self): 

81 """Set the command status to running.""" 

82 self.status = CommandStatus.RUNNING 

83 

84 

85class CommandCB(Protocol): 

86 def on_start(self, cmd: Command) -> None: ... 

87 def on_recv(self, cmd: Command, output: str) -> None: ... 

88 def on_term(self, cmd: Command, exit_code: int) -> None: ... 

89 

90 

91class CommandAsyncCB(Protocol): 

92 async def on_start(self, cmd: Command) -> None: ... 

93 async def on_recv(self, cmd: Command, output: str) -> None: ... 

94 async def on_term(self, cmd: Command, exit_code: int) -> None: ... 

95 

96 

97class QRetriever: 

98 def __init__(self, q: Queue, timeout: int, retries: int): 

99 self.q = q 

100 self.timeout = timeout 

101 self.retries = retries 

102 

103 def get(self): 

104 retry_count = 0 

105 while True: 

106 try: 

107 return self.q.get(block=True, timeout=self.timeout) 

108 except queue.Empty: 

109 if retry_count < self.retries: 

110 retry_count += 1 

111 continue 

112 else: 

113 raise TimeoutError("Timeout waiting for command output") 

114 

115 

116class CommandGroup(BaseModel): 

117 """Holder for a group of commands.""" 

118 

119 name: str 

120 cmds: OrderedDict[str, Command] = Field(default_factory=OrderedDict) 

121 

122 async def run_async( 

123 self, 

124 strategy: ProcessingStrategy, 

125 callbacks: CommandAsyncCB, 

126 ): 

127 q = mp.Manager().Queue() 

128 pool = ProcessPoolExecutor() 

129 futs = [ 

130 asyncio.get_event_loop().run_in_executor(pool, run_command, cmd.name, cmd.cmd, q) 

131 for _, cmd in self.cmds.items() 

132 ] 

133 

134 for (_, cmd), fut in zip(self.cmds.items(), futs): 

135 cmd.fut = fut 

136 cmd.set_running() 

137 

138 return await self._process_q_async(q, strategy, callbacks) 

139 

140 def run( 

141 self, 

142 strategy: ProcessingStrategy, 

143 callbacks: CommandCB, 

144 ): 

145 q = mp.Manager().Queue() 

146 pool = ProcessPoolExecutor() 

147 futs = [pool.submit(run_command, cmd.name, cmd.cmd, q) for _, cmd in self.cmds.items()] 

148 for (_, cmd), fut in zip(self.cmds.items(), futs): 

149 cmd.fut = fut 

150 cmd.set_running() 

151 

152 return self._process_q(q, strategy, callbacks) 

153 

154 def _process_q( 

155 self, 

156 q: Queue, 

157 strategy: ProcessingStrategy, 

158 callbacks: CommandCB, 

159 ) -> int: 

160 grp_exit_code = 0 

161 

162 if strategy == ProcessingStrategy.ON_RECV: 

163 for _, cmd in self.cmds.items(): 

164 callbacks.on_start(cmd) 

165 

166 timeout = 10 

167 retries = 3 

168 q_ret = QRetriever(q, timeout, retries) 

169 while True: 

170 q_result = q_ret.get() 

171 

172 # Can only get here with a valid message from the Q 

173 cmd_name = q_result[0] 

174 exit_code: Optional[int] = q_result[1] if isinstance(q_result[1], int) else None 

175 output_line: Optional[str] = q_result[1] if isinstance(q_result[1], str) else None 

176 if exit_code is None and output_line is None: 

177 raise ValueError("Invalid Q message") # pragma: no cover 

178 

179 cmd = self.cmds[cmd_name] 

180 if strategy == ProcessingStrategy.ON_RECV: 

181 if output_line is not None: 

182 cmd.incr_line_count(output_line) 

183 callbacks.on_recv(cmd, output_line) 

184 elif exit_code is not None: 

185 cmd.set_ret_code(exit_code) 

186 callbacks.on_term(cmd, exit_code) 

187 if exit_code != 0: 

188 grp_exit_code = 1 

189 else: 

190 raise ValueError("Invalid Q message") # pragma: no cover 

191 

192 if strategy == ProcessingStrategy.ON_COMP: 

193 if output_line is not None: 

194 cmd.incr_line_count(output_line) 

195 cmd.append_unflushed(output_line) 

196 elif exit_code is not None: 

197 callbacks.on_start(cmd) 

198 for line in cmd.unflushed: 

199 callbacks.on_recv(cmd, line) 

200 cmd.clear_unflushed() 

201 callbacks.on_term(cmd, exit_code) 

202 cmd.set_ret_code(exit_code) 

203 if exit_code != 0: 

204 grp_exit_code = 1 

205 else: 

206 raise ValueError("Invalid Q message") # pragma: no cover 

207 

208 if all(cmd.status.completed() for _, cmd in self.cmds.items()): 

209 break 

210 return grp_exit_code 

211 

212 async def _process_q_async( 

213 self, 

214 q: Queue, 

215 strategy: ProcessingStrategy, 

216 callbacks: CommandAsyncCB, 

217 ) -> int: 

218 grp_exit_code = 0 

219 

220 if strategy == ProcessingStrategy.ON_RECV: 

221 for _, cmd in self.cmds.items(): 

222 await callbacks.on_start(cmd) 

223 

224 timeout = 10 

225 retries = 3 

226 q_ret = QRetriever(q, timeout, retries) 

227 while True: 

228 await asyncio.sleep(0) 

229 q_result = q_ret.get() 

230 

231 # Can only get here with a valid message from the Q 

232 cmd_name = q_result[0] 

233 # print(q_result, type(q_result[0]), type(q_result[1])) 

234 exit_code: Optional[int] = q_result[1] if isinstance(q_result[1], int) else None 

235 output_line: Optional[str] = q_result[1] if isinstance(q_result[1], str) else None 

236 # print(output_line, exit_code) 

237 if exit_code is None and output_line is None: 

238 raise ValueError("Invalid Q message") # pragma: no cover 

239 

240 cmd = self.cmds[cmd_name] 

241 if strategy == ProcessingStrategy.ON_RECV: 

242 if output_line is not None: 

243 cmd.incr_line_count(output_line) 

244 await callbacks.on_recv(cmd, output_line) 

245 elif exit_code is not None: 

246 cmd.set_ret_code(exit_code) 

247 await callbacks.on_term(cmd, exit_code) 

248 if exit_code != 0: 

249 grp_exit_code = 1 

250 else: 

251 raise ValueError("Invalid Q message") # pragma: no cover 

252 

253 if strategy == ProcessingStrategy.ON_COMP: 

254 if output_line is not None: 

255 cmd.incr_line_count(output_line) 

256 cmd.append_unflushed(output_line) 

257 elif exit_code is not None: 

258 await callbacks.on_start(cmd) 

259 for line in cmd.unflushed: 

260 await callbacks.on_recv(cmd, line) 

261 cmd.clear_unflushed() 

262 await callbacks.on_term(cmd, exit_code) 

263 cmd.set_ret_code(exit_code) 

264 if exit_code != 0: 

265 grp_exit_code = 1 

266 else: 

267 raise ValueError("Invalid Q message") # pragma: no cover 

268 

269 if all(cmd.status.completed() for _, cmd in self.cmds.items()): 

270 break 

271 return grp_exit_code 

272 

273 

274def read_commands_ini(filename: Union[str, Path]) -> list[CommandGroup]: 

275 """Read a commands.ini file and return a list of CommandGroup objects. 

276 

277 Args: 

278 filename (Union[str, Path]): The filename of the commands.ini file. 

279 

280 Returns: 

281 list[CommandGroup]: A list of CommandGroup objects. 

282 """ 

283 config = configparser.ConfigParser() 

284 config.read(filename) 

285 

286 command_groups = [] 

287 for section in config.sections(): 

288 if section.startswith("group."): 

289 group_name = section.replace("group.", "") 

290 commands = OrderedDict() 

291 for name, cmd in config.items(section): 

292 name = name.strip() 

293 commands[name] = Command(name=name, cmd=cmd.strip()) 

294 command_group = CommandGroup(name=group_name, cmds=commands) 

295 command_groups.append(command_group) 

296 

297 return command_groups 

298 

299 

300def write_commands_ini(filename: Union[str, Path], command_groups: list[CommandGroup]): 

301 """Write a list of CommandGroup objects to a commands.ini file. 

302 

303 Args: 

304 filename (Union[str, Path]): The filename of the commands.ini file. 

305 command_groups (list[CommandGroup]): A list of CommandGroup objects. 

306 """ 

307 config = configparser.ConfigParser() 

308 

309 for group in command_groups: 

310 section_name = f"group.{group.name}" 

311 config[section_name] = {} 

312 for _, command in group.cmds.items(): 

313 config[section_name][command.name] = command.cmd 

314 

315 with open(filename, "w", encoding="utf-8") as configfile: 

316 config.write(configfile) 

317 

318 

319def run_command(name: str, command: str, q: Queue) -> None: 

320 """Run a command and put the output into a queue. The output is a tuple of the command 

321 name and the output line. The final output is a tuple of the command name and a dictionary 

322 with the return code. 

323 

324 Args: 

325 name (str): Name of the command. 

326 command (str): Command to run. 

327 q (Queue): Queue to put the output into. 

328 """ 

329 

330 with subprocess.Popen( 

331 f"{command}", 

332 shell=True, 

333 stdout=subprocess.PIPE, 

334 stderr=subprocess.STDOUT, 

335 text=True, 

336 ) as process: 

337 if process.stdout: 

338 for line in iter(process.stdout.readline, ""): 

339 q.put((name, line.strip())) 

340 process.stdout.close() 

341 process.wait() 

342 ret_code = process.returncode 

343 if ret_code is not None: 

344 q.put((name, int(ret_code))) 

345 else: 

346 raise ValueError("Process has no return code") # pragma: no cover