From f7166420bdd4e8c75db202bd0c2b860297ba994e Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 18 Apr 2024 19:49:29 +0000 Subject: [PATCH] fix dist --- llmfoundry/utils/config_utils.py | 4 ++-- scripts/eval/eval.py | 4 ++++ scripts/train/train.py | 2 ++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 35ab679b4a..026169df11 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -90,7 +90,7 @@ def calculate_batch_size_info( # Coming soon: this conversion math will be done inside Composer Trainer -def update_batch_size_info(cfg: Dict[str, Any]) -> DictConfig: +def update_batch_size_info(cfg: Dict[str, Any]) -> Dict[str, Any]: device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info( cfg['global_train_batch_size'], cfg['device_train_microbatch_size']) cfg['n_gpus'] = dist.get_world_size() @@ -102,7 +102,7 @@ def update_batch_size_info(cfg: Dict[str, Any]) -> DictConfig: if cfg['device_train_microbatch_size'] == 'auto': cfg['device_eval_batch_size'] = 1 # TODO debug auto eval microbatching else: - cfg['device_eval_batch_size'] = cfg.device_train_microbatch_size + cfg['device_eval_batch_size'] = cfg['device_train_microbatch_size'] return cfg diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 4c50aefb80..1e5e382a53 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -203,7 +203,11 @@ class EvalConfig: def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: + # Resolve all interpolation variables as early as possible unstructured_config = om.to_container(cfg, resolve=True) + assert isinstance(unstructured_config, dict) + assert all(isinstance(k, str) for k in unstructured_config.keys()) + unstructured_config = {str(k): v for k, v in unstructured_config.items()} # flatten union types before creating structured config: if 'eval_gauntlet' in unstructured_config: diff --git a/scripts/train/train.py b/scripts/train/train.py index 45ae255850..812f8a7921 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -196,6 +196,8 @@ def main(cfg: DictConfig) -> Trainer: # Resolve all interpolation variables as early as possible unstructured_config = om.to_container(cfg, resolve=True) assert isinstance(unstructured_config, dict) + assert all(isinstance(k, str) for k in unstructured_config.keys()) + unstructured_config = {str(k): v for k, v in unstructured_config.items()} # Structured config does not support unions of containers, so separate single and plural containers if (loader := unstructured_config.get('eval_loader', None)) is not None: