diff --git a/src/fairchem/core/common/utils.py b/src/fairchem/core/common/utils.py index cdc921ebc7..9b9c3a5170 100644 --- a/src/fairchem/core/common/utils.py +++ b/src/fairchem/core/common/utils.py @@ -46,6 +46,12 @@ from torch.nn.modules.module import _IncompatibleKeys +DEFAULT_ENV_VARS = { + # Expandable segments is a new cuda feature that helps with memory fragmentation during frequent allocations (ie: in the case of variable batch sizes). + # see https://pytorch.org/docs/stable/notes/cuda.html. + "PYTORCH_CUDA_ALLOC_CONF" : "expandable_segments:True", +} + # copied from https://stackoverflow.com/questions/33490870/parsing-yaml-in-python-detect-duplicated-keys # prevents loading YAMLS where keys have been overwritten class UniqueKeyLoader(yaml.SafeLoader): @@ -953,6 +959,12 @@ def check_traj_files(batch, traj_dir) -> bool: return all(fl.exists() for fl in traj_files) +def setup_env_vars() -> None: + for k, v in DEFAULT_ENV_VARS.items(): + os.environ[k] = v + logging.info(f"Setting env {k}={v}") + + @contextmanager def new_trainer_context(*, config: dict[str, Any], distributed: bool = False): from fairchem.core.common import distutils, gp_utils @@ -969,6 +981,7 @@ class _TrainingContext: trainer: BaseTrainer setup_logging() + setup_env_vars() original_config = config config = copy.deepcopy(original_config)