Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 AllGather Support in Fairscale #1185

Open
wants to merge 21 commits into
base: ngoyal_changes_for_pp_fp8_jiecaoyu_debug
Choose a base branch
from

Conversation

levendlee
Copy link
Member

What does this PR do?

Fixes # (issue).

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

ngoyal2707 and others added 21 commits March 29, 2024 15:12
This commit works with a 4 GPU run on SMALL model with FSDP and PP
enabled.
- Clean up flatten and non_flatten parameter generation logic.
- Avoid checking `main_grad` attribute all equal to zeros.
- Cleans up amax and scale update logic. Amax and scale should be
  done for both weights and parameters. So it should be done at
  forward of each microbatch.

- Consolidate `cast_params` and `all_gather` stream.
This commit works with a 4 GPU run on SMALL model with FSDP and PP
enabled.
- Clean up flatten and non_flatten parameter generation logic.
- Avoid checking `main_grad` attribute all equal to zeros.
- Cleans up amax and scale update logic. Amax and scale should be
  done for both weights and parameters. So it should be done at
  forward of each microbatch.

- Consolidate `cast_params` and `all_gather` stream.
…kresearch/fairscale into shikaili_fp8_allgather_no_pp_fix
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 20, 2024
Copy link

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @levendlee for the great work! I left some comments for my own learning.

and all(_is_te_module_with_weights(info[1]) for info in p._param_infos))
if fused_wgard_accumulation:
if getattr(p, "main_grad", None) is None:
p.main_grad = torch.empty_like(p, dtype=torch.float32)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, why empty_like instead of zeros_like?

if params is None:
params = self.params
with torch.cuda.stream(self._streams["fp32_to_fp16"]):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why did you use the "all_gather" stream instead of the "fp32_to_fp16" stream?

@@ -2087,6 +2179,9 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:

self.has_full_params = False

if self.fp8_all_gather:
self._update_amax_and_scale_fwd(is_first_microbatch_fwd=is_first_microbatch_fwd)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, is there a reason that this is not done together with _cast_params_for_all_gather? (For example, could this call be delayed a few lines to below where _cast_params_for_all_gather is called?)




@torch.no_grad()
def _rebuild_full_params(self, force_full_precision: bool = False, wait_for_all_gather = True) -> Optional[List[Tuple[torch.Tensor, bool]]]:
def _rebuild_full_params(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For fp8_all_gather=True, what happens when this method is called without the TE autocast context?

@@ -1448,16 +1505,22 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:

# All-gather full parameters. This will also transfer FP32 parameters to
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
self._rebuild_full_params()
self.module.has_unflatten_views = getattr(self.module, "has_unflatten_views", False)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants