Coverage for src/par_run/executor.py: 85%
251 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-12 08:32 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-12 08:32 -0400
1"""Todo"""
3import asyncio
4import configparser
5import enum
6import multiprocessing as mp
7import os
8import queue
9import subprocess
10from collections import OrderedDict
11from concurrent.futures import Future, ProcessPoolExecutor
12from pathlib import Path
13from queue import Queue
14import time
15from typing import List, Optional, Protocol, TypeVar, Union, Any
16from pydantic import BaseModel, Field, ConfigDict
17import tomlkit
19# Type alias for a generic future.
20GenFuture = Union[Future, asyncio.Future]
22ContextT = TypeVar("ContextT")
25class ProcessingStrategy(enum.Enum):
26 """Enum for processing strategies."""
28 ON_COMP = "comp"
29 ON_RECV = "recv"
32class CommandStatus(enum.Enum):
33 """Enum for command status."""
35 NOT_STARTED = "Not Started"
36 RUNNING = "Running"
37 SUCCESS = "Success"
38 FAILURE = "Failure"
40 def completed(self) -> bool:
41 """Return True if the command has completed."""
42 return self in [CommandStatus.SUCCESS, CommandStatus.FAILURE]
45class Command(BaseModel):
46 """Holder for a command and its name."""
48 model_config = ConfigDict(arbitrary_types_allowed=True)
50 name: str
51 cmd: str
52 passenv: Optional[list[str]] = Field(default=None)
53 setenv: Optional[dict[str, str]] = Field(default=None)
54 status: CommandStatus = CommandStatus.NOT_STARTED
55 unflushed: List[str] = Field(default=[], exclude=True)
56 num_non_empty_lines: int = Field(default=0, exclude=True)
57 ret_code: Optional[int] = Field(default=None, exclude=True)
58 fut: Optional[GenFuture] = Field(default=None, exclude=True)
59 start_time: Optional[float] = Field(default=None, exclude=True)
60 elapsed: Optional[float] = Field(default=None, exclude=True)
62 def incr_line_count(self, line: str) -> None:
63 """Increment the non-empty line count."""
64 if line.strip():
65 self.num_non_empty_lines += 1
67 def append_unflushed(self, line: str) -> None:
68 """Append a line to the output and increment the non-empty line count."""
69 self.unflushed.append(line)
71 def clear_unflushed(self) -> None:
72 """Clear the unflushed output."""
73 self.unflushed.clear()
75 def set_ret_code(self, ret_code: int):
76 """Set the return code and status of the command."""
77 if self.start_time:
78 self.elapsed = time.perf_counter() - self.start_time
79 self.ret_code = ret_code
80 if self.fut:
81 self.fut.cancel()
82 self.fut = None
83 if ret_code == 0:
84 self.status = CommandStatus.SUCCESS
85 else:
86 self.status = CommandStatus.FAILURE
88 def set_running(self):
89 """Set the command status to running."""
90 self.start_time = time.perf_counter()
91 self.status = CommandStatus.RUNNING
94class CommandCB(Protocol):
95 def on_start(self, cmd: Command) -> None: ...
96 def on_recv(self, cmd: Command, output: str) -> None: ...
97 def on_term(self, cmd: Command, exit_code: int) -> None: ...
100class CommandAsyncCB(Protocol):
101 async def on_start(self, cmd: Command) -> None: ...
102 async def on_recv(self, cmd: Command, output: str) -> None: ...
103 async def on_term(self, cmd: Command, exit_code: int) -> None: ...
106class QRetriever:
107 def __init__(self, q: Queue, timeout: int, retries: int):
108 self.q = q
109 self.timeout = timeout
110 self.retries = retries
112 def get(self):
113 retry_count = 0
114 while True:
115 try:
116 return self.q.get(block=True, timeout=self.timeout)
117 except queue.Empty:
118 if retry_count < self.retries:
119 retry_count += 1
120 continue
121 else:
122 raise TimeoutError("Timeout waiting for command output")
124 def __str__(self) -> str:
125 return f"QRetriever(timeout={self.timeout}, retries={self.retries})"
128class CommandGroup(BaseModel):
129 """Holder for a group of commands."""
131 name: str
132 desc: Optional[str] = None
133 cmds: OrderedDict[str, Command] = Field(default_factory=OrderedDict)
134 timeout: int = Field(default=30)
135 retries: int = Field(default=3)
137 async def run_async(
138 self,
139 strategy: ProcessingStrategy,
140 callbacks: CommandAsyncCB,
141 ):
142 q = mp.Manager().Queue()
143 pool = ProcessPoolExecutor()
144 futs = [
145 asyncio.get_event_loop().run_in_executor(pool, run_command, cmd.name, cmd.cmd, cmd.setenv, q)
146 for _, cmd in self.cmds.items()
147 ]
149 for (_, cmd), fut in zip(self.cmds.items(), futs):
150 cmd.fut = fut
151 cmd.set_running()
153 return await self._process_q_async(q, strategy, callbacks)
155 def run(self, strategy: ProcessingStrategy, callbacks: CommandCB):
156 q = mp.Manager().Queue()
157 pool = ProcessPoolExecutor()
158 futs = [pool.submit(run_command, cmd.name, cmd.cmd, cmd.setenv, q) for _, cmd in self.cmds.items()]
159 for (_, cmd), fut in zip(self.cmds.items(), futs):
160 cmd.fut = fut
161 cmd.set_running()
162 return self._process_q(q, strategy, callbacks)
164 def _process_q(
165 self,
166 q: Queue,
167 strategy: ProcessingStrategy,
168 callbacks: CommandCB,
169 ) -> int:
170 grp_exit_code = 0
172 if strategy == ProcessingStrategy.ON_RECV:
173 for _, cmd in self.cmds.items():
174 callbacks.on_start(cmd)
176 q_ret = QRetriever(q, self.timeout, self.retries)
177 while True:
178 q_result = q_ret.get()
180 # Can only get here with a valid message from the Q
181 cmd_name = q_result[0]
182 exit_code: Optional[int] = q_result[1] if isinstance(q_result[1], int) else None
183 output_line: Optional[str] = q_result[1] if isinstance(q_result[1], str) else None
184 if exit_code is None and output_line is None:
185 raise ValueError("Invalid Q message") # pragma: no cover
187 cmd = self.cmds[cmd_name]
188 if strategy == ProcessingStrategy.ON_RECV:
189 if output_line is not None:
190 cmd.incr_line_count(output_line)
191 callbacks.on_recv(cmd, output_line)
192 elif exit_code is not None:
193 cmd.set_ret_code(exit_code)
194 callbacks.on_term(cmd, exit_code)
195 if exit_code != 0:
196 grp_exit_code = 1
197 else:
198 raise ValueError("Invalid Q message") # pragma: no cover
200 if strategy == ProcessingStrategy.ON_COMP:
201 if output_line is not None:
202 cmd.incr_line_count(output_line)
203 cmd.append_unflushed(output_line)
204 elif exit_code is not None:
205 callbacks.on_start(cmd)
206 for line in cmd.unflushed:
207 callbacks.on_recv(cmd, line)
208 cmd.clear_unflushed()
209 callbacks.on_term(cmd, exit_code)
210 cmd.set_ret_code(exit_code)
211 if exit_code != 0:
212 grp_exit_code = 1
213 else:
214 raise ValueError("Invalid Q message") # pragma: no cover
216 if all(cmd.status.completed() for _, cmd in self.cmds.items()):
217 break
218 return grp_exit_code
220 async def _process_q_async(
221 self,
222 q: Queue,
223 strategy: ProcessingStrategy,
224 callbacks: CommandAsyncCB,
225 ) -> int:
226 grp_exit_code = 0
228 if strategy == ProcessingStrategy.ON_RECV:
229 for _, cmd in self.cmds.items():
230 await callbacks.on_start(cmd)
232 q_ret = QRetriever(q, self.timeout, self.retries)
233 while True:
234 await asyncio.sleep(0)
235 q_result = q_ret.get()
237 # Can only get here with a valid message from the Q
238 cmd_name = q_result[0]
239 exit_code: Optional[int] = q_result[1] if isinstance(q_result[1], int) else None
240 output_line: Optional[str] = q_result[1] if isinstance(q_result[1], str) else None
241 if exit_code is None and output_line is None:
242 raise ValueError("Invalid Q message") # pragma: no cover
244 cmd = self.cmds[cmd_name]
245 if strategy == ProcessingStrategy.ON_RECV:
246 if output_line is not None:
247 cmd.incr_line_count(output_line)
248 await callbacks.on_recv(cmd, output_line)
249 elif exit_code is not None:
250 cmd.set_ret_code(exit_code)
251 await callbacks.on_term(cmd, exit_code)
252 if exit_code != 0:
253 grp_exit_code = 1
254 else:
255 raise ValueError("Invalid Q message") # pragma: no cover
257 if strategy == ProcessingStrategy.ON_COMP:
258 if output_line is not None:
259 cmd.incr_line_count(output_line)
260 cmd.append_unflushed(output_line)
261 elif exit_code is not None:
262 await callbacks.on_start(cmd)
263 for line in cmd.unflushed:
264 await callbacks.on_recv(cmd, line)
265 cmd.clear_unflushed()
266 await callbacks.on_term(cmd, exit_code)
267 cmd.set_ret_code(exit_code)
268 if exit_code != 0:
269 grp_exit_code = 1
270 else:
271 raise ValueError("Invalid Q message") # pragma: no cover
273 if all(cmd.status.completed() for _, cmd in self.cmds.items()):
274 break
275 return grp_exit_code
278def read_commands_ini(filename: Union[str, Path]) -> list[CommandGroup]:
279 """Read a commands.ini file and return a list of CommandGroup objects.
281 Args:
282 filename (Union[str, Path]): The filename of the commands.ini file.
284 Returns:
285 list[CommandGroup]: A list of CommandGroup objects.
286 """
287 config = configparser.ConfigParser()
288 config.read(filename)
290 command_groups = []
291 for section in config.sections():
292 if section.startswith("group."):
293 group_name = section.replace("group.", "")
294 commands = OrderedDict()
295 for name, cmd in config.items(section):
296 name = name.strip()
297 commands[name] = Command(name=name, cmd=cmd.strip())
298 command_group = CommandGroup(name=group_name, cmds=commands)
299 command_groups.append(command_group)
301 return command_groups
304def write_commands_ini(filename: Union[str, Path], command_groups: list[CommandGroup]):
305 """Write a list of CommandGroup objects to a commands.ini file.
307 Args:
308 filename (Union[str, Path]): The filename of the commands.ini file.
309 command_groups (list[CommandGroup]): A list of CommandGroup objects.
310 """
311 config = configparser.ConfigParser()
313 for group in command_groups:
314 section_name = f"group.{group.name}"
315 config[section_name] = {}
316 for _, command in group.cmds.items():
317 config[section_name][command.name] = command.cmd
319 with open(filename, "w", encoding="utf-8") as configfile:
320 config.write(configfile)
323def _validate_mandatory_keys(data: tomlkit.items.Table, keys: list[str], context: str) -> tuple[Any, ...]:
324 """Validate that the mandatory keys are present in the data.
326 Args:
327 data (tomlkit.items.Table): The data to validate.
328 keys (list[str]): The mandatory keys.
329 """
330 vals = []
331 for key in keys:
332 val = data.get(key, None)
333 if not val:
334 raise ValueError(f"{key} is mandatory, not found in {context}")
335 vals.append(val)
336 return tuple(vals)
339def _get_optional_keys(data: tomlkit.items.Table, keys: list[str], default=None) -> tuple[Optional[Any], ...]:
340 """Get Optional keys or default.
342 Args:
343 data (tomlkit.items.Table): The data to use as source
344 keys (list[str]): The optional keys.
345 """
346 return tuple(data.get(key, default) for key in keys)
349def read_commands_toml(filename: Union[str, Path]) -> list[CommandGroup]:
350 """Read a commands.toml file and return a list of CommandGroup objects.
352 Args:
353 filename (Union[str, Path]): The filename of the commands.toml file.
355 Returns:
356 list[CommandGroup]: A list of CommandGroup objects.
357 """
359 with open(filename, "r", encoding="utf-8") as toml_file:
360 toml_data = tomlkit.parse(toml_file.read())
362 cmd_groups_data = toml_data.get("tool", {}).get("par-run", {})
363 if not cmd_groups_data:
364 raise ValueError("No par-run data found in toml file")
365 _ = cmd_groups_data.get("description", None)
367 command_groups = []
368 for group_data in cmd_groups_data.get("groups", []):
369 (group_name,) = _validate_mandatory_keys(group_data, ["name"], "top level par-run group")
370 group_desc, group_timeout, group_retries = _get_optional_keys(
371 group_data, ["desc", "timeout", "retries"], default=None
372 )
374 if not group_timeout:
375 group_timeout = 30
376 if not group_retries:
377 group_retries = 3
379 commands = OrderedDict()
380 for cmd_data in group_data.get("commands", []):
381 name, exec = _validate_mandatory_keys(cmd_data, ["name", "exec"], f"command group {group_name}")
382 setenv, passenv = _get_optional_keys(cmd_data, ["setenv", "passenv"], default=None)
384 commands[name] = Command(name=name, cmd=exec, setenv=setenv, passenv=passenv)
385 command_group = CommandGroup(
386 name=group_name, desc=group_desc, cmds=commands, timeout=group_timeout, retries=group_retries
387 )
388 command_groups.append(command_group)
390 return command_groups
393def run_command(name: str, command: str, setenv: Optional[dict[str, str]], q: Queue) -> None:
394 """Run a command and put the output into a queue. The output is a tuple of the command
395 name and the output line. The final output is a tuple of the command name and a dictionary
396 with the return code.
398 Args:
399 name (Command): Command to run.
400 q (Queue): Queue to put the output into.
401 """
403 new_env = None
404 if setenv:
405 new_env = os.environ.copy()
406 new_env.update(setenv)
408 with subprocess.Popen(
409 command,
410 shell=True,
411 env=new_env,
412 stdout=subprocess.PIPE,
413 stderr=subprocess.STDOUT,
414 text=True,
415 ) as process:
416 if process.stdout:
417 for line in iter(process.stdout.readline, ""):
418 q.put((name, line.strip()))
419 process.stdout.close()
420 process.wait()
421 ret_code = process.returncode
422 if ret_code is not None:
423 q.put((name, int(ret_code)))
424 else:
425 raise ValueError("Process has no return code") # pragma: no cover