diff --git a/README.md b/README.md index 5390a719d..0c379fbe8 100644 --- a/README.md +++ b/README.md @@ -299,6 +299,8 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb There are various other XLA flags users can set to improve performance. For a detailed explanation of these flags, please refer to the [GPU performance](./rosetta/docs/GPU_performance.md) doc. XLA flags can be tuned per workflow. For example, each script in [contrib/gpu/scripts_gpu](https://github.com/google/paxml/tree/main/paxml/contrib/gpu/scripts_gpu) sets its own [XLA flags](https://github.com/google/paxml/blob/93fbc8010dca95af59ab615c366d912136b7429c/paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh#L30-L33). +For a list of previously used XLA flags that are no longer needed, please also refer to the [GPU performance](./rosetta/docs/GPU_performance.md#previously-used-xla-flags) page. + ## Profiling JAX programs on GPU See [this page](./docs/profiling.md) for more information about how to profile JAX programs on GPU. diff --git a/rosetta/docs/GPU_performance.md b/rosetta/docs/GPU_performance.md index c5456e3c4..abe78ae1f 100644 --- a/rosetta/docs/GPU_performance.md +++ b/rosetta/docs/GPU_performance.md @@ -128,6 +128,8 @@ Fine-grain control to improve performance by initializing a NCCL communicator to - --xla_gpu_enable_cudnn_fmha=false (enables XLA pattern matcher to detect multi-headed attention pattern in JAX) - --xla_disable_hlo_passes=<> (turns off specific HLO passes; can be used for debugging) +## Previously used XLA Flags - +The following flags were used previously used but no longer required. +- --xla_gpu_enable_triton_gemm=false (use cuBLAS instead of Trition GeMM kernels); starting from JAX 0.4.32 we don't need it. diff --git a/rosetta/docs/NATIVE_FP8.md b/rosetta/docs/NATIVE_FP8.md index b8c7d46ae..ccd55fe76 100644 --- a/rosetta/docs/NATIVE_FP8.md +++ b/rosetta/docs/NATIVE_FP8.md @@ -112,6 +112,7 @@ In addition to the suggested XLA flags mentioned in [this section](https://githu ```bash export XLA_FLAGS=" \ --xla_gpu_enable_reduction_epilogue_fusion=false \ + --xla_gpu_enable_triton_gemm=false \ --xla_gpu_enable_cudnn_fmha=false \ --xla_gpu_enable_cudnn_layer_norm=true \ --xla_gpu_enable_cublaslt=true \ @@ -124,8 +125,7 @@ python -m paxml.main \ ... ``` -Please ensure you include the first flag, `--xla_gpu_enable_reduction_epilogue_fusion=false` as it is essential for enabling the FP8 functionality. The additional flags primarily focus on performance enhancement and should also prove beneficial for non-FP8 executions. - +Please ensure you include the first two flags, `--xla_gpu_enable_reduction_epilogue_fusion=false` and `--xla_gpu_enable_triton_gemm=false`, as they are essential for enabling the FP8 functionality. The additional flags primarily focus on performance enhancement and should also prove beneficial for non-FP8 executions. ## Transformer Engine vs Native FP8 Support Native XLA-FP8 specifically targets matrix multiplication operations. In contrast, the Transformer Engine focuses on enhancing the overall performance of the entire transformer layer. This encompasses not only the FP8 matrix multiplication but also attention mechanisms, layer normalizations, and other components. diff --git a/rosetta/docs/PGLE.md b/rosetta/docs/PGLE.md index e2230c549..11c8b9be9 100644 --- a/rosetta/docs/PGLE.md +++ b/rosetta/docs/PGLE.md @@ -64,7 +64,7 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_graph_level=0 --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 ---xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 +--xla_gpu_reduce_scatter_combine_threshold_bytes=1073741824 --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true