Source code for

"""Definition of `Job` and `Result` classes used to encapsulate an experiment
and the corresponding outcomes.

import enum
import importlib
import os
import pickle
import re
import subprocess
import sys
import tempfile
import time
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Union

__all__ = [

# Global registries to control the job and result id assignment

def reset_registry():
    """Reset the global job and result registries.

        This function should be used with care as it will allow for jobs with
        repeating IDs to be created. As a consequence, two or more
        :class:`Result` objects might coexist end make the actual experiment
        outcome ambiguous.
    global _ID_COUNTER
    _ID_COUNTER = -1

def generate_id():
    """Generate a new, unused integer job id."""
    global _ID_COUNTER
    _ID_COUNTER += 1
    return _ID_COUNTER

def import_script(path):
    """Import a module or script by a given path.

        path: :obj:`str`, can be either a module import of the form
            [package.]*[module] if the outer most package is in the
            `PYTHONPATH`, or a path to an arbitrary python script.

        The loaded python script as a module.
        module = importlib.import_module(path)
    except ModuleNotFoundError:
        if not os.path.isfile(path):
            raise FileNotFoundError(f"Cannot find script {path}.")
        if not os.path.basename(path).endswith(".py"):
            raise ValueError(

                f"Expected a python script ending with *.py, "
                f"found {os.path.basename(path)}.")
        import_path = os.path.dirname(os.path.abspath(path))
        module = importlib.import_module(
    return module

def run_command(cmd: List[str]) -> str:
    """Execute a command in the shell.

        cmd: :obj:`List[str]`. The command with its arguments to execute.

        The standard output of the command.

        :obj:`OSError`: if the standard error stream is not empty.
    ps =, capture_output=True)
    if ps.stderr:
        raise OSError(f"Failed running {' '.join(cmd)} with error message: "
    return ps.stdout.decode("utf-8")

def get_callable_from_script(script_path: str, func_name: str = "main") -> Callable:
    """Convert a module to a callable function and call the `main` function of
    the module.

        script_path: str, the file path to the python script to run. It can
            either be given as a module i.e. in the [package.]*[module] form,
            or as a path to a *.py file in case it is not added into the
            PYTHONPATH environment variable.
        func_name: str, the name of the function to run.

        The wrapper which calls a function from the script module.

          `AttributeError` if the script does not define a `func_name` function.

    def wrapper(*args):
        module = import_script(script_path)
        if not hasattr(module, func_name):
            raise AttributeError(
                f"Cannot find {func_name} function in {script_path}."
        return getattr(module, func_name)(*args)

    return wrapper

def run_script_with_args(binary: str, script_path: str, *args: Any, **kwargs: Any):
    """Run script using a binary and command line arguments.

        binary: str, the binary to run the script with, e.g. 'python'.
        script_path: str, the path to the script.
        *args: Any, a collection of arguments which will be converted to string
            and passed on to the run command.
        **kwargs: Any, keyword arguments which will be converted to named script

        The contents of the results, which the script is assumed to store,
        given an output file path as an argument.

        FileNotFoundError if the script cannot be found.

        It assumes that the script will store the results on disk using the
        path provided by the last of the command line arguments.
    if not os.path.isfile(script_path):
        raise FileNotFoundError(f"Cannot find script {script_path}.")
    with tempfile.TemporaryDirectory() as tmpdir:
        output_file = os.path.join(tmpdir, "results.pkl")
        args_as_str, kwargs_as_str = [], []
        if args:
            args_as_str.extend([*map(str, args), output_file])
        if kwargs:
                str(item) for k_v in kwargs.items() for item in k_v
            kwargs_as_str.extend(["--output_file", output_file])
        run_command([binary, script_path, *args_as_str, *kwargs_as_str])
        return fetch_result(output_file)

def fetch_result(output_file, n_trials: int = 5, waiting_time: float = 1.0) -> Any:
    """Load the output file.

        output_file: str, a path to the output file.
        n_trials: int, optional number of trials to load the file, afterwards a
            None is returned.
        waiting_time: float, time in seconds to wait before retrying to load
            the file.

        The unpickled output file if found, else None.
    if output_file is None:
        return None
    for _ in range(n_trials):
        if os.path.isfile(output_file):
        return None
    with open(output_file, 'rb') as fp:
        return pickle.load(fp)

[docs]@dataclass(frozen=True) class Job: """Default :class:`Job` class defining an experiment as a runnable task on the local machine. The job is defined by a callable function or a script task. In the case of the former the `args` will be passed directly to it upon calling. Otherwise either a module will be run as a scirpt with command line arguments or a function, attribute of the module, will be called with the `args` as input. In both cases a :class:`Result` object will be returned. Attributes: id: :obj:`int`. The job identifier. Must be unique. args: :obj:`tuple` or :obj:`dict`. The arguments or keyword arguments for the callable function or script. task: :obj:`Callable` or :obj:`str`, a python function to run or a file path to a python script. """ task: Union[Callable, str] args: Union[Tuple, Dict] = () id: int = field(default_factory=generate_id) meta: Any = None # job related constants _JOB_SCRIPT_FUNC_SEPARATOR = ":" _JOB_DEFAULT_BINARY = "source" _JOB_SCRIPT_FUNC_SEPARATION_REGEX = r"[^\w\/\.]+" def __post_init__(self): if not isinstance(self.task, (Callable, str)): raise ValueError( "Job's task must be either a callable function " "or a path to a script." ) if in _JOB_REGISTRY: raise ValueError( f"Job with an ID {} is already created. " f"Reusing IDs is prohibited." ) _JOB_REGISTRY.add( def __hash__(self): return hash(str( def __call__(self, *args, **kwargs) -> 'Result': all_args = args all_kwargs = kwargs if isinstance(self.args, Tuple): all_args += self.args else: all_kwargs = dict(**kwargs, **self.args) if isinstance(self.task, Callable): runnable = self.task else: runnable = self._build_callable() return Result(, data=runnable(*all_args, **all_kwargs)) def _build_callable(self): """Create a function from a string task. If the task is of the form /path/to/, split the path from the func and return a script.func_to_run callable. If the task is of the form /path/to/, then return a python /path/to/ callable. """ if self._JOB_SCRIPT_FUNC_SEPARATOR in self.task: # split the task string by the [:]+ marker script_path, func_name = re.split( self._JOB_SCRIPT_FUNC_SEPARATION_REGEX, self.task ) assert script_path and func_name, \ f"Empty path {script_path} or function name {func_name}" runnable = get_callable_from_script(script_path, func_name) else: binary = self._infer_binary() runnable = partial(run_script_with_args, binary, self.task) return runnable def _infer_binary(self): if isinstance(self.meta, dict) and "binary" in self.meta: return self.meta["binary"] if self.task.endswith(".py"): return "python" if self.task.endswith(".sh"): return "bash" return self._JOB_DEFAULT_BINARY
class SlurmJobState(enum.Enum): """Some of the most frequently encountered slurm job statuses.""" PENDING = 0 RUNNING = 1 COMPLETED = 2 FAILED = 3 CANCELLED = 4 UNKNOWN = 5 @classmethod def from_string(cls, state: str): if state == "running": return cls.RUNNING if state == "pending": return cls.PENDING if state == "completed": return cls.COMPLETED if state == "failed": return cls.FAILED if state == "cancelled": return cls.CANCELLED return cls.UNKNOWN
[docs]@dataclass(frozen=True) class SlurmJob(Job): """A :class:`Job` subclass to schedule tasks on Slurm. Runs an 'sbatch' command in the shell with the script. Attributes: output_file: (optional) :obj:`str`. Path to the file where the executed script will dump the result file. If none is provided, a temporary file will be created. """ output_file: str = None # slurm shell commands _SLURM_CMD_PUSH = ["sbatch"] _SLURM_CMD_KILL = ["scancel"] _SLURM_CMD_INFO = ["scontrol", "show", "job"] # slurm script elements _SLURM_SCRIPT_PREAMBLE = "#!/bin/bash" _SLURM_SCRIPT_LINE_PREFIX = "#SBATCH" _SLURM_SCRIPT_JOB_NAME = "--job-name" _SLURM_SCRIPT_OUT_NAME = "--output" _SLURM_SCRIPT_RESOURCES_MEM = "--mem" _SLURM_SCRIPT_RESOURCES_TIME = "--time" _SLURM_SCRIPT_RESOURCES_CPU = "--cpus-per-task" _SLURM_SCRIPT_RESOURCES_GPU = "--gres" # other macros _SLURM_JOB_STATE_REGEX = r"JobState=(RUNNING|PENDING|COMPLETED|FAILED|CANCELLED)" def __post_init__(self): if not isinstance(self.task, str): raise ValueError("Slurm job must be defined with a script to run.") super(SlurmJob, self).__post_init__() def __call__(self) -> 'Result': res = self._execute_job() return Result(, data=res) def _execute_job(self) -> Any: with tempfile.NamedTemporaryFile(mode="w+t", suffix=".sh") as fp: contents = self._create_slurm_script() fp.writelines(contents) response = run_command(self._SLURM_CMD_PUSH + [f"{}"]) slurm_id = int("[\d]+", response).group()) while True: slurm_status = self._query_job_status(slurm_id) if slurm_status in [SlurmJobState.RUNNING, SlurmJobState.PENDING]: time.sleep(1) elif slurm_status in [SlurmJobState.CANCELLED, SlurmJobState.FAILED]: return None elif slurm_status == SlurmJobState.COMPLETED: return fetch_result(self.output_file) else: raise RuntimeError(f"Unknown state of slurm job {slurm_id}.") def _create_slurm_script(self) -> List[str]: if not self.meta: raise ValueError(f"Cannot infer slurm job parameters. " f"Fill in meta dict in job {}.") else: # Preamble, job name and output log filename definitions content_lines = [ f"{self._SLURM_SCRIPT_PREAMBLE}\n", f"{self._SLURM_SCRIPT_LINE_PREFIX} " f"{self._SLURM_SCRIPT_JOB_NAME}=job_{}\n", f"{self._SLURM_SCRIPT_LINE_PREFIX} " f"{self._SLURM_SCRIPT_OUT_NAME}=log_%j.txt\n"] # Resources specification n_cpus = int(self.meta.get("resources", {}).get("cpu", 1)) if n_cpus >= 1: content_lines.append( f"{self._SLURM_SCRIPT_LINE_PREFIX} " f"{self._SLURM_SCRIPT_RESOURCES_CPU}={n_cpus}\n" ) gpus = str(self.meta.get("resources", {}).get("gpu", "")) if gpus: if gpus.isnumeric(): gpus = f"gpu:{gpus}" content_lines.append( f"{self._SLURM_SCRIPT_LINE_PREFIX} " f"{self._SLURM_SCRIPT_RESOURCES_GPU}={gpus}\n" ) mem = str(self.meta.get("resources", {}).get("memory", "")) if mem: content_lines.append( f"{self._SLURM_SCRIPT_LINE_PREFIX} " f"{self._SLURM_SCRIPT_RESOURCES_MEM}={mem}\n" ) limit_time = str(self.meta.get("resources", {}).get("time", "")) if limit_time: content_lines.append( f"{self._SLURM_SCRIPT_LINE_PREFIX} " f"{self._SLURM_SCRIPT_RESOURCES_TIME}={limit_time}\n" ) # Task specification binary = self.meta.get("binary", "python") if isinstance(self.args, Tuple): # build positional arguments script_args = ' '.join([*map(str, self.args), self.output_file]) else: # build named arguments script_args = ' '.join([ *(str(item) for key_val in self.args.items() for item in key_val), "--output_file", self.output_file ]) content_lines.append(f"{binary} {self.task} {script_args}") return content_lines def _query_job_status(self, slurm_id: int) -> SlurmJobState: response = run_command(self._SLURM_CMD_INFO + [str(slurm_id)]) job_state =, response) if job_state is not None: job_state = return SlurmJobState.from_string(job_state)
[docs]@dataclass(frozen=True) class Result: """A :class:`Result` class to store the output of the executed :class:`Job`. It shares the same id as the job which generated it. Attributes: id: :obj:`int`. The identifier of the `Result` object which corresponds to the job that has been run. data: :obj:`Any`. The output data of the job. """ data: Any id: int def __post_init__(self): if in _RESULT_REGISTRY: raise ValueError( f"Result with an ID {} is already created. " f"Reusing IDs is prohibited." ) _RESULT_REGISTRY.add(