From 7674e7d362c95e665a9186172914bb79ea7e209d Mon Sep 17 00:00:00 2001 From: Ben Schroeter Date: Wed, 17 Jul 2024 14:52:58 +1000 Subject: [PATCH] Added job dependencies and directives support. Fixes #14 and #10 --- hpcpy/__init__.py | 4 +- hpcpy/_version.py | 157 ++++++++++++++++++------------- hpcpy/client.py | 194 ++++++++++++++++++++++++++++++--------- hpcpy/constants.py | 41 +++++---- hpcpy/exceptions.py | 4 +- hpcpy/utilities.py | 40 ++++---- tests/test_client.py | 36 +++++--- tests/test_hpcpy.py | 24 ----- tests/test_pbs_client.py | 64 +++++++++++++ tests/test_utilities.py | 3 +- 10 files changed, 382 insertions(+), 185 deletions(-) delete mode 100644 tests/test_hpcpy.py create mode 100644 tests/test_pbs_client.py diff --git a/hpcpy/__init__.py b/hpcpy/__init__.py index 9564b5a..6cbface 100644 --- a/hpcpy/__init__.py +++ b/hpcpy/__init__.py @@ -1,3 +1,5 @@ """Top-level package for hpcpy.""" + from . import _version -__version__ = _version.get_versions()['version'] + +__version__ = _version.get_versions()["version"] diff --git a/hpcpy/_version.py b/hpcpy/_version.py index 1795c4d..af95c7a 100644 --- a/hpcpy/_version.py +++ b/hpcpy/_version.py @@ -1,4 +1,3 @@ - # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -68,12 +67,14 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate @@ -100,10 +101,14 @@ def run_command( try: dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) + process = subprocess.Popen( + [command] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + **popen_kwargs, + ) break except OSError as e: if e.errno == errno.ENOENT: @@ -141,15 +146,21 @@ def versions_from_parentdir( for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -212,7 +223,7 @@ def git_versions_from_keywords( # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -221,7 +232,7 @@ def git_versions_from_keywords( # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} + tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -229,32 +240,36 @@ def git_versions_from_keywords( for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] # Filter out refs that exactly match prefix or that don't start # with a number once the prefix is stripped (mostly a concern # when prefix is '') - if not re.match(r'\d', r): + if not re.match(r"\d", r): continue if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") def git_pieces_from_vcs( - tag_prefix: str, - root: str, - verbose: bool, - runner: Callable = run_command + tag_prefix: str, root: str, verbose: bool, runner: Callable = run_command ) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. @@ -273,8 +288,7 @@ def git_pieces_from_vcs( env.pop("GIT_DIR", None) runner = functools.partial(runner, env=env) - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=not verbose) + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -282,10 +296,19 @@ def git_pieces_from_vcs( # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) + describe_out, rc = runner( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + f"{tag_prefix}[[:digit:]]*", + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -300,8 +323,7 @@ def git_pieces_from_vcs( pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) # --abbrev-ref was added in git-1.6.3 if rc != 0 or branch_name is None: raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") @@ -341,17 +363,16 @@ def git_pieces_from_vcs( dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -360,10 +381,12 @@ def git_pieces_from_vcs( if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -412,8 +435,7 @@ def render_pep440(pieces: Dict[str, Any]) -> str: rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -442,8 +464,7 @@ def render_pep440_branch(pieces: Dict[str, Any]) -> str: rendered = "0" if pieces["branch"] != "master": rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -604,11 +625,13 @@ def render_git_describe_long(pieces: Dict[str, Any]) -> str: def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -632,9 +655,13 @@ def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } def get_versions() -> Dict[str, Any]: @@ -648,8 +675,7 @@ def get_versions() -> Dict[str, Any]: verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass @@ -658,13 +684,16 @@ def get_versions() -> Dict[str, Any]: # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): + for _ in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None, + } try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -678,6 +707,10 @@ def get_versions() -> Dict[str, Any]: except NotThisMethod: pass - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } diff --git a/hpcpy/client.py b/hpcpy/client.py index 686b4fd..45629f0 100644 --- a/hpcpy/client.py +++ b/hpcpy/client.py @@ -1,19 +1,22 @@ """Abstract class for client implementation.""" + from hpcpy.utilities import shell, interpolate_file_template import hpcpy.constants as hc import hpcpy.exceptions as hx +import hpcpy.utilities as hu from pathlib import Path from random import choice from string import ascii_uppercase import os import json -import datetime +from datetime import datetime, timedelta import pandas as pd +from typing import Union class Client: - def __init__(self, tmp_submit, tmp_status, tmp_delete, job_script_expiry='1H'): + def __init__(self, tmp_submit, tmp_status, tmp_delete, job_script_expiry="1H"): # Set the command templates self._tmp_submit = tmp_submit @@ -32,7 +35,7 @@ def _clean_rendered_job_scripts(self): rendered_job_scripts = self.list_rendered_job_scripts() # Work out the threshold - now = datetime.datetime.now() + now = datetime.now() threshold = now - pd.to_timedelta(self.job_script_expiry).to_pytimedelta() for rjs in rendered_job_scripts: @@ -42,7 +45,7 @@ def _clean_rendered_job_scripts(self): continue # Get the modified time of the file, check threshold and delete - mod_time = datetime.datetime.fromtimestamp(os.path.getmtime(rjs)) + mod_time = datetime.fromtimestamp(os.path.getmtime(rjs)) if mod_time <= threshold: os.remove(rjs) @@ -56,10 +59,12 @@ def list_rendered_job_scripts(self): """ return [hc.JOB_SCRIPT_DIR / rjs for rjs in os.listdir(hc.JOB_SCRIPT_DIR)] - def submit(self, job_script, render=False, **context): + def submit( + self, job_script, directives=list(), render=False, dry_run=False, **context + ): """Submit the job script. - + Parameters ---------- job_script : path-like @@ -71,22 +76,26 @@ def submit(self, job_script, render=False, **context): """ if render: - - _job_script = self._render_job_script( - job_script, - **context - ) + + _job_script = self._render_job_script(job_script, **context) else: - + _job_script = job_script - context['job_script'] = _job_script + # Add the directives to the interpolation context (will return blank string if nothing there) + context["directives"] = self._render_directives(directives) + + context["job_script"] = _job_script cmd = self._tmp_submit.format(**context) + + # Just return the command string for the user without submitting + if dry_run: + return cmd + result = self._shell(cmd) return result - def status(self, job_id): """Check the status of a job. @@ -100,7 +109,6 @@ def status(self, job_id): result = self._shell(cmd) return result - def delete(self, job_id): """Delete/cancel a job. @@ -127,7 +135,6 @@ def is_queued(self, job_id): True if queued, False otherwise. """ return self.status(job_id) == hc.STATUS_QUEUED - def is_running(self, job_id): """Check if the job is running. @@ -144,7 +151,6 @@ def is_running(self, job_id): """ return self.status(job_id) == hc.STATUS_RUNNING - def _shell(self, cmd, decode=True): """Generic shell interface to capture exceptions. @@ -163,11 +169,10 @@ def _shell(self, cmd, decode=True): result = shell(cmd) if decode: - result = result.stdout.decode('utf8').strip() + result = result.stdout.decode("utf8").strip() return result - def _get_job_script_filename(self, filepath, hash_length=8) -> str: """Generate a script filename with a random prefix. @@ -185,9 +190,8 @@ def _get_job_script_filename(self, filepath, hash_length=8) -> str: """ filename, ext = os.path.splitext(filepath) filename = os.path.basename(filename) - _hash = ''.join(choice(ascii_uppercase) for i in range(hash_length)) - return f'{filename}_{_hash}{ext}' - + _hash = "".join(choice(ascii_uppercase) for i in range(hash_length)) + return f"{filename}_{_hash}{ext}" def _render_job_script(self, template, **context): """Render a job script. @@ -202,12 +206,9 @@ def _render_job_script(self, template, **context): str Path to the rendered job script. """ - + # Render the template - _rendered = interpolate_file_template( - template, - **context - ) + _rendered = interpolate_file_template(template, **context) # Generate the output filepath os.makedirs(hc.JOB_SCRIPT_DIR, exist_ok=True) @@ -215,25 +216,43 @@ def _render_job_script(self, template, **context): output_filepath = hc.JOB_SCRIPT_DIR / output_filename # Write it out - with open(output_filepath, 'w') as fo: + with open(output_filepath, "w") as fo: fo.write(_rendered) - + return output_filepath + def _render_directives(self, directives): + """Render the directives into a single string for command interpolation. + + Parameters + ---------- + directives : list + List of scheduler-compliant directives. One per item. + + Returns + ------- + str + Rendered directives, or blank string. + """ + + # Render blank directives if empty + if not directives: + return "" + + return " " + " ".join(directives) + class PBSClient(Client): def __init__(self): - + # Set up the templates super().__init__( - tmp_submit=hc.PBS_SUBMIT, - tmp_status=hc.PBS_STATUS, - tmp_delete=hc.PBS_DELETE + tmp_submit=hc.PBS_SUBMIT, tmp_status=hc.PBS_STATUS, tmp_delete=hc.PBS_DELETE ) def status(self, job_id): - + # Get the raw response raw = super().status(job_id=job_id) @@ -241,9 +260,97 @@ def status(self, job_id): parsed = json.loads(raw) # Get the status out of the job ID - _status = parsed.get('Jobs').get(job_id).get('job_state') + _status = parsed.get("Jobs").get(job_id).get("job_state") return hc.PBS_STATUSES[_status] + def submit( + self, + job_script: Union[str, Path], + directives: list = None, + render: bool = False, + dry_run: bool = False, + depends_on: list = None, + delay: Union[datetime, timedelta] = None, + queue: str = None, + walltime: timedelta = None, + storage: list = None, + **context, + ): + """Submit a job to the scheduler. + + Parameters + ---------- + job_script : Union[str, Path] + Path to the script. + directives : list, optional + List of complete directives to submit, by default list() + render : bool, optional + Render the job script from a template, by default False + dry_run : bool, optional + Return rather than executing the command, by default False + depends_on : list, optional + List of job IDs with successful exit on which this job depends, by default list() + delay: Union[datetime, timedelta] + Delay the start of this job until specific date or interval, by default None + queue: str, optional + Queue on which to submit the job, by default None + walltime: timedelta, optional + Walltime expressed as a timedelta, by default None + storage: list, optional + List of storage mounts to apply, by default None + **context: + Additional key/value pairs to be added to command/jobscript interpolation + """ + + directives = directives if isinstance(directives, list) else list() + + # Add job depends + if depends_on: + depends_on = hu.ensure_list(depends_on) + directives.append("-W depend=afterok:" + ":".join(depends_on)) + + # Add delay (specified time or delta) + if delay: + + current_time = datetime.now() + delay_str = None + + if isinstance(delay, datetime) and delay > current_time: + delay_str = delay.strftime("%Y%m%d%H%M.%S") + + elif isinstance(delay, timedelta) and (current_time + delay) > current_time: + delay_str = (current_time + delay).strftime("%Y%m%d%H%M.%S") + else: + raise ValueError( + "Job submission delay argument either incorrect or puts the job in the past." + ) + + # Add the delay directive + directives.append(f"-a {delay_str}") + + # Add queue + if queue: + directives.append(f"-q {queue}") + + # Add walltime + if walltime: + _walltime = str(walltime) + directives.append(f"-l walltime={_walltime}") + + # Add storage + if storage: + storage_str = "+".join(storage) + directives.append(f"-l storage={storage_str}") + + # Call the super + return super().submit( + job_script=job_script, + directives=directives, + render=render, + dry_run=dry_run, + **context, + ) + class SlurmClient(Client): pass @@ -255,13 +362,14 @@ def __init__(self): super().__init__( tmp_submit=hc.MOCK_SUBMIT, tmp_status=hc.MOCK_STATUS, - tmp_delete=hc.MOCK_DELETE + tmp_delete=hc.MOCK_DELETE, ) - + def status(self, job_id): status_code = super().status(job_id=job_id) return hc.MOCK_STATUSES[status_code] + class ClientFactory: def get_client() -> Client: @@ -278,11 +386,7 @@ def get_client() -> Client: When no scheduler can be detected. """ - clients = dict( - ls=MockClient, - qsub=PBSClient, - sbatch=SlurmClient - ) + clients = dict(ls=MockClient, qsub=PBSClient, sbatch=SlurmClient) # Remove the MockClient if dev mode is off if os.getenv("HPCPY_DEV_MODE", "0") != "1": @@ -290,7 +394,7 @@ def get_client() -> Client: # Loop through the clients in order, looking for a valid scheduler for cmd, client in clients.items(): - if shell(f'which {cmd}', check=False).returncode == 0: + if shell(f"which {cmd}", check=False).returncode == 0: return client() - - raise hx.NoClientException() \ No newline at end of file + + raise hx.NoClientException() diff --git a/hpcpy/constants.py b/hpcpy/constants.py index 8f50134..802df40 100644 --- a/hpcpy/constants.py +++ b/hpcpy/constants.py @@ -1,23 +1,24 @@ """Constants.""" + from pathlib import Path # Location for rendered job scripts -JOB_SCRIPT_DIR = Path.home() / '.hpcpy' / 'job_scripts' +JOB_SCRIPT_DIR = Path.home() / ".hpcpy" / "job_scripts" JOB_SCRIPT_DIR.mkdir(parents=True, exist_ok=True) # Statuses -STATUS_CYCLE_HARVESTING = 'U' -STATUS_EXITING = 'E' -STATUS_FINISHED = 'F' -STATUS_HAS_SUBJOB = 'B' -STATUS_HELD = 'H' -STATUS_MOVED = 'M' -STATUS_MOVING = 'T' -STATUS_QUEUED = 'Q' -STATUS_RUNNING = 'R' -STATUS_SUBJOB_COMPLETED = 'X' -STATUS_SUSPENDED = 'S' -STATUS_WAITING = 'W' +STATUS_CYCLE_HARVESTING = "U" +STATUS_EXITING = "E" +STATUS_FINISHED = "F" +STATUS_HAS_SUBJOB = "B" +STATUS_HELD = "H" +STATUS_MOVED = "M" +STATUS_MOVING = "T" +STATUS_QUEUED = "Q" +STATUS_RUNNING = "R" +STATUS_SUBJOB_COMPLETED = "X" +STATUS_SUSPENDED = "S" +STATUS_WAITING = "W" # PBS status translation PBS_STATUSES = dict( @@ -32,18 +33,18 @@ T=STATUS_MOVING, U=STATUS_CYCLE_HARVESTING, W=STATUS_WAITING, - X=STATUS_SUBJOB_COMPLETED + X=STATUS_SUBJOB_COMPLETED, ) # PBS command templates -PBS_SUBMIT = 'qsub {job_script}' -PBS_STATUS = 'qstat -f -F json {job_id}' -PBS_DELETE = 'qdel {job_id}' +PBS_SUBMIT = "qsub{directives} {job_script}" +PBS_STATUS = "qstat -f -F json {job_id}" +PBS_DELETE = "qdel {job_id}" # Mock command templateds -MOCK_SUBMIT = 'echo 12345' -MOCK_STATUS = 'echo Q' +MOCK_SUBMIT = "echo 12345" +MOCK_STATUS = "echo Q" MOCK_DELETE = 'echo "DELETED"' # Mock status translation -MOCK_STATUSES = PBS_STATUSES \ No newline at end of file +MOCK_STATUSES = PBS_STATUSES diff --git a/hpcpy/exceptions.py b/hpcpy/exceptions.py index 78408f0..483d2f8 100644 --- a/hpcpy/exceptions.py +++ b/hpcpy/exceptions.py @@ -1,3 +1,5 @@ class NoClientException(Exception): def __init__(self): - super().__init__('Unable to detect scheduler type, cannot determine client type.') \ No newline at end of file + super().__init__( + "Unable to detect scheduler type, cannot determine client type." + ) diff --git a/hpcpy/utilities.py b/hpcpy/utilities.py index f478b9c..47b7e3a 100644 --- a/hpcpy/utilities.py +++ b/hpcpy/utilities.py @@ -1,4 +1,5 @@ """Utilities.""" + import subprocess as sp import jinja2 as j2 import jinja2.meta as j2m @@ -24,17 +25,13 @@ def shell(cmd, shell=True, check=True, capture_output=True, **kwargs): ------- subprocess.CompletedProcess Process object. - + Raises ------ subprocess.CalledProcessError """ return sp.run( - cmd, - shell=shell, - check=check, - capture_output=capture_output, - **kwargs + cmd, shell=shell, check=check, capture_output=capture_output, **kwargs ) @@ -50,7 +47,7 @@ def interpolate_string_template(template, **kwargs) -> str: ------- str Interpolated template. - + Raises ------ jinja2.exceptions.UndefinedError : @@ -58,22 +55,19 @@ def interpolate_string_template(template, **kwargs) -> str: """ # Set up the rendering environment - env = j2.Environment( - loader=j2.BaseLoader(), - undefined=j2.DebugUndefined - ) - + env = j2.Environment(loader=j2.BaseLoader(), undefined=j2.DebugUndefined) + # Render the template _template = env.from_string(template) rendered = _template.render(**kwargs) - + # Look for undefined variables (those that remain even after conditionals) ast = env.parse(rendered) undefined = j2m.find_undeclared_variables(ast) if undefined: - raise j2.UndefinedError(f'The following variables are undefined: {undefined!r}') - + raise j2.UndefinedError(f"The following variables are undefined: {undefined!r}") + return rendered @@ -90,9 +84,10 @@ def interpolate_file_template(filepath, **kwargs): str Interpolated template. """ - template = open(filepath, 'r').read() + template = open(filepath, "r").read() return interpolate_string_template(template, **kwargs) + def get_installed_root() -> Path: """Get the installed root of the benchcab installation. @@ -102,4 +97,15 @@ def get_installed_root() -> Path: Path to the installed root. """ - return Path(resources.files("hpcpy")) \ No newline at end of file + return Path(resources.files("hpcpy")) + + +def ensure_list(obj): + """Ensure the object provided is a list. + + Parameters + ---------- + obj : mixed + Object of any type + """ + return obj if isinstance(obj, list) else [obj] diff --git a/tests/test_client.py b/tests/test_client.py index ff61330..e517251 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,9 +1,11 @@ """Client tests.""" + import pytest from hpcpy.client import ClientFactory import hpcpy.constants as hc import hpcpy.utilities as hu + @pytest.fixture def client(): return ClientFactory.get_client() @@ -12,42 +14,48 @@ def client(): def test_get_job_script_filename(client): """Test get_job_script_filename.""" hash_length = 5 - input_filename = 'file.sh' - result = client._get_job_script_filename(f'/path/to/{input_filename}', hash_length=hash_length) + input_filename = "file.sh" + result = client._get_job_script_filename( + f"/path/to/{input_filename}", hash_length=hash_length + ) assert result != input_filename - assert len(result) == len(input_filename) + hash_length + 1 # underscore + assert len(result) == len(input_filename) + hash_length + 1 # underscore + def test_submit(client): """Test submit.""" - result = client.submit('test.txt') - assert result == '12345' + result = client.submit("test.txt") + assert result == "12345" + def test_status(client): """Test status.""" - result = client.status('12345') + result = client.status("12345") assert result == hc.STATUS_QUEUED + def test_delete(client): """Test delete.""" - result = client.delete('12345') - assert result == 'DELETED' + result = client.delete("12345") + assert result == "DELETED" + def test_render_job_script(client): """Test rendering the job script.""" - + # Write it out. - template_filepath = hu.get_installed_root() / 'data' / 'test' / 'test.j2' - rendered_filepath = client._render_job_script(template_filepath, myarg='world') + template_filepath = hu.get_installed_root() / "data" / "test" / "test.j2" + rendered_filepath = client._render_job_script(template_filepath, myarg="world") - rendered = open(rendered_filepath, 'r').read().strip() - assert rendered == 'hello world' + rendered = open(rendered_filepath, "r").read().strip() + assert rendered == "hello world" def test_clean_rendered_job_scripts(client): """Test cleaning out the rendered job scripts.""" # Run the clean command - client.job_script_expiry = '0h' + client.job_script_expiry = "0h" # Ensure that the job script directory is empty client._clean_rendered_job_scripts() diff --git a/tests/test_hpcpy.py b/tests/test_hpcpy.py deleted file mode 100644 index 19a033a..0000000 --- a/tests/test_hpcpy.py +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env python - -"""Tests for `hpcpy` package.""" - -import pytest - - -from hpcpy import hpcpy - - -@pytest.fixture -def response(): - """Sample pytest fixture. - - See more at: http://doc.pytest.org/en/latest/fixture.html - """ - # import requests - # return requests.get('https://github.com/audreyr/cookiecutter-pypackage') - - -def test_content(response): - """Sample pytest test function with the pytest fixture as an argument.""" - # from bs4 import BeautifulSoup - # assert 'GitHub' in BeautifulSoup(response.content).title.string diff --git a/tests/test_pbs_client.py b/tests/test_pbs_client.py new file mode 100644 index 0000000..6d4752b --- /dev/null +++ b/tests/test_pbs_client.py @@ -0,0 +1,64 @@ +import pytest +from hpcpy.client import PBSClient +from datetime import datetime, timedelta + + +@pytest.fixture +def client(): + return PBSClient() + + +def test_directives(client): + """Test if the directives are properly interpolated""" + expected = "qsub -q express -l walltime=10:00:00 test.sh" + result = client.submit( + "test.sh", directives=["-q express", "-l walltime=10:00:00"], dry_run=True + ) + + assert result == expected + + +def test_depends_on(client): + """Test if the depends_on argument is correctly applied.""" + expected = "qsub -W depend=afterok:job1:job2 test.sh" + result = client.submit("test.sh", depends_on=["job1", "job2"], dry_run=True) + + assert result == expected + + +def test_delay(client): + """Test if delay is correctly applied""" + run_at = datetime(2200, 7, 26, 12, 0, 0) + run_at_str = run_at.strftime("%Y%m%d%H%M.%S") + expected = f"qsub -a {run_at_str} test.sh" + result = client.submit("test.sh", delay=run_at, dry_run=True) + + assert result == expected + + +def test_queue(client): + """Test if the queue argument is added.""" + expected = "qsub -q express test.sh" + result = client.submit("test.sh", dry_run=True, queue="express") + + assert result == expected + + +def test_walltime(client): + """Test if the walltime argument is added.""" + expected = "qsub -l walltime=2:30:12 test.sh" + result = client.submit( + "test.sh", dry_run=True, walltime=timedelta(hours=2, minutes=30, seconds=12) + ) + + assert result == expected + + +def test_storage(client): + """Test if the storage argument is added.""" + expected = "qsub -l storage=gdata/rp23+scratch/rp23 test.sh" + result = client.submit( + "test.sh", dry_run=True, storage=["gdata/rp23", "scratch/rp23"] + ) + + assert result == expected diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 408337c..ecd6324 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -1,6 +1,7 @@ """Tests for utilities.py""" + import hpcpy.utilities as hu def test_interpolate_string_template(): - assert hu.interpolate_string_template('hello {{arg}}', arg='world') == 'hello world' \ No newline at end of file + assert hu.interpolate_string_template("hello {{arg}}", arg="world") == "hello world"