Skip to content

Commit

Permalink
reverted the changes related to triton_gemm; needed for fp8 functiona…
Browse files Browse the repository at this point in the history
…lity
  • Loading branch information
kocchop committed Sep 3, 2024
1 parent 1abffe4 commit 76f222a
Show file tree
Hide file tree
Showing 10 changed files with 23 additions and 17 deletions.
1 change: 1 addition & 0 deletions .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ ENV BUILD_DATE=${BUILD_DATE}
# The following environment variables tune performance
ENV XLA_FLAGS=""
ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_latency_hiding_scheduler=true"
ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_triton_gemm=false"
ENV CUDA_DEVICE_MAX_CONNECTIONS=1
ENV NCCL_NVLS_ENABLE=0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ index 89974dd..388d2ec 100755
-MODEL_DIR_LOCAL=${7:-"model_dir"}
-MODEL_DIR=${PWD}/${MODEL_DIR_LOCAL}
-NUM_MICROBATCHES=${8:-0}
+export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}"
+export XLA_FLAGS="--xla_gpu_enable_triton_gemm=false --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}"
+
+#! Change these values !#
+FT_TASK=${FT_TASK:=mnli2} # currently supported: mnli2, squad1
Expand Down Expand Up @@ -751,7 +751,7 @@ index 18bb722..f807105 100755
-MODEL_DIR=${PWD}/${MODEL_DIR_LOCAL}
-NUM_MICROBATCHES=${6:-0}
-MP=${7:-1}
+export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}"
+export XLA_FLAGS="--xla_gpu_enable_triton_gemm=false --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}"

-echo Model Parallel partitions: ${MP}
+#! Change these values !#
Expand Down Expand Up @@ -3323,8 +3323,8 @@ index cd563ec..e075df3 100755
-set -x
+set -eoux pipefail

-export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}"
+export BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592}"
-export XLA_FLAGS="--xla_gpu_enable_triton_gemm=false --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}"
+export BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_gpu_enable_triton_gemm=false --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592}"
+export XLA_FLAGS="${BASE_XLA_FLAGS} ${XLA_FLAGS:-}"

#! Change these values !#
Expand Down Expand Up @@ -3442,8 +3442,8 @@ index d083540..56919a5 100755
#BENCHMARK_MODE=True
STAT_PERIOD=100 #only used if BENCHMARK_MODE is set

-export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}"
+export BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592}"
-export XLA_FLAGS="--xla_gpu_enable_triton_gemm=false --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}"
+export BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_gpu_enable_triton_gemm=false --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592}"
+export XLA_FLAGS="${BASE_XLA_FLAGS} ${XLA_FLAGS:-}"

#! Change these values !#
Expand Down
3 changes: 2 additions & 1 deletion .github/container/test-maxtext.sh
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ export NVTE_FUSED_ATTN=${ENABLE_FUSED_ATTN}
export XLA_PYTHON_CLIENT_MEM_FRACTION=${MEM_FRACTION}
export CUDA_DEVICE_MAX_CONNECTIONS=1

export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_triton_gemm=false
--xla_gpu_graph_level=0
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb
| XLA Flags | Value | Explanation |
| --------- | ----- | ----------- |
| `--xla_gpu_enable_latency_hiding_scheduler` | `true` | allows XLA to move communication collectives to increase overlap with compute kernels |
| `--xla_gpu_enable_triton_gemm` | `false` | use cuBLAS instead of Trition GeMM kernels |

| Environment Variable | Value | Explanation |
| -------------------- | ----- | ----------- |
Expand Down
4 changes: 3 additions & 1 deletion rosetta/docs/GPU_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,7 @@ Fine-grain control to improve performance by initializing a NCCL communicator to
## Previously used XLA Flags

The following flags were used previously used but no longer required.
- --xla_gpu_enable_triton_gemm=false (use cuBLAS instead of Trition GeMM kernels); starting from JAX 0.4.32 we don't need it.
- --xla_gpu_enable_async_reduce_scatter, --xla_gpu_enable_async_all_reduce, --xla_gpu_enable_async_all_gather ; Turned on by default, no longer needed
- --xla_gpu_enable_highest_priority_async_stream ; Turned on by default
- --xla_gpu_enable_triton_softmax_fusion ; Deprecated, no longer used

12 changes: 5 additions & 7 deletions rosetta/docs/NATIVE_FP8.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,19 @@ Enabling this feature is effortless. Users only need to include the option `--fd
In addition to the suggested XLA flags mentioned in [this section](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/rosetta/projects/pax/README.md#xla-flags), we also recommend setting these following XLA flags. The execution script should look like:
```bash
export XLA_FLAGS=" \
--xla_gpu_enable_reduction_epilogue_fusion=false \
--xla_gpu_enable_triton_gemm=false \
--xla_gpu_enable_cudnn_fmha=false \
--xla_gpu_enable_cudnn_layer_norm=true \
--xla_gpu_enable_cublaslt=true \
--xla_gpu_enable_latency_hiding_scheduler=true \
--xla_gpu_all_reduce_combine_threshold_bytes=51200 "
--xla_gpu_enable_pipelined_all_reduce=false \
--xla_gpu_enable_pipelined_all_gather=false \
--xla_gpu_enable_pipelined_reduce_scatter=false \
"
export ENABLE_TE=0
python -m paxml.main \
...
--fdl.USE_FP8=True \
...
```

Please ensure you include the first two flags, `--xla_gpu_enable_reduction_epilogue_fusion=false` and `--xla_gpu_enable_triton_gemm=false`, as they are essential for enabling the FP8 functionality. The additional flags primarily focus on performance enhancement and should also prove beneficial for non-FP8 executions.
Please not that disabling the triton gemm and pipelined collectives are essential for enabling the FP8 functionality and performance.

## Transformer Engine vs Native FP8 Support
Native XLA-FP8 specifically targets matrix multiplication operations. In contrast, the Transformer Engine focuses on enhancing the overall performance of the entire transformer layer. This encompasses not only the FP8 matrix multiplication but also attention mechanisms, layer normalizations, and other components.
Expand Down
1 change: 1 addition & 0 deletions rosetta/docs/PGLE.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ PGLE found latency for async op custom-call-start.1 and (assumed)custom-call-don
In order to get the best performance with PGLE, here is a list of all recommended XLA flags:
```
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_triton_gemm-false
--xla_gpu_graph_level=0
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
Expand Down
3 changes: 2 additions & 1 deletion rosetta/rosetta/projects/maxtext/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ In order to obtain the best performance, please set the appropriate XLA flags. W
The [GPU Performance document](../../../docs/GPU_performance.md) provides a detailed description of the XLA flags that can be set to optimize performance. These are the recommended XLA flags to get good performance for MaxText.

```
XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_triton_gemm-false
--xla_gpu_graph_level=0
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
Expand Down
1 change: 1 addition & 0 deletions rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ export NCCL_IB_SL=1

# Set XLA Flags
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_triton_gemm-false
--xla_gpu_graph_level=0
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
Expand Down
2 changes: 1 addition & 1 deletion rosetta/rosetta/projects/pax/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ For the the 126M model, we recommend setting `--xla_gpu_all_reduce_combine_thres

```
BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_triton_gemm=false
--xla_gpu_all_reduce_combine_threshold_bytes=33554432
--xla_gpu_graph_level=0" bash run_pile_multinode.sh ...
```
Expand Down

0 comments on commit 76f222a

Please sign in to comment.