Skip to content

Commit

Permalink
Merge branch 'dev' into dist-file-utils
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Jun 12, 2024
2 parents 038106a + ba82cc9 commit 3772221
Showing 1 changed file with 1 addition and 25 deletions.
26 changes: 1 addition & 25 deletions composer/distributed/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,36 +328,12 @@ def sync_hook(*args):

mixed_precision = fsdp_config.mixed_precision
keep_low_precision_grads = fsdp_config.keep_low_precision_grads
mixed_precision, param_dtype, _, _ = get_mixed_precision(
mixed_precision, _, _, _ = get_mixed_precision(
precision,
mixed_precision=mixed_precision,
keep_low_precision_grads=keep_low_precision_grads,
)

# Note: FSDP does support the use of torch.float32 with sharding.
# They just never expected a user to pass in torch.float32 into mixed_precision as a param_dtype.
# See: https://github.com/pytorch/pytorch/issues/90584
# The PR fixing this bug is merged into PyTorch, but it hasn't made its way into a release yet.
# Instead a user needs to pass in `None` as param_dtype to have the parameters as torch.float32.
# TODO: remove these checks when PyTorch has a release that includes the fix.
if sharding_map_key != 'NO_SHARD':
if (
precision == Precision.AMP_FP16 and param_dtype not in [torch.float16, None] or
precision == Precision.AMP_BF16 and param_dtype not in [torch.bfloat16, None]
):
raise ValueError(
f'FSDP in PyTorch 1.13 does not support precision `{precision}` with sharding strategy `{sharding_strategy}` '
f'and param_dtype `{param_dtype}.` Consider using one of the predefined mixed_precision strategies '
"(choose: `'FULL'`, `'DEFAULT'`, `'PURE'`)",
)

if param_dtype == torch.float32:
raise ValueError(
f'FSDP in PyTorch 1.13 does not support param_dtype `{param_dtype}` with sharding_strategy `{sharding_map_key}` '
f'Consider using `amp` or `bf16` for precision or setting param_dtype in mixed_precision to `None` '
f'with sharding strategy `{sharding_map_key}.`',
)

process_group = None
if fsdp_config.process_group is not None:
process_group_dict = {'process_group': fsdp_config.process_group}
Expand Down

0 comments on commit 3772221

Please sign in to comment.