Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove FSDP restriction from PyTorch 1.13 #3395

Merged
merged 2 commits into from
Jun 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 1 addition & 25 deletions composer/distributed/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,36 +328,12 @@ def sync_hook(*args):

mixed_precision = fsdp_config.mixed_precision
keep_low_precision_grads = fsdp_config.keep_low_precision_grads
mixed_precision, param_dtype, _, _ = get_mixed_precision(
mixed_precision, _, _, _ = get_mixed_precision(
precision,
mixed_precision=mixed_precision,
keep_low_precision_grads=keep_low_precision_grads,
)

# Note: FSDP does support the use of torch.float32 with sharding.
# They just never expected a user to pass in torch.float32 into mixed_precision as a param_dtype.
# See: https://github.com/pytorch/pytorch/issues/90584
# The PR fixing this bug is merged into PyTorch, but it hasn't made its way into a release yet.
# Instead a user needs to pass in `None` as param_dtype to have the parameters as torch.float32.
# TODO: remove these checks when PyTorch has a release that includes the fix.
if sharding_map_key != 'NO_SHARD':
if (
precision == Precision.AMP_FP16 and param_dtype not in [torch.float16, None] or
precision == Precision.AMP_BF16 and param_dtype not in [torch.bfloat16, None]
):
raise ValueError(
f'FSDP in PyTorch 1.13 does not support precision `{precision}` with sharding strategy `{sharding_strategy}` '
f'and param_dtype `{param_dtype}.` Consider using one of the predefined mixed_precision strategies '
"(choose: `'FULL'`, `'DEFAULT'`, `'PURE'`)",
)

if param_dtype == torch.float32:
raise ValueError(
f'FSDP in PyTorch 1.13 does not support param_dtype `{param_dtype}` with sharding_strategy `{sharding_map_key}` '
f'Consider using `amp` or `bf16` for precision or setting param_dtype in mixed_precision to `None` '
f'with sharding strategy `{sharding_map_key}.`',
)

process_group = None
if fsdp_config.process_group is not None:
process_group_dict = {'process_group': fsdp_config.process_group}
Expand Down
Loading