From ecacd5b614d2df6d589b8af1e2aad6a570d3511d Mon Sep 17 00:00:00 2001 From: Md Fahim Faysal Khan Date: Thu, 5 Sep 2024 11:57:27 -0700 Subject: [PATCH 1/4] remove deprecated XLA flag (#1010) 1. `xla_gpu_enable_triton_gemm` is still needed. 2. Removed some other deprecated XLA flags: `xla_gpu_enable_triton_softmax_fusion` 3. Also removed some XLA flags that are now turned on by default. `xla_enable_async_all_gather` etc. --- .github/container/test-maxtext.sh | 5 ++--- README.md | 2 ++ rosetta/docs/GPU_performance.md | 6 +++++- rosetta/docs/NATIVE_FP8.md | 13 +++++-------- rosetta/docs/PGLE.md | 1 - rosetta/rosetta/projects/maxtext/README.md | 8 ++------ .../projects/maxtext/scripts/example_slurm.sub | 8 +------- rosetta/rosetta/projects/pax/README.md | 8 ++++---- 8 files changed, 21 insertions(+), 30 deletions(-) diff --git a/.github/container/test-maxtext.sh b/.github/container/test-maxtext.sh index 164fa5912..0dc26c8c1 100755 --- a/.github/container/test-maxtext.sh +++ b/.github/container/test-maxtext.sh @@ -223,7 +223,7 @@ export NVTE_FUSED_ATTN=${ENABLE_FUSED_ATTN} export XLA_PYTHON_CLIENT_MEM_FRACTION=${MEM_FRACTION} export CUDA_DEVICE_MAX_CONNECTIONS=1 -export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true +export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 @@ -232,8 +232,7 @@ export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_schedule --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true - --xla_gpu_enable_while_loop_double_buffering=true - --xla_gpu_enable_triton_softmax_fusion=false + --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization} diff --git a/README.md b/README.md index 66d9b2a4e..1764c5f00 100644 --- a/README.md +++ b/README.md @@ -300,6 +300,8 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb There are various other XLA flags users can set to improve performance. For a detailed explanation of these flags, please refer to the [GPU performance](./rosetta/docs/GPU_performance.md) doc. XLA flags can be tuned per workflow. For example, each script in [contrib/gpu/scripts_gpu](https://github.com/google/paxml/tree/main/paxml/contrib/gpu/scripts_gpu) sets its own [XLA flags](https://github.com/google/paxml/blob/93fbc8010dca95af59ab615c366d912136b7429c/paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh#L30-L33). +For a list of previously used XLA flags that are no longer needed, please also refer to the [GPU performance](./rosetta/docs/GPU_performance.md#previously-used-xla-flags) page. + ## Profiling JAX programs on GPU See [this page](./docs/profiling.md) for more information about how to profile JAX programs on GPU. diff --git a/rosetta/docs/GPU_performance.md b/rosetta/docs/GPU_performance.md index c5456e3c4..fabbc6963 100644 --- a/rosetta/docs/GPU_performance.md +++ b/rosetta/docs/GPU_performance.md @@ -128,6 +128,10 @@ Fine-grain control to improve performance by initializing a NCCL communicator to - --xla_gpu_enable_cudnn_fmha=false (enables XLA pattern matcher to detect multi-headed attention pattern in JAX) - --xla_disable_hlo_passes=<> (turns off specific HLO passes; can be used for debugging) +## Previously used XLA Flags - +The following flags were used previously used but no longer required. +- --xla_gpu_enable_async_reduce_scatter, --xla_gpu_enable_async_all_reduce, --xla_gpu_enable_async_all_gather ; Turned on by default, no longer needed +- --xla_gpu_enable_highest_priority_async_stream ; Turned on by default +- --xla_gpu_enable_triton_softmax_fusion ; Deprecated, no longer used diff --git a/rosetta/docs/NATIVE_FP8.md b/rosetta/docs/NATIVE_FP8.md index dd3aa1bae..069b06fdd 100644 --- a/rosetta/docs/NATIVE_FP8.md +++ b/rosetta/docs/NATIVE_FP8.md @@ -111,13 +111,11 @@ Enabling this feature is effortless. Users only need to include the option `--fd In addition to the suggested XLA flags mentioned in [this section](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/rosetta/projects/pax/README.md#xla-flags), we also recommend setting these following XLA flags. The execution script should look like: ```bash export XLA_FLAGS=" \ - --xla_gpu_enable_reduction_epilogue_fusion=false \ --xla_gpu_enable_triton_gemm=false \ - --xla_gpu_enable_cudnn_fmha=false \ - --xla_gpu_enable_cudnn_layer_norm=true \ - --xla_gpu_enable_cublaslt=true \ - --xla_gpu_enable_latency_hiding_scheduler=true \ - --xla_gpu_all_reduce_combine_threshold_bytes=51200 " + --xla_gpu_enable_pipelined_all_reduce=false \ + --xla_gpu_enable_pipelined_all_gather=false \ + --xla_gpu_enable_pipelined_reduce_scatter=false \ +" export ENABLE_TE=0 python -m paxml.main \ ... @@ -125,8 +123,7 @@ python -m paxml.main \ ... ``` -Please ensure you include the first two flags, `--xla_gpu_enable_reduction_epilogue_fusion=false` and `--xla_gpu_enable_triton_gemm=false`, as they are essential for enabling the FP8 functionality. The additional flags primarily focus on performance enhancement and should also prove beneficial for non-FP8 executions. - +Please not that disabling the triton gemm and pipelined collectives is essential for enabling the FP8 functionality and performance. ## Transformer Engine vs Native FP8 Support Native XLA-FP8 specifically targets matrix multiplication operations. In contrast, the Transformer Engine focuses on enhancing the overall performance of the entire transformer layer. This encompasses not only the FP8 matrix multiplication but also attention mechanisms, layer normalizations, and other components. diff --git a/rosetta/docs/PGLE.md b/rosetta/docs/PGLE.md index 02e5f5294..2425ddffe 100644 --- a/rosetta/docs/PGLE.md +++ b/rosetta/docs/PGLE.md @@ -70,7 +70,6 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true ---xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization diff --git a/rosetta/rosetta/projects/maxtext/README.md b/rosetta/rosetta/projects/maxtext/README.md index fde5a9125..2320a7ed9 100644 --- a/rosetta/rosetta/projects/maxtext/README.md +++ b/rosetta/rosetta/projects/maxtext/README.md @@ -67,12 +67,9 @@ In order to obtain the best performance, please set the appropriate XLA flags. W The [GPU Performance document](../../../docs/GPU_performance.md) provides a detailed description of the XLA flags that can be set to optimize performance. These are the recommended XLA flags to get good performance for MaxText. ``` -XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true - --xla_gpu_enable_async_all_gather=true - --xla_gpu_enable_async_reduce_scatter=true +XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false - --xla_gpu_graph_level=0 - --xla_gpu_enable_async_all_reduce=true + --xla_gpu_graph_level=0 --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 --xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 @@ -80,7 +77,6 @@ XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true - --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization" diff --git a/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub b/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub index e96eaa781..0ca3fd802 100644 --- a/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub +++ b/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub @@ -53,11 +53,8 @@ export NCCL_IB_SL=1 # Set XLA Flags export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true - --xla_gpu_enable_async_all_gather=true - --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 - --xla_gpu_enable_async_all_reduce=true --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 --xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 @@ -65,12 +62,9 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true - --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false - --xla_disable_hlo_passes=rematerialization - --xla_gpu_enable_custom_fusions=false - --xla_gpu_enable_address_computation_fusion=false" + --xla_disable_hlo_passes=rematerialization" # Make directories that may not exist mkdir -p $BASE_WORKSPACE_DIR diff --git a/rosetta/rosetta/projects/pax/README.md b/rosetta/rosetta/projects/pax/README.md index 6ac4dc150..d1829b847 100644 --- a/rosetta/rosetta/projects/pax/README.md +++ b/rosetta/rosetta/projects/pax/README.md @@ -138,10 +138,10 @@ The [GPU Performance document](../../../docs/GPU_performance.md) provides a deta For the the 126M model, we recommend setting `--xla_gpu_all_reduce_combine_threshold_bytes=33554432`, which is different from the value recommended in `paxml/contrib/gpu/scripts_gpu/run_pile_multinode.sh`. To overwrite the default XLA flags set in the script, set the `BASE_XLA_FLAGS` environment variable prior to running `run_pile_multinode` as follows: ``` -BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false - --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true - --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_all_reduce_combine_threshold_bytes=33554432 - --xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true" bash run_pile_multinode.sh ... +BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_enable_triton_gemm=false + --xla_gpu_all_reduce_combine_threshold_bytes=33554432 + --xla_gpu_graph_level=0" bash run_pile_multinode.sh ... ``` # Configs From 44b4dfee401a03c0cf3bebfec8700d9b61eb231f Mon Sep 17 00:00:00 2001 From: Md Fahim Faysal Khan Date: Thu, 5 Sep 2024 21:47:28 -0700 Subject: [PATCH 2/4] fix tensorboard events dir path (#1032) Fixed the tensorboard dir path after a recent change in MaxText software: https://github.com/google/maxtext/pull/863 --- .github/workflows/baselines/test_maxtext_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/baselines/test_maxtext_metrics.py b/.github/workflows/baselines/test_maxtext_metrics.py index bd180ecfe..a130c86c6 100644 --- a/.github/workflows/baselines/test_maxtext_metrics.py +++ b/.github/workflows/baselines/test_maxtext_metrics.py @@ -19,7 +19,7 @@ def test_loss(baseline_filename): baseline_filepath = os.path.join(baselines_dir, baseline_filename) test_config = baseline_filename.split(".")[0] - event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/events*") + event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/logdir/events*") event_file = glob.glob(event_file)[0] with open(baseline_filepath, "r") as baseline_file: end_step = json.load(baseline_file)["end_step"] @@ -31,7 +31,7 @@ def test_loss(baseline_filename): def test_step_time(baseline_filename): baseline_filepath = os.path.join(baselines_dir, baseline_filename) test_config = baseline_filename.split(".")[0] - event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/events*") + event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/logdir/events*") event_file = glob.glob(event_file)[0] with open(baseline_filepath, "r") as baseline_file: step_time_avg_expected = json.load(baseline_file)["step_time_avg"] From f808df5883c23ab0fe9e9310384e0875386a1d8b Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Fri, 6 Sep 2024 15:06:18 -0700 Subject: [PATCH 3/4] Makes jaxlib wheel dirs readable for non-root users (#1023) Example as of 8-28-2024 ``` $ docker run --entrypoint='' --rm -it ghcr.io/nvidia/jax:pax-2024-08-28 ls -lah /opt/jaxlibs total 20K drwxr-xr-x 1 root root 4.0K Aug 28 09:43 . drwxr-xr-x 1 root root 4.0K Aug 28 10:04 .. drwx------ 1 root root 4.0K Aug 28 09:43 jax_gpu_pjrt drwx------ 1 root root 4.0K Aug 28 09:43 jax_gpu_plugin drwx------ 1 root root 4.0K Aug 28 09:43 jaxlib ``` Signed-off-by: Terry Kong --- .github/container/build-jax.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/container/build-jax.sh b/.github/container/build-jax.sh index fa4c055b8..8ff65ca99 100755 --- a/.github/container/build-jax.sh +++ b/.github/container/build-jax.sh @@ -316,6 +316,9 @@ pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUIL # jaxlib 0.4.32.dev20240808 /opt/jaxlibs/jaxlib pip list | grep jax +# Ensure directories are readable by all for non-root users +chmod 755 $BUILD_PATH_JAXLIB/* + ## Cleanup pushd $SRC_PATH_JAX From f116054dbed654cbb280764457fa9a78fe003b51 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Mon, 9 Sep 2024 10:24:07 -0700 Subject: [PATCH 4/4] TE multithread build (#1009) --- .github/container/Dockerfile.jax | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/container/Dockerfile.jax b/.github/container/Dockerfile.jax index c85bee347..726656a7a 100644 --- a/.github/container/Dockerfile.jax +++ b/.github/container/Dockerfile.jax @@ -62,6 +62,7 @@ pip install ninja && rm -rf ~/.cache/pip # TransformerEngine now needs JAX at build time git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE} pushd ${SRC_PATH_TRANSFORMER_ENGINE} +export NVTE_BUILD_THREADS_PER_JOB=8 python setup.py bdist_wheel && rm -rf build ls "${SRC_PATH_TRANSFORMER_ENGINE}/dist" EOF