diff --git a/cli/medperf/cli.py b/cli/medperf/cli.py index c6524a21a..078d37841 100644 --- a/cli/medperf/cli.py +++ b/cli/medperf/cli.py @@ -17,6 +17,7 @@ import medperf.commands.association.association as association import medperf.commands.compatibility_test.compatibility_test as compatibility_test import medperf.commands.storage as storage +from medperf.utils import check_for_updates app = typer.Typer() app.add_typer(mlcube.app, name="mlcube", help="Manage mlcubes") @@ -99,5 +100,6 @@ def main( logging.info(f"Running MedPerf v{__version__} on {loglevel} logging level") logging.info(f"Executed command: {' '.join(sys.argv[1:])}") + check_for_updates() config.ui.print(f"MedPerf {__version__}") diff --git a/cli/medperf/commands/execution.py b/cli/medperf/commands/execution.py index f604c74d0..d8afb2244 100644 --- a/cli/medperf/commands/execution.py +++ b/cli/medperf/commands/execution.py @@ -90,7 +90,7 @@ def run_inference(self): except ExecutionError as e: if not self.ignore_model_errors: logging.error(f"Model MLCube Execution failed: {e}") - raise ExecutionError("Model MLCube failed") + raise ExecutionError(f"Model MLCube failed: {e}") else: self.partial = True logging.warning(f"Model MLCube Execution failed: {e}") diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index 915e15039..6e6f65677 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -1,6 +1,5 @@ import os import yaml -import pexpect import logging from typing import List, Dict, Optional, Union from pydantic import Field @@ -12,6 +11,7 @@ remove_path, verify_hash, generate_tmp_path, + spawn_and_kill, ) from medperf.entities.interface import Entity, Uploadable from medperf.entities.schemas import MedperfSchema, DeployableSchema @@ -248,7 +248,8 @@ def _set_image_hash_from_registry(self): cmd = f"mlcube --log-level {config.loglevel} inspect --mlcube={self.cube_path} --format=yaml" cmd += f" --platform={config.platform} --output-file {tmp_out_yaml}" logging.info(f"Running MLCube command: {cmd}") - with pexpect.spawn(cmd, timeout=config.mlcube_inspect_timeout) as proc: + with spawn_and_kill(cmd, timeout=config.mlcube_inspect_timeout) as proc_wrapper: + proc = proc_wrapper.proc combine_proc_sp_text(proc) if proc.exitstatus != 0: raise ExecutionError("There was an error while inspecting the image hash") @@ -266,7 +267,8 @@ def _get_image_from_registry(self): if config.platform == "singularity": cmd += f" -Psingularity.image={self._converted_singularity_image_name}" logging.info(f"Running MLCube command: {cmd}") - with pexpect.spawn(cmd, timeout=config.mlcube_configure_timeout) as proc: + with spawn_and_kill(cmd, timeout=config.mlcube_configure_timeout) as proc_wrapper: + proc = proc_wrapper.proc combine_proc_sp_text(proc) if proc.exitstatus != 0: raise ExecutionError("There was an error while retrieving the MLCube image") @@ -361,7 +363,8 @@ def run( cmd += " -Pplatform.accelerator_count=0" logging.info(f"Running MLCube command: {cmd}") - with pexpect.spawn(cmd, timeout=timeout) as proc: + with spawn_and_kill(cmd, timeout=timeout) as proc_wrapper: + proc = proc_wrapper.proc proc_out = combine_proc_sp_text(proc) if output_logs is not None: diff --git a/cli/medperf/tests/entities/test_cube.py b/cli/medperf/tests/entities/test_cube.py index 758c64097..c7b8998cc 100644 --- a/cli/medperf/tests/entities/test_cube.py +++ b/cli/medperf/tests/entities/test_cube.py @@ -39,8 +39,9 @@ def setup(request, mocker, comms, fs): # Mock additional third party elements mpexpect = MockPexpect(0) - mocker.patch(PATCH_CUBE.format("pexpect.spawn"), side_effect=mpexpect.spawn) + mocker.patch(PATCH_CUBE.format("spawn_and_kill.spawn"), side_effect=mpexpect.spawn) mocker.patch(PATCH_CUBE.format("combine_proc_sp_text"), return_value="") + mocker.patch(PATCH_CUBE.format("spawn_and_kill.killpg"), return_value="") return request.param @@ -95,7 +96,7 @@ def test_download_run_files_without_image_configures_mlcube( fs.create_file( "tmp_path", contents=yaml.dump({"hash": NO_IMG_CUBE["image_hash"]}) ) - spy = mocker.spy(medperf.entities.cube.pexpect, "spawn") + spy = mocker.spy(medperf.entities.cube.spawn_and_kill, "spawn") expected_cmds = [ f"mlcube --log-level debug configure --mlcube={self.manifest_path} --platform={config.platform}", f"mlcube --log-level debug inspect --mlcube={self.manifest_path}" @@ -122,7 +123,7 @@ def test_download_run_files_stops_execution_if_configure_fails( "tmp_path", contents=yaml.dump({"hash": NO_IMG_CUBE["image_hash"]}) ) mpexpect = MockPexpect(1, "expected_hash") - mocker.patch("pexpect.spawn", side_effect=mpexpect.spawn) + mocker.patch(PATCH_CUBE.format("spawn_and_kill.spawn"), side_effect=mpexpect.spawn) # Act & Assert cube = Cube.get(self.id) @@ -147,7 +148,7 @@ def test_download_run_files_without_image_fails_with_wrong_hash( @pytest.mark.parametrize("setup", [{"remote": [DEFAULT_CUBE]}], indirect=True) def test_download_run_files_with_image_isnt_configured(self, mocker, setup): # Arrange - spy = mocker.spy(medperf.entities.cube.pexpect, "spawn") + spy = mocker.spy(medperf.entities.cube.spawn_and_kill, "spawn") # Act cube = Cube.get(self.id) @@ -175,7 +176,7 @@ def test_cube_runs_command(self, mocker, timeout, setup, task): # Arrange mpexpect = MockPexpect(0, "expected_hash") spy = mocker.patch( - PATCH_CUBE.format("pexpect.spawn"), side_effect=mpexpect.spawn + PATCH_CUBE.format("spawn_and_kill.spawn"), side_effect=mpexpect.spawn ) mocker.patch(PATCH_CUBE.format("Cube.get_config"), side_effect=["", ""]) expected_cmd = ( @@ -196,7 +197,7 @@ def test_cube_runs_command(self, mocker, timeout, setup, task): def test_cube_runs_command_with_rw_access(self, mocker, setup, task): # Arrange mpexpect = MockPexpect(0, "expected_hash") - spy = mocker.patch("pexpect.spawn", side_effect=mpexpect.spawn) + spy = mocker.patch(PATCH_CUBE.format("spawn_and_kill.spawn"), side_effect=mpexpect.spawn) mocker.patch( PATCH_CUBE.format("Cube.get_config"), side_effect=["", ""], @@ -219,7 +220,7 @@ def test_cube_runs_command_with_rw_access(self, mocker, setup, task): def test_cube_runs_command_with_extra_args(self, mocker, setup, task): # Arrange mpexpect = MockPexpect(0, "expected_hash") - spy = mocker.patch("pexpect.spawn", side_effect=mpexpect.spawn) + spy = mocker.patch(PATCH_CUBE.format("spawn_and_kill.spawn"), side_effect=mpexpect.spawn) mocker.patch(PATCH_CUBE.format("Cube.get_config"), side_effect=["", ""]) expected_cmd = ( f"mlcube --log-level debug run --mlcube={self.manifest_path} --task={task} " @@ -239,7 +240,7 @@ def test_cube_runs_command_with_extra_args(self, mocker, setup, task): def test_cube_runs_command_and_preserves_runtime_args(self, mocker, setup, task): # Arrange mpexpect = MockPexpect(0, "expected_hash") - spy = mocker.patch("pexpect.spawn", side_effect=mpexpect.spawn) + spy = mocker.patch(PATCH_CUBE.format("spawn_and_kill.spawn"), side_effect=mpexpect.spawn) mocker.patch( PATCH_CUBE.format("Cube.get_config"), side_effect=["cpuarg cpuval", "gpuarg gpuval"], @@ -262,7 +263,7 @@ def test_cube_runs_command_and_preserves_runtime_args(self, mocker, setup, task) def test_run_stops_execution_if_child_fails(self, mocker, setup, task): # Arrange mpexpect = MockPexpect(1, "expected_hash") - mocker.patch("pexpect.spawn", side_effect=mpexpect.spawn) + mocker.patch(PATCH_CUBE.format("spawn_and_kill.spawn"), side_effect=mpexpect.spawn) mocker.patch(PATCH_CUBE.format("Cube.get_config"), side_effect=["", ""]) # Act & Assert diff --git a/cli/medperf/tests/mocks/pexpect.py b/cli/medperf/tests/mocks/pexpect.py index 54dea77b0..83f6811c9 100644 --- a/cli/medperf/tests/mocks/pexpect.py +++ b/cli/medperf/tests/mocks/pexpect.py @@ -1,7 +1,8 @@ class MockChild: - def __init__(self, exitstatus, stdout): + def __init__(self, exitstatus, stdout, pid): self.exitstatus = exitstatus self.stdout = stdout + self.pid = pid def __enter__(self, *args, **kwargs): return self @@ -18,11 +19,15 @@ def isalive(self): def close(self): pass + def wait(self): + pass + class MockPexpect: - def __init__(self, exitstatus, stdout=""): + def __init__(self, exitstatus, stdout="", pid=123456): self.exitstatus = exitstatus self.stdout = stdout + self.pid = pid def spawn(self, command: str, timeout: int = 30) -> MockChild: - return MockChild(self.exitstatus, self.stdout) + return MockChild(self.exitstatus, self.stdout, self.pid) diff --git a/cli/medperf/utils.py b/cli/medperf/utils.py index 2714f97f0..c054d64c1 100644 --- a/cli/medperf/utils.py +++ b/cli/medperf/utils.py @@ -2,6 +2,7 @@ import re import os +import signal import yaml import random import hashlib @@ -18,6 +19,7 @@ from typing import List from colorama import Fore, Style from pexpect.exceptions import TIMEOUT +from git import Repo, GitCommandError import medperf.config as config from medperf.exceptions import ExecutionError, MedperfException, InvalidEntityError @@ -438,3 +440,75 @@ def filter_latest_associations(associations, entity_key): latest_associations = list(latest_associations.values()) return latest_associations + + +def check_for_updates() -> None: + """Check if the current branch is up-to-date with its remote counterpart using GitPython.""" + repo = Repo(config.BASE_DIR) + if repo.bare: + logging.debug('Repo is bare') + return + + logging.debug(f'Current git commit: {repo.head.commit.hexsha}') + + try: + for remote in repo.remotes: + remote.fetch() + + if repo.head.is_detached: + logging.debug('Repo is in detached state') + return + + current_branch = repo.active_branch + tracking_branch = current_branch.tracking_branch() + + if tracking_branch is None: + logging.debug("Current branch does not track a remote branch.") + return + if current_branch.commit.hexsha == tracking_branch.commit.hexsha: + logging.debug('No git branch updates.') + return + + logging.debug(f'Git branch updates found: {current_branch.commit.hexsha} -> {tracking_branch.commit.hexsha}') + config.ui.print_warning('MedPerf client updates found. Please, update your MedPerf installation.') + except GitCommandError as e: + logging.debug('Exception raised during updates check. Maybe user checked out repo with git@ and private key' + 'or repo is in detached / non-tracked state?') + logging.debug(e) + + +class spawn_and_kill: + def __init__(self, cmd, timeout=None, *args, **kwargs): + self.cmd = cmd + self.timeout = timeout + self._args = args + self._kwargs = kwargs + self.proc: spawn + self.exception_occurred = False + + @staticmethod + def spawn(*args, **kwargs): + return spawn(*args, **kwargs) + + def killpg(self): + os.killpg(self.pid, signal.SIGINT) + + def __enter__(self): + self.proc = self.spawn(self.cmd, timeout=self.timeout, *self._args, **self._kwargs) + self.pid = self.proc.pid + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type: + self.exception_occurred = True + # Forcefully kill the process group if any exception occurred, in particular, + # - KeyboardInterrupt (user pressed Ctrl+C in terminal) + # - any other medperf exception like OOM or bug + # - pexpect.TIMEOUT + logging.info(f'Killing ancestor processes because of exception: {exc_val=}') + self.killpg() + + self.proc.close() + self.proc.wait() + # Return False to propagate exceptions, if any + return False diff --git a/cli/requirements.txt b/cli/requirements.txt index 4ddb37547..5a5240ba6 100644 --- a/cli/requirements.txt +++ b/cli/requirements.txt @@ -21,3 +21,4 @@ schema==0.7.5 setuptools<=66.1.1 email-validator==2.0.0 auth0-python==4.3.0 +GitPython==3.1.41