[PyTorch] Remove special handling for FP8 params in FP8 recipe infrastructure #1326
+11
−75
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
#1142 exposed a very subtle bug that caused non-deterministic test failures in
test_fusible_ops_with_userbuffers.py
.Bug description
test_fusible_ops_with_userbuffers.py
runs multiple test cases at a time because launching a parallel job is expensive, so it constructs and destroys multiple TE models with FP8 parameters. Python IDs may be reused after an object is deallocated, so the Python ID for FP8 tensors is sometimes reused. However,Float8Tensor.post_optimizer_step_fwd_amax_reduction
uses Python IDs to check whether to perform amax reductions and FP8 scale updates. I observed that this was causing FP8 scale updates at weird times, which corrupted UB buffers, which caused hangs.🫠
In short, the problem is from this weird callback in
Float8Tensor
:TransformerEngine/transformer_engine/pytorch/tensor/float8_tensor.py
Line 77 in 2643ba1
This hack was added in #575 so that we would properly update FP8 scales for FP8 params after the optimizer step. However, we've made improvements since then:
Thus, there's no need to do an FP8 scale update for the weights immediately after the optimizer step. We just need to do it sometime before the next optimizer step and there should be no change in numerics. In fact, these FP8 scales are already participating in the forward pass amax reduction and scale update, so avoiding those operations reduces runtime overheads. Also, this just makes
Float8Tensor
more sane and less tightly coupled with the FP8 recipe infrastructure.Type of change
Changes
Checklist: