diff --git a/.github/container/Dockerfile.jax b/.github/container/Dockerfile.jax index f526c9178..c85bee347 100644 --- a/.github/container/Dockerfile.jax +++ b/.github/container/Dockerfile.jax @@ -84,6 +84,7 @@ ENV BUILD_DATE=${BUILD_DATE} # The following environment variables tune performance ENV XLA_FLAGS="" ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_latency_hiding_scheduler=true" +ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_triton_gemm=false" ENV CUDA_DEVICE_MAX_CONNECTIONS=1 ENV NCCL_NVLS_ENABLE=0 diff --git a/.github/container/patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch b/.github/container/patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch index 1ca04c711..d4b3917cf 100644 --- a/.github/container/patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch +++ b/.github/container/patches/t5x/mirror-patch-t5x_te_in_contrib_noindent.patch @@ -637,7 +637,7 @@ index 89974dd..388d2ec 100755 -MODEL_DIR_LOCAL=${7:-"model_dir"} -MODEL_DIR=${PWD}/${MODEL_DIR_LOCAL} -NUM_MICROBATCHES=${8:-0} -+export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}" ++export XLA_FLAGS="--xla_gpu_enable_triton_gemm=false --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}" + +#! Change these values !# +FT_TASK=${FT_TASK:=mnli2} # currently supported: mnli2, squad1 @@ -751,7 +751,7 @@ index 18bb722..f807105 100755 -MODEL_DIR=${PWD}/${MODEL_DIR_LOCAL} -NUM_MICROBATCHES=${6:-0} -MP=${7:-1} -+export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}" ++export XLA_FLAGS="--xla_gpu_enable_triton_gemm=false --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}" -echo Model Parallel partitions: ${MP} +#! Change these values !# @@ -3323,8 +3323,8 @@ index cd563ec..e075df3 100755 -set -x +set -eoux pipefail --export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}" -+export BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592}" +-export XLA_FLAGS="--xla_gpu_enable_triton_gemm=false --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}" ++export BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_gpu_enable_triton_gemm=false --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592}" +export XLA_FLAGS="${BASE_XLA_FLAGS} ${XLA_FLAGS:-}" #! Change these values !# @@ -3442,8 +3442,8 @@ index d083540..56919a5 100755 #BENCHMARK_MODE=True STAT_PERIOD=100 #only used if BENCHMARK_MODE is set --export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}" -+export BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592}" +-export XLA_FLAGS="--xla_gpu_enable_triton_gemm=false --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592 ${XLA_FLAGS}" ++export BASE_XLA_FLAGS="${BASE_XLA_FLAGS:---xla_gpu_enable_triton_gemm=false --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_all_reduce_combine_threshold_bytes=8589934592}" +export XLA_FLAGS="${BASE_XLA_FLAGS} ${XLA_FLAGS:-}" #! Change these values !# diff --git a/.github/container/test-maxtext.sh b/.github/container/test-maxtext.sh index 1d8cc50f9..0377237f5 100755 --- a/.github/container/test-maxtext.sh +++ b/.github/container/test-maxtext.sh @@ -223,7 +223,8 @@ export NVTE_FUSED_ATTN=${ENABLE_FUSED_ATTN} export XLA_PYTHON_CLIENT_MEM_FRACTION=${MEM_FRACTION} export CUDA_DEVICE_MAX_CONNECTIONS=1 -export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true +export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 diff --git a/README.md b/README.md index 0c379fbe8..1764c5f00 100644 --- a/README.md +++ b/README.md @@ -291,6 +291,7 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb | XLA Flags | Value | Explanation | | --------- | ----- | ----------- | | `--xla_gpu_enable_latency_hiding_scheduler` | `true` | allows XLA to move communication collectives to increase overlap with compute kernels | +| `--xla_gpu_enable_triton_gemm` | `false` | use cuBLAS instead of Trition GeMM kernels | | Environment Variable | Value | Explanation | | -------------------- | ----- | ----------- | diff --git a/rosetta/docs/GPU_performance.md b/rosetta/docs/GPU_performance.md index abe78ae1f..fabbc6963 100644 --- a/rosetta/docs/GPU_performance.md +++ b/rosetta/docs/GPU_performance.md @@ -131,5 +131,7 @@ Fine-grain control to improve performance by initializing a NCCL communicator to ## Previously used XLA Flags The following flags were used previously used but no longer required. -- --xla_gpu_enable_triton_gemm=false (use cuBLAS instead of Trition GeMM kernels); starting from JAX 0.4.32 we don't need it. +- --xla_gpu_enable_async_reduce_scatter, --xla_gpu_enable_async_all_reduce, --xla_gpu_enable_async_all_gather ; Turned on by default, no longer needed +- --xla_gpu_enable_highest_priority_async_stream ; Turned on by default +- --xla_gpu_enable_triton_softmax_fusion ; Deprecated, no longer used diff --git a/rosetta/docs/NATIVE_FP8.md b/rosetta/docs/NATIVE_FP8.md index ccd55fe76..127b3ca4c 100644 --- a/rosetta/docs/NATIVE_FP8.md +++ b/rosetta/docs/NATIVE_FP8.md @@ -111,13 +111,11 @@ Enabling this feature is effortless. Users only need to include the option `--fd In addition to the suggested XLA flags mentioned in [this section](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/rosetta/projects/pax/README.md#xla-flags), we also recommend setting these following XLA flags. The execution script should look like: ```bash export XLA_FLAGS=" \ - --xla_gpu_enable_reduction_epilogue_fusion=false \ --xla_gpu_enable_triton_gemm=false \ - --xla_gpu_enable_cudnn_fmha=false \ - --xla_gpu_enable_cudnn_layer_norm=true \ - --xla_gpu_enable_cublaslt=true \ - --xla_gpu_enable_latency_hiding_scheduler=true \ - --xla_gpu_all_reduce_combine_threshold_bytes=51200 " + --xla_gpu_enable_pipelined_all_reduce=false \ + --xla_gpu_enable_pipelined_all_gather=false \ + --xla_gpu_enable_pipelined_reduce_scatter=false \ +" export ENABLE_TE=0 python -m paxml.main \ ... @@ -125,7 +123,7 @@ python -m paxml.main \ ... ``` -Please ensure you include the first two flags, `--xla_gpu_enable_reduction_epilogue_fusion=false` and `--xla_gpu_enable_triton_gemm=false`, as they are essential for enabling the FP8 functionality. The additional flags primarily focus on performance enhancement and should also prove beneficial for non-FP8 executions. +Please not that disabling the triton gemm and pipelined collectives are essential for enabling the FP8 functionality and performance. ## Transformer Engine vs Native FP8 Support Native XLA-FP8 specifically targets matrix multiplication operations. In contrast, the Transformer Engine focuses on enhancing the overall performance of the entire transformer layer. This encompasses not only the FP8 matrix multiplication but also attention mechanisms, layer normalizations, and other components. diff --git a/rosetta/docs/PGLE.md b/rosetta/docs/PGLE.md index 11c8b9be9..dcf925fb9 100644 --- a/rosetta/docs/PGLE.md +++ b/rosetta/docs/PGLE.md @@ -61,6 +61,7 @@ PGLE found latency for async op custom-call-start.1 and (assumed)custom-call-don In order to get the best performance with PGLE, here is a list of all recommended XLA flags: ``` export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true +--xla_gpu_enable_triton_gemm-false --xla_gpu_graph_level=0 --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 diff --git a/rosetta/rosetta/projects/maxtext/README.md b/rosetta/rosetta/projects/maxtext/README.md index 6199c2df8..b137edfd0 100644 --- a/rosetta/rosetta/projects/maxtext/README.md +++ b/rosetta/rosetta/projects/maxtext/README.md @@ -67,7 +67,8 @@ In order to obtain the best performance, please set the appropriate XLA flags. W The [GPU Performance document](../../../docs/GPU_performance.md) provides a detailed description of the XLA flags that can be set to optimize performance. These are the recommended XLA flags to get good performance for MaxText. ``` -XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true +XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_enable_triton_gemm-false --xla_gpu_graph_level=0 --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 diff --git a/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub b/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub index e957d01d6..93894c75d 100644 --- a/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub +++ b/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub @@ -53,6 +53,7 @@ export NCCL_IB_SL=1 # Set XLA Flags export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_enable_triton_gemm-false --xla_gpu_graph_level=0 --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 diff --git a/rosetta/rosetta/projects/pax/README.md b/rosetta/rosetta/projects/pax/README.md index 737550c06..d1829b847 100644 --- a/rosetta/rosetta/projects/pax/README.md +++ b/rosetta/rosetta/projects/pax/README.md @@ -139,7 +139,7 @@ For the the 126M model, we recommend setting `--xla_gpu_all_reduce_combine_thres ``` BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true - --xla_gpu_enable_triton_softmax_fusion=false + --xla_gpu_enable_triton_gemm=false --xla_gpu_all_reduce_combine_threshold_bytes=33554432 --xla_gpu_graph_level=0" bash run_pile_multinode.sh ... ```