Skip to content

Commit

Permalink
Merge branch 'main' into mlcube-download-log
Browse files Browse the repository at this point in the history
  • Loading branch information
hasan7n authored Feb 21, 2024
2 parents 600d757 + 824c432 commit 3893989
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 17 deletions.
2 changes: 2 additions & 0 deletions cli/medperf/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import medperf.commands.association.association as association
import medperf.commands.compatibility_test.compatibility_test as compatibility_test
import medperf.commands.storage as storage
from medperf.utils import check_for_updates

app = typer.Typer()
app.add_typer(mlcube.app, name="mlcube", help="Manage mlcubes")
Expand Down Expand Up @@ -99,5 +100,6 @@ def main(

logging.info(f"Running MedPerf v{__version__} on {loglevel} logging level")
logging.info(f"Executed command: {' '.join(sys.argv[1:])}")
check_for_updates()

config.ui.print(f"MedPerf {__version__}")
2 changes: 1 addition & 1 deletion cli/medperf/commands/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def run_inference(self):
except ExecutionError as e:
if not self.ignore_model_errors:
logging.error(f"Model MLCube Execution failed: {e}")
raise ExecutionError("Model MLCube failed")
raise ExecutionError(f"Model MLCube failed: {e}")
else:
self.partial = True
logging.warning(f"Model MLCube Execution failed: {e}")
Expand Down
11 changes: 7 additions & 4 deletions cli/medperf/entities/cube.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import yaml
import pexpect
import logging
from typing import List, Dict, Optional, Union
from pydantic import Field
Expand All @@ -12,6 +11,7 @@
remove_path,
verify_hash,
generate_tmp_path,
spawn_and_kill,
)
from medperf.entities.interface import Entity, Uploadable
from medperf.entities.schemas import MedperfSchema, DeployableSchema
Expand Down Expand Up @@ -248,7 +248,8 @@ def _set_image_hash_from_registry(self):
cmd = f"mlcube --log-level {config.loglevel} inspect --mlcube={self.cube_path} --format=yaml"
cmd += f" --platform={config.platform} --output-file {tmp_out_yaml}"
logging.info(f"Running MLCube command: {cmd}")
with pexpect.spawn(cmd, timeout=config.mlcube_inspect_timeout) as proc:
with spawn_and_kill(cmd, timeout=config.mlcube_inspect_timeout) as proc_wrapper:
proc = proc_wrapper.proc
combine_proc_sp_text(proc)
if proc.exitstatus != 0:
raise ExecutionError("There was an error while inspecting the image hash")
Expand All @@ -266,7 +267,8 @@ def _get_image_from_registry(self):
if config.platform == "singularity":
cmd += f" -Psingularity.image={self._converted_singularity_image_name}"
logging.info(f"Running MLCube command: {cmd}")
with pexpect.spawn(cmd, timeout=config.mlcube_configure_timeout) as proc:
with spawn_and_kill(cmd, timeout=config.mlcube_configure_timeout) as proc_wrapper:
proc = proc_wrapper.proc
combine_proc_sp_text(proc)
if proc.exitstatus != 0:
raise ExecutionError("There was an error while retrieving the MLCube image")
Expand Down Expand Up @@ -361,7 +363,8 @@ def run(
cmd += " -Pplatform.accelerator_count=0"

logging.info(f"Running MLCube command: {cmd}")
with pexpect.spawn(cmd, timeout=timeout) as proc:
with spawn_and_kill(cmd, timeout=timeout) as proc_wrapper:
proc = proc_wrapper.proc
proc_out = combine_proc_sp_text(proc)

if output_logs is not None:
Expand Down
19 changes: 10 additions & 9 deletions cli/medperf/tests/entities/test_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def setup(request, mocker, comms, fs):

# Mock additional third party elements
mpexpect = MockPexpect(0)
mocker.patch(PATCH_CUBE.format("pexpect.spawn"), side_effect=mpexpect.spawn)
mocker.patch(PATCH_CUBE.format("spawn_and_kill.spawn"), side_effect=mpexpect.spawn)
mocker.patch(PATCH_CUBE.format("combine_proc_sp_text"), return_value="")
mocker.patch(PATCH_CUBE.format("spawn_and_kill.killpg"), return_value="")

return request.param

Expand Down Expand Up @@ -95,7 +96,7 @@ def test_download_run_files_without_image_configures_mlcube(
fs.create_file(
"tmp_path", contents=yaml.dump({"hash": NO_IMG_CUBE["image_hash"]})
)
spy = mocker.spy(medperf.entities.cube.pexpect, "spawn")
spy = mocker.spy(medperf.entities.cube.spawn_and_kill, "spawn")
expected_cmds = [
f"mlcube --log-level debug configure --mlcube={self.manifest_path} --platform={config.platform}",
f"mlcube --log-level debug inspect --mlcube={self.manifest_path}"
Expand All @@ -122,7 +123,7 @@ def test_download_run_files_stops_execution_if_configure_fails(
"tmp_path", contents=yaml.dump({"hash": NO_IMG_CUBE["image_hash"]})
)
mpexpect = MockPexpect(1, "expected_hash")
mocker.patch("pexpect.spawn", side_effect=mpexpect.spawn)
mocker.patch(PATCH_CUBE.format("spawn_and_kill.spawn"), side_effect=mpexpect.spawn)

# Act & Assert
cube = Cube.get(self.id)
Expand All @@ -147,7 +148,7 @@ def test_download_run_files_without_image_fails_with_wrong_hash(
@pytest.mark.parametrize("setup", [{"remote": [DEFAULT_CUBE]}], indirect=True)
def test_download_run_files_with_image_isnt_configured(self, mocker, setup):
# Arrange
spy = mocker.spy(medperf.entities.cube.pexpect, "spawn")
spy = mocker.spy(medperf.entities.cube.spawn_and_kill, "spawn")

# Act
cube = Cube.get(self.id)
Expand Down Expand Up @@ -175,7 +176,7 @@ def test_cube_runs_command(self, mocker, timeout, setup, task):
# Arrange
mpexpect = MockPexpect(0, "expected_hash")
spy = mocker.patch(
PATCH_CUBE.format("pexpect.spawn"), side_effect=mpexpect.spawn
PATCH_CUBE.format("spawn_and_kill.spawn"), side_effect=mpexpect.spawn
)
mocker.patch(PATCH_CUBE.format("Cube.get_config"), side_effect=["", ""])
expected_cmd = (
Expand All @@ -196,7 +197,7 @@ def test_cube_runs_command(self, mocker, timeout, setup, task):
def test_cube_runs_command_with_rw_access(self, mocker, setup, task):
# Arrange
mpexpect = MockPexpect(0, "expected_hash")
spy = mocker.patch("pexpect.spawn", side_effect=mpexpect.spawn)
spy = mocker.patch(PATCH_CUBE.format("spawn_and_kill.spawn"), side_effect=mpexpect.spawn)
mocker.patch(
PATCH_CUBE.format("Cube.get_config"),
side_effect=["", ""],
Expand All @@ -219,7 +220,7 @@ def test_cube_runs_command_with_rw_access(self, mocker, setup, task):
def test_cube_runs_command_with_extra_args(self, mocker, setup, task):
# Arrange
mpexpect = MockPexpect(0, "expected_hash")
spy = mocker.patch("pexpect.spawn", side_effect=mpexpect.spawn)
spy = mocker.patch(PATCH_CUBE.format("spawn_and_kill.spawn"), side_effect=mpexpect.spawn)
mocker.patch(PATCH_CUBE.format("Cube.get_config"), side_effect=["", ""])
expected_cmd = (
f"mlcube --log-level debug run --mlcube={self.manifest_path} --task={task} "
Expand All @@ -239,7 +240,7 @@ def test_cube_runs_command_with_extra_args(self, mocker, setup, task):
def test_cube_runs_command_and_preserves_runtime_args(self, mocker, setup, task):
# Arrange
mpexpect = MockPexpect(0, "expected_hash")
spy = mocker.patch("pexpect.spawn", side_effect=mpexpect.spawn)
spy = mocker.patch(PATCH_CUBE.format("spawn_and_kill.spawn"), side_effect=mpexpect.spawn)
mocker.patch(
PATCH_CUBE.format("Cube.get_config"),
side_effect=["cpuarg cpuval", "gpuarg gpuval"],
Expand All @@ -262,7 +263,7 @@ def test_cube_runs_command_and_preserves_runtime_args(self, mocker, setup, task)
def test_run_stops_execution_if_child_fails(self, mocker, setup, task):
# Arrange
mpexpect = MockPexpect(1, "expected_hash")
mocker.patch("pexpect.spawn", side_effect=mpexpect.spawn)
mocker.patch(PATCH_CUBE.format("spawn_and_kill.spawn"), side_effect=mpexpect.spawn)
mocker.patch(PATCH_CUBE.format("Cube.get_config"), side_effect=["", ""])

# Act & Assert
Expand Down
11 changes: 8 additions & 3 deletions cli/medperf/tests/mocks/pexpect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
class MockChild:
def __init__(self, exitstatus, stdout):
def __init__(self, exitstatus, stdout, pid):
self.exitstatus = exitstatus
self.stdout = stdout
self.pid = pid

def __enter__(self, *args, **kwargs):
return self
Expand All @@ -18,11 +19,15 @@ def isalive(self):
def close(self):
pass

def wait(self):
pass


class MockPexpect:
def __init__(self, exitstatus, stdout=""):
def __init__(self, exitstatus, stdout="", pid=123456):
self.exitstatus = exitstatus
self.stdout = stdout
self.pid = pid

def spawn(self, command: str, timeout: int = 30) -> MockChild:
return MockChild(self.exitstatus, self.stdout)
return MockChild(self.exitstatus, self.stdout, self.pid)
74 changes: 74 additions & 0 deletions cli/medperf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import re
import os
import signal
import yaml
import random
import hashlib
Expand All @@ -18,6 +19,7 @@
from typing import List
from colorama import Fore, Style
from pexpect.exceptions import TIMEOUT
from git import Repo, GitCommandError
import medperf.config as config
from medperf.exceptions import ExecutionError, MedperfException, InvalidEntityError

Expand Down Expand Up @@ -438,3 +440,75 @@ def filter_latest_associations(associations, entity_key):

latest_associations = list(latest_associations.values())
return latest_associations


def check_for_updates() -> None:
"""Check if the current branch is up-to-date with its remote counterpart using GitPython."""
repo = Repo(config.BASE_DIR)
if repo.bare:
logging.debug('Repo is bare')
return

logging.debug(f'Current git commit: {repo.head.commit.hexsha}')

try:
for remote in repo.remotes:
remote.fetch()

if repo.head.is_detached:
logging.debug('Repo is in detached state')
return

current_branch = repo.active_branch
tracking_branch = current_branch.tracking_branch()

if tracking_branch is None:
logging.debug("Current branch does not track a remote branch.")
return
if current_branch.commit.hexsha == tracking_branch.commit.hexsha:
logging.debug('No git branch updates.')
return

logging.debug(f'Git branch updates found: {current_branch.commit.hexsha} -> {tracking_branch.commit.hexsha}')
config.ui.print_warning('MedPerf client updates found. Please, update your MedPerf installation.')
except GitCommandError as e:
logging.debug('Exception raised during updates check. Maybe user checked out repo with git@ and private key'
'or repo is in detached / non-tracked state?')
logging.debug(e)


class spawn_and_kill:
def __init__(self, cmd, timeout=None, *args, **kwargs):
self.cmd = cmd
self.timeout = timeout
self._args = args
self._kwargs = kwargs
self.proc: spawn
self.exception_occurred = False

@staticmethod
def spawn(*args, **kwargs):
return spawn(*args, **kwargs)

def killpg(self):
os.killpg(self.pid, signal.SIGINT)

def __enter__(self):
self.proc = self.spawn(self.cmd, timeout=self.timeout, *self._args, **self._kwargs)
self.pid = self.proc.pid
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type:
self.exception_occurred = True
# Forcefully kill the process group if any exception occurred, in particular,
# - KeyboardInterrupt (user pressed Ctrl+C in terminal)
# - any other medperf exception like OOM or bug
# - pexpect.TIMEOUT
logging.info(f'Killing ancestor processes because of exception: {exc_val=}')
self.killpg()

self.proc.close()
self.proc.wait()
# Return False to propagate exceptions, if any
return False
1 change: 1 addition & 0 deletions cli/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ schema==0.7.5
setuptools<=66.1.1
email-validator==2.0.0
auth0-python==4.3.0
GitPython==3.1.41

0 comments on commit 3893989

Please sign in to comment.