Skip to content

Commit

Permalink
polish train
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Apr 18, 2024
1 parent 40324c8 commit e7a2bfc
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 28 deletions.
2 changes: 1 addition & 1 deletion llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def build_optimizer(model: torch.nn.Module, name: str,
optimizer_config[k] = om.to_container(v, resolve=True)

params = _extract_param_groups(model, optimizer_config)
kwargs = optimizer_config
kwargs = {**optimizer_config}

if 'params' in kwargs:
raise ValueError(
Expand Down
18 changes: 9 additions & 9 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,19 @@ def calculate_batch_size_info(


# Coming soon: this conversion math will be done inside Composer Trainer
def update_batch_size_info(cfg: DictConfig) -> DictConfig:
def update_batch_size_info(cfg: Dict[str, Any]) -> DictConfig:
device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info(
cfg.global_train_batch_size, cfg.device_train_microbatch_size)
cfg.n_gpus = dist.get_world_size()
cfg.device_train_batch_size = device_train_batch_size
cfg.device_train_microbatch_size = device_train_microbatch_size
cfg.device_train_grad_accum = device_train_grad_accum
cfg['global_train_batch_size'], cfg['device_train_microbatch_size'])
cfg['n_gpus'] = dist.get_world_size()
cfg['device_train_batch_size'] = device_train_batch_size
cfg['device_train_microbatch_size'] = device_train_microbatch_size
cfg['device_train_grad_accum'] = device_train_grad_accum
# Safely set `device_eval_batch_size` if not provided by user
if 'device_eval_batch_size' not in cfg:
if cfg.device_train_microbatch_size == 'auto':
cfg.device_eval_batch_size = 1 # TODO debug auto eval microbatching
if cfg['device_train_microbatch_size'] == 'auto':
cfg['device_eval_batch_size'] = 1 # TODO debug auto eval microbatching
else:
cfg.device_eval_batch_size = cfg.device_train_microbatch_size
cfg['device_eval_batch_size'] = cfg.device_train_microbatch_size
return cfg


Expand Down
28 changes: 13 additions & 15 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,16 @@ class TrainConfig:
def validate_config(train_config: TrainConfig):
"""Validates compatible model and dataloader selection."""
loaders = [train_config.train_loader]
if train_config.eval_loader is not None or train_config.eval_loaders is not None:
if isinstance(train_config.eval_loaders, list):
for loader in (train_config.eval_loaders or []): # pyright
if 'label' not in loader or loader['label'] is None:
raise ValueError(
'When specifying multiple evaluation datasets, each one must include the \
if train_config.eval_loaders is not None:
for loader in (train_config.eval_loaders or []): # pyright
if 'label' not in loader or loader['label'] is None:
raise ValueError(
'When specifying multiple evaluation datasets, each one must include the \
`label` attribute.')
loaders.append(loader)
else:
assert train_config.eval_loader is not None # pyright being pyright
loaders.append(train_config.eval_loader)
loaders.append(loader)
if train_config.eval_loader is not None:
assert train_config.eval_loaders is None, 'Only one of `eval_loader` or `eval_loaders` should be provided.'
loaders.append(train_config.eval_loader)
for loader in loaders:
if loader['name'] == 'text':
if train_config.model['name'] == 'hf_t5':
Expand Down Expand Up @@ -194,14 +193,13 @@ def validate_config(train_config: TrainConfig):


def main(cfg: DictConfig) -> Trainer:
# Resolve all interpolation variables as early as possible
unstructured_config = om.to_container(cfg, resolve=True)
assert isinstance(unstructured_config, dict)
# Resolve all interpolation variables as early as possible
om.resolve(unstructured_config)

# Structured config does not support unions of containers, so separate single and plural containers
if (loader := unstructured_config.get('eval_loader', None)) is not None:
if isinstance(loader, ListConfig):
if isinstance(loader, list) or isinstance(loader, ListConfig):
unstructured_config['eval_loaders'] = list(
unstructured_config.pop('eval_loader'))
if (tasks := unstructured_config.get('icl_tasks', None)) is not None:
Expand Down Expand Up @@ -441,7 +439,7 @@ def main(cfg: DictConfig) -> Trainer:

# Callbacks
callbacks: List[Callback] = [
build_callback(str(name), callback_cfg, om.to_container(logged_cfg))
build_callback(str(name), callback_cfg, logged_cfg)
for name, callback_cfg in callback_configs.items()
] if callback_configs else []

Expand Down Expand Up @@ -588,7 +586,7 @@ def main(cfg: DictConfig) -> Trainer:

if should_log_config:
log.info('Logging config')
log_config(logged_cfg)
log_config(DictConfig(logged_cfg))
torch.cuda.empty_cache()
gc.collect()

Expand Down
5 changes: 3 additions & 2 deletions tests/a_scripts/train/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om

from llmfoundry.utils.config_utils import update_batch_size_info
from llmfoundry.utils.config_utils import to_str_dict, update_batch_size_info
from scripts.train.train import TrainConfig, main, validate_config # noqa: E402
from tests.data_utils import (create_arxiv_dataset, create_c4_dataset_xxsmall,
gpt_tiny_cfg)
Expand Down Expand Up @@ -158,7 +158,8 @@ def test_validate_config():
test_cfg: DictConfig = om.load(f) # type: ignore
test_cfg.model.ffn_config.moe_world_size = 4
test_cfg.fsdp_config.use_orig_params = False
test_cfg = update_batch_size_info(test_cfg)
test_cfg_dict = to_str_dict(test_cfg)
test_cfg_dict = update_batch_size_info(test_cfg_dict)
with pytest.raises(
ValueError,
match=
Expand Down
3 changes: 2 additions & 1 deletion tests/a_scripts/train/test_train_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None:
for param in mandatory_params:
orig_param = cfg.pop(param)
with pytest.raises(
(omegaconf.errors.MissingMandatoryValue, NameError)):
(omegaconf.errors.MissingMandatoryValue, NameError,
omegaconf.errors.InterpolationKeyError)):
main(cfg)
cfg[param] = orig_param

Expand Down

0 comments on commit e7a2bfc

Please sign in to comment.