diff --git a/libensemble/resources/platforms.py b/libensemble/resources/platforms.py index fe274b254..580e55b92 100644 --- a/libensemble/resources/platforms.py +++ b/libensemble/resources/platforms.py @@ -8,6 +8,7 @@ option or the environment variable ``LIBE_PLATFORM``. """ +import logging import os import subprocess from typing import Optional @@ -16,6 +17,10 @@ from libensemble.utils.misc import specs_dump +logger = logging.getLogger(__name__) +# To change logging level for just this module +# logger.setLevel(logging.DEBUG) + class PlatformException(Exception): """Platform module exception""" @@ -269,6 +274,7 @@ class Known_platforms(BaseModel): generic_rocm: GenericROCm = GenericROCm() crusher: Crusher = Crusher() frontier: Frontier = Frontier() + perlmutter: Perlmutter = Perlmutter() perlmutter_c: PerlmutterCPU = PerlmutterCPU() perlmutter_g: PerlmutterGPU = PerlmutterGPU() polaris: Polaris = Polaris() @@ -292,10 +298,15 @@ def known_envs(): """Detect system by environment variables""" name = None if os.environ.get("NERSC_HOST") == "perlmutter": - if "gpu_" in os.environ.get("SLURM_JOB_PARTITION"): - name = "perlmutter_g" + partition = os.environ.get("SLURM_JOB_PARTITION") + if partition: + if "gpu_" in partition: + name = "perlmutter_g" + else: + name = "perlmutter_c" else: - name = "perlmutter_c" + name = "perlmutter" + logger.manager_warning("Perlmutter detected, but no compute partition detected. Are you on login nodes?") return name diff --git a/libensemble/tests/unit_tests/test_platform.py b/libensemble/tests/unit_tests/test_platform.py index f319e0184..ba27d03b2 100644 --- a/libensemble/tests/unit_tests/test_platform.py +++ b/libensemble/tests/unit_tests/test_platform.py @@ -1,6 +1,12 @@ import pytest -from libensemble.resources.platforms import Known_platforms, PlatformException, get_platform, known_system_detect +from libensemble.resources.platforms import ( + Known_platforms, + PlatformException, + get_platform, + known_envs, + known_system_detect, +) from libensemble.utils.misc import specs_dump my_spec = { @@ -20,8 +26,12 @@ } -def test_platform_empty(): +def test_platform_empty(monkeypatch): """Test no platform options supplied""" + + # Ensure NERSC_HOST not set + monkeypatch.delenv("NERSC_HOST", raising=False) + exp = {} libE_specs = {} platform_info = get_platform(libE_specs) @@ -55,10 +65,13 @@ def test_platform_known(): assert platform_info == exp, f"platform_info does not match expected: {platform_info}" -def test_platform_specs(): +def test_platform_specs(monkeypatch): """Test known platform and platform_specs supplied""" from libensemble.specs import LibeSpecs + # Ensure NERSC_HOST not set + monkeypatch.delenv("NERSC_HOST", raising=False) + exp = my_spec libE_specs = {"platform_specs": my_spec} platform_info = get_platform(libE_specs) @@ -81,7 +94,12 @@ def test_platform_specs(): assert specs_dump(LS.platform_specs, exclude_none=True) == exp, "Conversion isn't as expected" -def test_known_sys_detect(): +def test_known_sys_detect(monkeypatch): + """Test detection of known system""" + + # Ensure NERSC_HOST not set + monkeypatch.delenv("NERSC_HOST", raising=False) + known_platforms = specs_dump(Known_platforms(), exclude_none=True) get_sys_cmd = "echo summit.olcf.ornl.gov" # Overrides default "hostname -d" name = known_system_detect(cmd=get_sys_cmd) @@ -94,9 +112,26 @@ def test_known_sys_detect(): assert name is None, f"Expected known_system_detect to return None ({name})" +def test_env_sys_detect(monkeypatch): + """Test detection of system partitions""" + monkeypatch.setenv("NERSC_HOST", "other_host") + monkeypatch.setenv("SLURM_JOB_PARTITION", "cpu_test_partition") + name = known_envs() + assert name is None + monkeypatch.setenv("NERSC_HOST", "perlmutter") + + monkeypatch.setenv("SLURM_JOB_PARTITION", "gpu_test_partition") + name = known_envs() + assert name == "perlmutter_g" + + monkeypatch.setenv("SLURM_JOB_PARTITION", "cpu_test_partition") + name = known_envs() + assert name == "perlmutter_c" + + monkeypatch.delenv("SLURM_JOB_PARTITION", raising=False) + name = known_envs() + assert name == "perlmutter" + + if __name__ == "__main__": - test_platform_empty() - test_unknown_platform() - test_platform_known() - test_platform_specs() - test_known_sys_detect() + pytest.main([__file__])