diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 2860b82059..f460da8344 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -276,6 +276,10 @@ def make_dataclass_and_log_config( dataclass_dict_config: DictConfig = om.structured( dataclass_constructor(**unstructured_config)) + # Error on missing mandatory values: + for key in dataclass_fields: + _ = dataclass_dict_config[key] + # Convert DictConfig to dict for dataclass constructor so that child # configs are not DictConfigs dataclass_config: T = dataclass_constructor( diff --git a/scripts/train/train.py b/scripts/train/train.py index 139f2ff7ff..19693852a2 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -44,10 +44,6 @@ def validate_config(train_config: TrainConfig): """Validates compatible model and dataloader selection.""" - # Check for missing mandatory fields and throw error early. - for field in TRAIN_CONFIG_KEYS: - _ = getattr(train_config, field) - # Validate the rest of the config loaders = [train_config.train_loader] if train_config.eval_loaders is not None: