Skip to content

Commit

Permalink
fix dist
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Apr 18, 2024
1 parent 4c403fd commit f716642
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
4 changes: 2 additions & 2 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def calculate_batch_size_info(


# Coming soon: this conversion math will be done inside Composer Trainer
def update_batch_size_info(cfg: Dict[str, Any]) -> DictConfig:
def update_batch_size_info(cfg: Dict[str, Any]) -> Dict[str, Any]:
device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info(
cfg['global_train_batch_size'], cfg['device_train_microbatch_size'])
cfg['n_gpus'] = dist.get_world_size()
Expand All @@ -102,7 +102,7 @@ def update_batch_size_info(cfg: Dict[str, Any]) -> DictConfig:
if cfg['device_train_microbatch_size'] == 'auto':
cfg['device_eval_batch_size'] = 1 # TODO debug auto eval microbatching
else:
cfg['device_eval_batch_size'] = cfg.device_train_microbatch_size
cfg['device_eval_batch_size'] = cfg['device_train_microbatch_size']
return cfg


Expand Down
4 changes: 4 additions & 0 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,11 @@ class EvalConfig:


def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]:
# Resolve all interpolation variables as early as possible
unstructured_config = om.to_container(cfg, resolve=True)
assert isinstance(unstructured_config, dict)
assert all(isinstance(k, str) for k in unstructured_config.keys())
unstructured_config = {str(k): v for k, v in unstructured_config.items()}

# flatten union types before creating structured config:
if 'eval_gauntlet' in unstructured_config:
Expand Down
2 changes: 2 additions & 0 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ def main(cfg: DictConfig) -> Trainer:
# Resolve all interpolation variables as early as possible
unstructured_config = om.to_container(cfg, resolve=True)
assert isinstance(unstructured_config, dict)
assert all(isinstance(k, str) for k in unstructured_config.keys())
unstructured_config = {str(k): v for k, v in unstructured_config.items()}

# Structured config does not support unions of containers, so separate single and plural containers
if (loader := unstructured_config.get('eval_loader', None)) is not None:
Expand Down

0 comments on commit f716642

Please sign in to comment.