Skip to content

Commit

Permalink
Improve code, documentation and modularize test cases for config
Browse files Browse the repository at this point in the history
  • Loading branch information
abhaasgoyal committed Feb 2, 2024
1 parent 1f1f827 commit b78294b
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 61 deletions.
1 change: 0 additions & 1 deletion benchcab/benchcab.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def fluxsite_submit_job(
) -> None:
"""Submits the PBS job script step in the fluxsite test workflow."""
config = self._get_config(config_path)

self._validate_environment(project=config["project"], modules=config["modules"])
if self.benchcab_exe_path is None:
msg = "Path to benchcab executable is undefined."
Expand Down
11 changes: 7 additions & 4 deletions benchcab/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

"""A module containing all *_config() functions."""
import os
import sys
from pathlib import Path

import yaml
Expand Down Expand Up @@ -91,10 +90,14 @@ def read_optional_key(config: dict):
raise ValueError(msg)
config["project"] = os.environ["PROJECT"]

# Directory List is obtained from Gadi Resources - https://opus.nci.org.au/display/Help/0.+Welcome+to+Gadi
data_dirs = ["/g/data", "/scratch"]
groups = list(
set([group for data_dir in data_dirs for group in os.listdir(data_dir)])
set(
[
group
for data_dir in internal.USER_PROJECT_DIRS
for group in os.listdir(data_dir)
]
)
)

if config["project"] not in groups:
Expand Down
6 changes: 5 additions & 1 deletion benchcab/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

_, NODENAME, _, _, _ = os.uname()

CONFIG_REQUIRED_KEYS = ["realisations", "modules", "experiment"]
CONFIG_REQUIRED_KEYS = ["realisations", "modules"]

# Parameters for job script:
QSUB_FNAME = "benchmark_cable_qsub.sh"
Expand All @@ -28,6 +28,10 @@
# Path to the user's current working directory
CWD = Path.cwd()

# Directory List is obtained from Gadi User Guide in Section - Gadi Resources
# https://opus.nci.org.au/display/Help/0.+Welcome+to+Gadi
USER_PROJECT_DIRS = ["/g/data", "/scratch"]

# Path to the user's home directory
HOME_DIR = Path(os.environ["HOME"])

Expand Down
115 changes: 60 additions & 55 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,21 @@
import benchcab.internal as bi
import benchcab.utils as bu

# Temporarily set $PROJECT for testing module
NO_OPTIONAL_CONFIG_PROJECT = "hh5"
OPTIONAL_CONFIG_PROJECT = "ks32"


# Temporarily set $PROJECT for testing module
@pytest.fixture(autouse=True, scope="module")
def _set_project_validation_dirs():
with mock.patch("os.listdir") as mocked_listdir:
mocked_listdir.return_value = [
NO_OPTIONAL_CONFIG_PROJECT,
OPTIONAL_CONFIG_PROJECT,
]
yield


@pytest.fixture(autouse=True)
def _set_project_env_variable(monkeypatch):
# Clear existing environment variables first
Expand All @@ -24,13 +35,6 @@ def _set_project_env_variable(monkeypatch):
yield


@pytest.fixture(autouse=True)
def _set_project_validation_dirs():
with mock.patch("os.listdir") as mocked_listdir:
mocked_listdir.return_value = ["hh5", OPTIONAL_CONFIG_PROJECT]
yield


@pytest.fixture()
def config_str(request) -> str:
"""Provide relative YAML path string of data files."""
Expand All @@ -50,10 +54,10 @@ def empty_config() -> dict:


@pytest.fixture()
def default_only_config() -> dict:
def no_optional_config() -> dict:
"""Config with no optional parameters.
Reads from config-basic.yml
Expected value after reading from config-basic.yml
"""
return {
"modules": ["intel-compiler/2021.1.1", "netcdf/4.7.4", "openmpi/4.1.0"],
Expand All @@ -69,12 +73,12 @@ def default_only_config() -> dict:


@pytest.fixture()
def all_optional_default_config(default_only_config) -> dict:
"""Config with all optional parameters set as default.
def all_optional_default_config(no_optional_config) -> dict:
"""Populate all keys in config with default optional values.
Reads from config-basic.yml
Expected value after reading from config-basic.yml
"""
config = default_only_config | {
config = no_optional_config | {
"project": OPTIONAL_CONFIG_PROJECT,
"fluxsite": {
"experiment": bi.FLUXSITE_DEFAULT_EXPERIMENT,
Expand All @@ -90,13 +94,13 @@ def all_optional_default_config(default_only_config) -> dict:


@pytest.fixture()
def all_optional_custom_config(default_only_config) -> dict:
"""Config with custom optional parameters.
def all_optional_custom_config(no_optional_config) -> dict:
"""Populate all keys in config with custom optional values.
Reads from config-optional.yml
Expected value after reading from config-optional.yml
"""
config = default_only_config | {
"project": "hh5",
config = no_optional_config | {
"project": NO_OPTIONAL_CONFIG_PROJECT,
"fluxsite": {
"experiment": "AU-Tum",
"multiprocess": False,
Expand Down Expand Up @@ -126,7 +130,7 @@ def all_optional_custom_config(default_only_config) -> dict:
@pytest.mark.parametrize(
("config_str", "output_config", "pytest_error"),
[
("config-basic.yml", "default_only_config", does_not_raise()),
("config-basic.yml", "no_optional_config", does_not_raise()),
("config-optional.yml", "all_optional_custom_config", does_not_raise()),
("config-missing.yml", "empty_config", pytest.raises(FileNotFoundError)),
],
Expand Down Expand Up @@ -155,44 +159,45 @@ def test_validate_config(config_str, pytest_error):
assert bc.validate_config(config)


@pytest.mark.parametrize(
("input_config", "output_config"),
[
("default_only_config", "all_optional_default_config"),
("all_optional_default_config", "all_optional_default_config"),
("all_optional_custom_config", "all_optional_custom_config"),
],
)
def test_read_optional_key_add_data(input_config, output_config, request):
"""Test default key-values are added if not provided by config.yaml, and existing keys stay intact."""
config = request.getfixturevalue(input_config)
bc.read_optional_key(config)
assert pformat(config) == pformat(request.getfixturevalue(output_config))

class TestReadOptionalKey:
"""Tests related to adding optional keys in config."""

def test_no_project_name(default_only_config, monkeypatch):
"""If project key and $PROJECT are not provided, then raise error."""
monkeypatch.delenv("PROJECT")
err_msg = re.escape(
"""Couldn't resolve project: check 'project' in config.yaml
and/or $PROJECT set in ~/.config/gadi-login.conf
"""
@pytest.mark.parametrize(
("input_config", "output_config"),
[
("no_optional_config", "all_optional_default_config"),
("all_optional_default_config", "all_optional_default_config"),
("all_optional_custom_config", "all_optional_custom_config"),
],
)
with pytest.raises(ValueError, match=err_msg):
bc.read_optional_key(default_only_config)

def test_read_optional_key_add_data(self, input_config, output_config, request):
"""Test default key-values are added if not provided by config.yaml, and existing keys stay intact."""
config = request.getfixturevalue(input_config)
bc.read_optional_key(config)
assert pformat(config) == pformat(request.getfixturevalue(output_config))

def test_user_not_in_project(default_only_config):
"""If user is not in viewable NCI projects, raise error."""
default_only_config["project"] = "non_existing"
err_msg = re.escape(
"User is not a member of project [non_existing]: Check if project key is correct"
)
with pytest.raises(
ValueError,
match=err_msg,
):
bc.read_optional_key(default_only_config)
def test_no_project_name(self, no_optional_config, monkeypatch):
"""If project key and $PROJECT are not provided, then raise error."""
monkeypatch.delenv("PROJECT")
err_msg = re.escape(
"""Couldn't resolve project: check 'project' in config.yaml
and/or $PROJECT set in ~/.config/gadi-login.conf
"""
)
with pytest.raises(ValueError, match=err_msg):
bc.read_optional_key(no_optional_config)

def test_user_not_in_project(self, no_optional_config):
"""If user is not in viewable NCI projects, raise error."""
no_optional_config["project"] = "non_existing"
err_msg = re.escape(
"User is not a member of project [non_existing]: Check if project key is correct"
)
with pytest.raises(
ValueError,
match=err_msg,
):
bc.read_optional_key(no_optional_config)


@pytest.mark.parametrize(
Expand Down

0 comments on commit b78294b

Please sign in to comment.