Coverage for src/par_run/executor.py: 85%

251 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-12 08:32 -0400

1"""Todo""" 

2 

3import asyncio 

4import configparser 

5import enum 

6import multiprocessing as mp 

7import os 

8import queue 

9import subprocess 

10from collections import OrderedDict 

11from concurrent.futures import Future, ProcessPoolExecutor 

12from pathlib import Path 

13from queue import Queue 

14import time 

15from typing import List, Optional, Protocol, TypeVar, Union, Any 

16from pydantic import BaseModel, Field, ConfigDict 

17import tomlkit 

18 

19# Type alias for a generic future. 

20GenFuture = Union[Future, asyncio.Future] 

21 

22ContextT = TypeVar("ContextT") 

23 

24 

25class ProcessingStrategy(enum.Enum): 

26 """Enum for processing strategies.""" 

27 

28 ON_COMP = "comp" 

29 ON_RECV = "recv" 

30 

31 

32class CommandStatus(enum.Enum): 

33 """Enum for command status.""" 

34 

35 NOT_STARTED = "Not Started" 

36 RUNNING = "Running" 

37 SUCCESS = "Success" 

38 FAILURE = "Failure" 

39 

40 def completed(self) -> bool: 

41 """Return True if the command has completed.""" 

42 return self in [CommandStatus.SUCCESS, CommandStatus.FAILURE] 

43 

44 

45class Command(BaseModel): 

46 """Holder for a command and its name.""" 

47 

48 model_config = ConfigDict(arbitrary_types_allowed=True) 

49 

50 name: str 

51 cmd: str 

52 passenv: Optional[list[str]] = Field(default=None) 

53 setenv: Optional[dict[str, str]] = Field(default=None) 

54 status: CommandStatus = CommandStatus.NOT_STARTED 

55 unflushed: List[str] = Field(default=[], exclude=True) 

56 num_non_empty_lines: int = Field(default=0, exclude=True) 

57 ret_code: Optional[int] = Field(default=None, exclude=True) 

58 fut: Optional[GenFuture] = Field(default=None, exclude=True) 

59 start_time: Optional[float] = Field(default=None, exclude=True) 

60 elapsed: Optional[float] = Field(default=None, exclude=True) 

61 

62 def incr_line_count(self, line: str) -> None: 

63 """Increment the non-empty line count.""" 

64 if line.strip(): 

65 self.num_non_empty_lines += 1 

66 

67 def append_unflushed(self, line: str) -> None: 

68 """Append a line to the output and increment the non-empty line count.""" 

69 self.unflushed.append(line) 

70 

71 def clear_unflushed(self) -> None: 

72 """Clear the unflushed output.""" 

73 self.unflushed.clear() 

74 

75 def set_ret_code(self, ret_code: int): 

76 """Set the return code and status of the command.""" 

77 if self.start_time: 

78 self.elapsed = time.perf_counter() - self.start_time 

79 self.ret_code = ret_code 

80 if self.fut: 

81 self.fut.cancel() 

82 self.fut = None 

83 if ret_code == 0: 

84 self.status = CommandStatus.SUCCESS 

85 else: 

86 self.status = CommandStatus.FAILURE 

87 

88 def set_running(self): 

89 """Set the command status to running.""" 

90 self.start_time = time.perf_counter() 

91 self.status = CommandStatus.RUNNING 

92 

93 

94class CommandCB(Protocol): 

95 def on_start(self, cmd: Command) -> None: ... 

96 def on_recv(self, cmd: Command, output: str) -> None: ... 

97 def on_term(self, cmd: Command, exit_code: int) -> None: ... 

98 

99 

100class CommandAsyncCB(Protocol): 

101 async def on_start(self, cmd: Command) -> None: ... 

102 async def on_recv(self, cmd: Command, output: str) -> None: ... 

103 async def on_term(self, cmd: Command, exit_code: int) -> None: ... 

104 

105 

106class QRetriever: 

107 def __init__(self, q: Queue, timeout: int, retries: int): 

108 self.q = q 

109 self.timeout = timeout 

110 self.retries = retries 

111 

112 def get(self): 

113 retry_count = 0 

114 while True: 

115 try: 

116 return self.q.get(block=True, timeout=self.timeout) 

117 except queue.Empty: 

118 if retry_count < self.retries: 

119 retry_count += 1 

120 continue 

121 else: 

122 raise TimeoutError("Timeout waiting for command output") 

123 

124 def __str__(self) -> str: 

125 return f"QRetriever(timeout={self.timeout}, retries={self.retries})" 

126 

127 

128class CommandGroup(BaseModel): 

129 """Holder for a group of commands.""" 

130 

131 name: str 

132 desc: Optional[str] = None 

133 cmds: OrderedDict[str, Command] = Field(default_factory=OrderedDict) 

134 timeout: int = Field(default=30) 

135 retries: int = Field(default=3) 

136 

137 async def run_async( 

138 self, 

139 strategy: ProcessingStrategy, 

140 callbacks: CommandAsyncCB, 

141 ): 

142 q = mp.Manager().Queue() 

143 pool = ProcessPoolExecutor() 

144 futs = [ 

145 asyncio.get_event_loop().run_in_executor(pool, run_command, cmd.name, cmd.cmd, cmd.setenv, q) 

146 for _, cmd in self.cmds.items() 

147 ] 

148 

149 for (_, cmd), fut in zip(self.cmds.items(), futs): 

150 cmd.fut = fut 

151 cmd.set_running() 

152 

153 return await self._process_q_async(q, strategy, callbacks) 

154 

155 def run(self, strategy: ProcessingStrategy, callbacks: CommandCB): 

156 q = mp.Manager().Queue() 

157 pool = ProcessPoolExecutor() 

158 futs = [pool.submit(run_command, cmd.name, cmd.cmd, cmd.setenv, q) for _, cmd in self.cmds.items()] 

159 for (_, cmd), fut in zip(self.cmds.items(), futs): 

160 cmd.fut = fut 

161 cmd.set_running() 

162 return self._process_q(q, strategy, callbacks) 

163 

164 def _process_q( 

165 self, 

166 q: Queue, 

167 strategy: ProcessingStrategy, 

168 callbacks: CommandCB, 

169 ) -> int: 

170 grp_exit_code = 0 

171 

172 if strategy == ProcessingStrategy.ON_RECV: 

173 for _, cmd in self.cmds.items(): 

174 callbacks.on_start(cmd) 

175 

176 q_ret = QRetriever(q, self.timeout, self.retries) 

177 while True: 

178 q_result = q_ret.get() 

179 

180 # Can only get here with a valid message from the Q 

181 cmd_name = q_result[0] 

182 exit_code: Optional[int] = q_result[1] if isinstance(q_result[1], int) else None 

183 output_line: Optional[str] = q_result[1] if isinstance(q_result[1], str) else None 

184 if exit_code is None and output_line is None: 

185 raise ValueError("Invalid Q message") # pragma: no cover 

186 

187 cmd = self.cmds[cmd_name] 

188 if strategy == ProcessingStrategy.ON_RECV: 

189 if output_line is not None: 

190 cmd.incr_line_count(output_line) 

191 callbacks.on_recv(cmd, output_line) 

192 elif exit_code is not None: 

193 cmd.set_ret_code(exit_code) 

194 callbacks.on_term(cmd, exit_code) 

195 if exit_code != 0: 

196 grp_exit_code = 1 

197 else: 

198 raise ValueError("Invalid Q message") # pragma: no cover 

199 

200 if strategy == ProcessingStrategy.ON_COMP: 

201 if output_line is not None: 

202 cmd.incr_line_count(output_line) 

203 cmd.append_unflushed(output_line) 

204 elif exit_code is not None: 

205 callbacks.on_start(cmd) 

206 for line in cmd.unflushed: 

207 callbacks.on_recv(cmd, line) 

208 cmd.clear_unflushed() 

209 callbacks.on_term(cmd, exit_code) 

210 cmd.set_ret_code(exit_code) 

211 if exit_code != 0: 

212 grp_exit_code = 1 

213 else: 

214 raise ValueError("Invalid Q message") # pragma: no cover 

215 

216 if all(cmd.status.completed() for _, cmd in self.cmds.items()): 

217 break 

218 return grp_exit_code 

219 

220 async def _process_q_async( 

221 self, 

222 q: Queue, 

223 strategy: ProcessingStrategy, 

224 callbacks: CommandAsyncCB, 

225 ) -> int: 

226 grp_exit_code = 0 

227 

228 if strategy == ProcessingStrategy.ON_RECV: 

229 for _, cmd in self.cmds.items(): 

230 await callbacks.on_start(cmd) 

231 

232 q_ret = QRetriever(q, self.timeout, self.retries) 

233 while True: 

234 await asyncio.sleep(0) 

235 q_result = q_ret.get() 

236 

237 # Can only get here with a valid message from the Q 

238 cmd_name = q_result[0] 

239 exit_code: Optional[int] = q_result[1] if isinstance(q_result[1], int) else None 

240 output_line: Optional[str] = q_result[1] if isinstance(q_result[1], str) else None 

241 if exit_code is None and output_line is None: 

242 raise ValueError("Invalid Q message") # pragma: no cover 

243 

244 cmd = self.cmds[cmd_name] 

245 if strategy == ProcessingStrategy.ON_RECV: 

246 if output_line is not None: 

247 cmd.incr_line_count(output_line) 

248 await callbacks.on_recv(cmd, output_line) 

249 elif exit_code is not None: 

250 cmd.set_ret_code(exit_code) 

251 await callbacks.on_term(cmd, exit_code) 

252 if exit_code != 0: 

253 grp_exit_code = 1 

254 else: 

255 raise ValueError("Invalid Q message") # pragma: no cover 

256 

257 if strategy == ProcessingStrategy.ON_COMP: 

258 if output_line is not None: 

259 cmd.incr_line_count(output_line) 

260 cmd.append_unflushed(output_line) 

261 elif exit_code is not None: 

262 await callbacks.on_start(cmd) 

263 for line in cmd.unflushed: 

264 await callbacks.on_recv(cmd, line) 

265 cmd.clear_unflushed() 

266 await callbacks.on_term(cmd, exit_code) 

267 cmd.set_ret_code(exit_code) 

268 if exit_code != 0: 

269 grp_exit_code = 1 

270 else: 

271 raise ValueError("Invalid Q message") # pragma: no cover 

272 

273 if all(cmd.status.completed() for _, cmd in self.cmds.items()): 

274 break 

275 return grp_exit_code 

276 

277 

278def read_commands_ini(filename: Union[str, Path]) -> list[CommandGroup]: 

279 """Read a commands.ini file and return a list of CommandGroup objects. 

280 

281 Args: 

282 filename (Union[str, Path]): The filename of the commands.ini file. 

283 

284 Returns: 

285 list[CommandGroup]: A list of CommandGroup objects. 

286 """ 

287 config = configparser.ConfigParser() 

288 config.read(filename) 

289 

290 command_groups = [] 

291 for section in config.sections(): 

292 if section.startswith("group."): 

293 group_name = section.replace("group.", "") 

294 commands = OrderedDict() 

295 for name, cmd in config.items(section): 

296 name = name.strip() 

297 commands[name] = Command(name=name, cmd=cmd.strip()) 

298 command_group = CommandGroup(name=group_name, cmds=commands) 

299 command_groups.append(command_group) 

300 

301 return command_groups 

302 

303 

304def write_commands_ini(filename: Union[str, Path], command_groups: list[CommandGroup]): 

305 """Write a list of CommandGroup objects to a commands.ini file. 

306 

307 Args: 

308 filename (Union[str, Path]): The filename of the commands.ini file. 

309 command_groups (list[CommandGroup]): A list of CommandGroup objects. 

310 """ 

311 config = configparser.ConfigParser() 

312 

313 for group in command_groups: 

314 section_name = f"group.{group.name}" 

315 config[section_name] = {} 

316 for _, command in group.cmds.items(): 

317 config[section_name][command.name] = command.cmd 

318 

319 with open(filename, "w", encoding="utf-8") as configfile: 

320 config.write(configfile) 

321 

322 

323def _validate_mandatory_keys(data: tomlkit.items.Table, keys: list[str], context: str) -> tuple[Any, ...]: 

324 """Validate that the mandatory keys are present in the data. 

325 

326 Args: 

327 data (tomlkit.items.Table): The data to validate. 

328 keys (list[str]): The mandatory keys. 

329 """ 

330 vals = [] 

331 for key in keys: 

332 val = data.get(key, None) 

333 if not val: 

334 raise ValueError(f"{key} is mandatory, not found in {context}") 

335 vals.append(val) 

336 return tuple(vals) 

337 

338 

339def _get_optional_keys(data: tomlkit.items.Table, keys: list[str], default=None) -> tuple[Optional[Any], ...]: 

340 """Get Optional keys or default. 

341 

342 Args: 

343 data (tomlkit.items.Table): The data to use as source 

344 keys (list[str]): The optional keys. 

345 """ 

346 return tuple(data.get(key, default) for key in keys) 

347 

348 

349def read_commands_toml(filename: Union[str, Path]) -> list[CommandGroup]: 

350 """Read a commands.toml file and return a list of CommandGroup objects. 

351 

352 Args: 

353 filename (Union[str, Path]): The filename of the commands.toml file. 

354 

355 Returns: 

356 list[CommandGroup]: A list of CommandGroup objects. 

357 """ 

358 

359 with open(filename, "r", encoding="utf-8") as toml_file: 

360 toml_data = tomlkit.parse(toml_file.read()) 

361 

362 cmd_groups_data = toml_data.get("tool", {}).get("par-run", {}) 

363 if not cmd_groups_data: 

364 raise ValueError("No par-run data found in toml file") 

365 _ = cmd_groups_data.get("description", None) 

366 

367 command_groups = [] 

368 for group_data in cmd_groups_data.get("groups", []): 

369 (group_name,) = _validate_mandatory_keys(group_data, ["name"], "top level par-run group") 

370 group_desc, group_timeout, group_retries = _get_optional_keys( 

371 group_data, ["desc", "timeout", "retries"], default=None 

372 ) 

373 

374 if not group_timeout: 

375 group_timeout = 30 

376 if not group_retries: 

377 group_retries = 3 

378 

379 commands = OrderedDict() 

380 for cmd_data in group_data.get("commands", []): 

381 name, exec = _validate_mandatory_keys(cmd_data, ["name", "exec"], f"command group {group_name}") 

382 setenv, passenv = _get_optional_keys(cmd_data, ["setenv", "passenv"], default=None) 

383 

384 commands[name] = Command(name=name, cmd=exec, setenv=setenv, passenv=passenv) 

385 command_group = CommandGroup( 

386 name=group_name, desc=group_desc, cmds=commands, timeout=group_timeout, retries=group_retries 

387 ) 

388 command_groups.append(command_group) 

389 

390 return command_groups 

391 

392 

393def run_command(name: str, command: str, setenv: Optional[dict[str, str]], q: Queue) -> None: 

394 """Run a command and put the output into a queue. The output is a tuple of the command 

395 name and the output line. The final output is a tuple of the command name and a dictionary 

396 with the return code. 

397 

398 Args: 

399 name (Command): Command to run. 

400 q (Queue): Queue to put the output into. 

401 """ 

402 

403 new_env = None 

404 if setenv: 

405 new_env = os.environ.copy() 

406 new_env.update(setenv) 

407 

408 with subprocess.Popen( 

409 command, 

410 shell=True, 

411 env=new_env, 

412 stdout=subprocess.PIPE, 

413 stderr=subprocess.STDOUT, 

414 text=True, 

415 ) as process: 

416 if process.stdout: 

417 for line in iter(process.stdout.readline, ""): 

418 q.put((name, line.strip())) 

419 process.stdout.close() 

420 process.wait() 

421 ret_code = process.returncode 

422 if ret_code is not None: 

423 q.put((name, int(ret_code))) 

424 else: 

425 raise ValueError("Process has no return code") # pragma: no cover