Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Remove special handling for FP8 params in FP8 recipe infrastructure #1326

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Nov 9, 2024

Description

#1142 exposed a very subtle bug that caused non-deterministic test failures in test_fusible_ops_with_userbuffers.py.

Bug description

test_fusible_ops_with_userbuffers.py runs multiple test cases at a time because launching a parallel job is expensive, so it constructs and destroys multiple TE models with FP8 parameters. Python IDs may be reused after an object is deallocated, so the Python ID for FP8 tensors is sometimes reused. However, Float8Tensor.post_optimizer_step_fwd_amax_reduction uses Python IDs to check whether to perform amax reductions and FP8 scale updates. I observed that this was causing FP8 scale updates at weird times, which corrupted UB buffers, which caused hangs.

🫠

In short, the problem is from this weird callback in Float8Tensor:

def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None:

This hack was added in #575 so that we would properly update FP8 scales for FP8 params after the optimizer step. However, we've made improvements since then:

Thus, there's no need to do an FP8 scale update for the weights immediately after the optimizer step. We just need to do it sometime before the next optimizer step and there should be no change in numerics. In fact, these FP8 scales are already participating in the forward pass amax reduction and scale update, so avoiding those operations reduces runtime overheads. Also, this just makes Float8Tensor more sane and less tightly coupled with the FP8 recipe infrastructure.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

  • Remove special handling for FP8 params from FP8 recipe infrastructure

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch L1 L3

@ksivaman
Copy link
Member

/te-ci pytorch

@timmoon10 timmoon10 mentioned this pull request Nov 13, 2024
13 tasks
@timmoon10
Copy link
Collaborator Author

The convergence tests in pipeline 20334396 timed out, but all the tests that did run passed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants