diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 7b93a4235d..615b0b09d0 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -390,7 +390,7 @@ def build_optimizer(model: torch.nn.Module, name: str, optimizer_config[k] = om.to_container(v, resolve=True) params = _extract_param_groups(model, optimizer_config) - kwargs = optimizer_config + kwargs = {**optimizer_config} if 'params' in kwargs: raise ValueError( diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 7901d2b44d..35ab679b4a 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -90,19 +90,19 @@ def calculate_batch_size_info( # Coming soon: this conversion math will be done inside Composer Trainer -def update_batch_size_info(cfg: DictConfig) -> DictConfig: +def update_batch_size_info(cfg: Dict[str, Any]) -> DictConfig: device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info( - cfg.global_train_batch_size, cfg.device_train_microbatch_size) - cfg.n_gpus = dist.get_world_size() - cfg.device_train_batch_size = device_train_batch_size - cfg.device_train_microbatch_size = device_train_microbatch_size - cfg.device_train_grad_accum = device_train_grad_accum + cfg['global_train_batch_size'], cfg['device_train_microbatch_size']) + cfg['n_gpus'] = dist.get_world_size() + cfg['device_train_batch_size'] = device_train_batch_size + cfg['device_train_microbatch_size'] = device_train_microbatch_size + cfg['device_train_grad_accum'] = device_train_grad_accum # Safely set `device_eval_batch_size` if not provided by user if 'device_eval_batch_size' not in cfg: - if cfg.device_train_microbatch_size == 'auto': - cfg.device_eval_batch_size = 1 # TODO debug auto eval microbatching + if cfg['device_train_microbatch_size'] == 'auto': + cfg['device_eval_batch_size'] = 1 # TODO debug auto eval microbatching else: - cfg.device_eval_batch_size = cfg.device_train_microbatch_size + cfg['device_eval_batch_size'] = cfg.device_train_microbatch_size return cfg diff --git a/scripts/train/train.py b/scripts/train/train.py index 30baaac260..576ac383d1 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -115,17 +115,16 @@ class TrainConfig: def validate_config(train_config: TrainConfig): """Validates compatible model and dataloader selection.""" loaders = [train_config.train_loader] - if train_config.eval_loader is not None or train_config.eval_loaders is not None: - if isinstance(train_config.eval_loaders, list): - for loader in (train_config.eval_loaders or []): # pyright - if 'label' not in loader or loader['label'] is None: - raise ValueError( - 'When specifying multiple evaluation datasets, each one must include the \ + if train_config.eval_loaders is not None: + for loader in (train_config.eval_loaders or []): # pyright + if 'label' not in loader or loader['label'] is None: + raise ValueError( + 'When specifying multiple evaluation datasets, each one must include the \ `label` attribute.') - loaders.append(loader) - else: - assert train_config.eval_loader is not None # pyright being pyright - loaders.append(train_config.eval_loader) + loaders.append(loader) + if train_config.eval_loader is not None: + assert train_config.eval_loaders is None, 'Only one of `eval_loader` or `eval_loaders` should be provided.' + loaders.append(train_config.eval_loader) for loader in loaders: if loader['name'] == 'text': if train_config.model['name'] == 'hf_t5': @@ -194,14 +193,13 @@ def validate_config(train_config: TrainConfig): def main(cfg: DictConfig) -> Trainer: + # Resolve all interpolation variables as early as possible unstructured_config = om.to_container(cfg, resolve=True) assert isinstance(unstructured_config, dict) - # Resolve all interpolation variables as early as possible - om.resolve(unstructured_config) # Structured config does not support unions of containers, so separate single and plural containers if (loader := unstructured_config.get('eval_loader', None)) is not None: - if isinstance(loader, ListConfig): + if isinstance(loader, list) or isinstance(loader, ListConfig): unstructured_config['eval_loaders'] = list( unstructured_config.pop('eval_loader')) if (tasks := unstructured_config.get('icl_tasks', None)) is not None: @@ -441,7 +439,7 @@ def main(cfg: DictConfig) -> Trainer: # Callbacks callbacks: List[Callback] = [ - build_callback(str(name), callback_cfg, om.to_container(logged_cfg)) + build_callback(str(name), callback_cfg, logged_cfg) for name, callback_cfg in callback_configs.items() ] if callback_configs else [] @@ -588,7 +586,7 @@ def main(cfg: DictConfig) -> Trainer: if should_log_config: log.info('Logging config') - log_config(logged_cfg) + log_config(DictConfig(logged_cfg)) torch.cuda.empty_cache() gc.collect() diff --git a/tests/a_scripts/train/test_train.py b/tests/a_scripts/train/test_train.py index 7899eeda0a..9efc04755d 100644 --- a/tests/a_scripts/train/test_train.py +++ b/tests/a_scripts/train/test_train.py @@ -11,7 +11,7 @@ from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om -from llmfoundry.utils.config_utils import update_batch_size_info +from llmfoundry.utils.config_utils import to_str_dict, update_batch_size_info from scripts.train.train import TrainConfig, main, validate_config # noqa: E402 from tests.data_utils import (create_arxiv_dataset, create_c4_dataset_xxsmall, gpt_tiny_cfg) @@ -158,7 +158,8 @@ def test_validate_config(): test_cfg: DictConfig = om.load(f) # type: ignore test_cfg.model.ffn_config.moe_world_size = 4 test_cfg.fsdp_config.use_orig_params = False - test_cfg = update_batch_size_info(test_cfg) + test_cfg_dict = to_str_dict(test_cfg) + test_cfg_dict = update_batch_size_info(test_cfg_dict) with pytest.raises( ValueError, match= diff --git a/tests/a_scripts/train/test_train_inputs.py b/tests/a_scripts/train/test_train_inputs.py index 24cad29a6b..c2dd5b3d27 100644 --- a/tests/a_scripts/train/test_train_inputs.py +++ b/tests/a_scripts/train/test_train_inputs.py @@ -79,7 +79,8 @@ def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None: for param in mandatory_params: orig_param = cfg.pop(param) with pytest.raises( - (omegaconf.errors.MissingMandatoryValue, NameError)): + (omegaconf.errors.MissingMandatoryValue, NameError, + omegaconf.errors.InterpolationKeyError)): main(cfg) cfg[param] = orig_param