Skip to content

Commit

Permalink
Merge branch 'main' into 24.10-devel
Browse files Browse the repository at this point in the history
  • Loading branch information
DwarKapex committed Sep 18, 2024
2 parents 97232b9 + f116054 commit a8c5273
Show file tree
Hide file tree
Showing 11 changed files with 27 additions and 32 deletions.
1 change: 1 addition & 0 deletions .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pip install ninja && rm -rf ~/.cache/pip
# TransformerEngine now needs JAX at build time
git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE}
pushd ${SRC_PATH_TRANSFORMER_ENGINE}
export NVTE_BUILD_THREADS_PER_JOB=8
python setup.py bdist_wheel && rm -rf build
ls "${SRC_PATH_TRANSFORMER_ENGINE}/dist"
EOF
Expand Down
3 changes: 3 additions & 0 deletions .github/container/build-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUIL
# jaxlib 0.4.32.dev20240808 /opt/jaxlibs/jaxlib
pip list | grep jax

# Ensure directories are readable by all for non-root users
chmod 755 $BUILD_PATH_JAXLIB/*

## Cleanup

pushd $SRC_PATH_JAX
Expand Down
5 changes: 2 additions & 3 deletions .github/container/test-maxtext.sh
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ export NVTE_FUSED_ATTN=${ENABLE_FUSED_ATTN}
export XLA_PYTHON_CLIENT_MEM_FRACTION=${MEM_FRACTION}
export CUDA_DEVICE_MAX_CONNECTIONS=1

export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_triton_gemm=false
--xla_gpu_graph_level=0
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
Expand All @@ -232,8 +232,7 @@ export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_schedule
--xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true
--xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true
--xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_while_loop_double_buffering=true
--xla_gpu_enable_all_gather_combine_by_dim=false
--xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/baselines/test_maxtext_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def test_loss(baseline_filename):
baseline_filepath = os.path.join(baselines_dir, baseline_filename)
test_config = baseline_filename.split(".")[0]
event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/events*")
event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/logdir/events*")
event_file = glob.glob(event_file)[0]
with open(baseline_filepath, "r") as baseline_file:
end_step = json.load(baseline_file)["end_step"]
Expand All @@ -31,7 +31,7 @@ def test_loss(baseline_filename):
def test_step_time(baseline_filename):
baseline_filepath = os.path.join(baselines_dir, baseline_filename)
test_config = baseline_filename.split(".")[0]
event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/events*")
event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/logdir/events*")
event_file = glob.glob(event_file)[0]
with open(baseline_filepath, "r") as baseline_file:
step_time_avg_expected = json.load(baseline_file)["step_time_avg"]
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb

There are various other XLA flags users can set to improve performance. For a detailed explanation of these flags, please refer to the [GPU performance](./rosetta/docs/GPU_performance.md) doc. XLA flags can be tuned per workflow. For example, each script in [contrib/gpu/scripts_gpu](https://github.com/google/paxml/tree/main/paxml/contrib/gpu/scripts_gpu) sets its own [XLA flags](https://github.com/google/paxml/blob/93fbc8010dca95af59ab615c366d912136b7429c/paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh#L30-L33).

For a list of previously used XLA flags that are no longer needed, please also refer to the [GPU performance](./rosetta/docs/GPU_performance.md#previously-used-xla-flags) page.

## Profiling JAX programs on GPU
See [this page](./docs/profiling.md) for more information about how to profile JAX programs on GPU.

Expand Down
6 changes: 5 additions & 1 deletion rosetta/docs/GPU_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ Fine-grain control to improve performance by initializing a NCCL communicator to
- --xla_gpu_enable_cudnn_fmha=false (enables XLA pattern matcher to detect multi-headed attention pattern in JAX)
- --xla_disable_hlo_passes=<> (turns off specific HLO passes; can be used for debugging)

## Previously used XLA Flags


The following flags were used previously used but no longer required.
- --xla_gpu_enable_async_reduce_scatter, --xla_gpu_enable_async_all_reduce, --xla_gpu_enable_async_all_gather ; Turned on by default, no longer needed
- --xla_gpu_enable_highest_priority_async_stream ; Turned on by default
- --xla_gpu_enable_triton_softmax_fusion ; Deprecated, no longer used

13 changes: 5 additions & 8 deletions rosetta/docs/NATIVE_FP8.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,22 +111,19 @@ Enabling this feature is effortless. Users only need to include the option `--fd
In addition to the suggested XLA flags mentioned in [this section](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/rosetta/projects/pax/README.md#xla-flags), we also recommend setting these following XLA flags. The execution script should look like:
```bash
export XLA_FLAGS=" \
--xla_gpu_enable_reduction_epilogue_fusion=false \
--xla_gpu_enable_triton_gemm=false \
--xla_gpu_enable_cudnn_fmha=false \
--xla_gpu_enable_cudnn_layer_norm=true \
--xla_gpu_enable_cublaslt=true \
--xla_gpu_enable_latency_hiding_scheduler=true \
--xla_gpu_all_reduce_combine_threshold_bytes=51200 "
--xla_gpu_enable_pipelined_all_reduce=false \
--xla_gpu_enable_pipelined_all_gather=false \
--xla_gpu_enable_pipelined_reduce_scatter=false \
"
export ENABLE_TE=0
python -m paxml.main \
...
--fdl.USE_FP8=True \
...
```

Please ensure you include the first two flags, `--xla_gpu_enable_reduction_epilogue_fusion=false` and `--xla_gpu_enable_triton_gemm=false`, as they are essential for enabling the FP8 functionality. The additional flags primarily focus on performance enhancement and should also prove beneficial for non-FP8 executions.

Please not that disabling the triton gemm and pipelined collectives is essential for enabling the FP8 functionality and performance.

## Transformer Engine vs Native FP8 Support
Native XLA-FP8 specifically targets matrix multiplication operations. In contrast, the Transformer Engine focuses on enhancing the overall performance of the entire transformer layer. This encompasses not only the FP8 matrix multiplication but also attention mechanisms, layer normalizations, and other components.
Expand Down
1 change: 0 additions & 1 deletion rosetta/docs/PGLE.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_pipelined_reduce_scatter=true
--xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true
--xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_all_gather_combine_by_dim=false
--xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization
Expand Down
8 changes: 2 additions & 6 deletions rosetta/rosetta/projects/maxtext/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,16 @@ In order to obtain the best performance, please set the appropriate XLA flags. W
The [GPU Performance document](../../../docs/GPU_performance.md) provides a detailed description of the XLA flags that can be set to optimize performance. These are the recommended XLA flags to get good performance for MaxText.

```
XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_async_all_gather=true
--xla_gpu_enable_async_reduce_scatter=true
XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_triton_gemm=false
--xla_gpu_graph_level=0
--xla_gpu_enable_async_all_reduce=true
--xla_gpu_graph_level=0
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
--xla_gpu_reduce_scatter_combine_threshold_bytes=134217728
--xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true
--xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true
--xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_all_gather_combine_by_dim=false
--xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization"
Expand Down
8 changes: 1 addition & 7 deletions rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,18 @@ export NCCL_IB_SL=1

# Set XLA Flags
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_async_all_gather=true
--xla_gpu_enable_async_reduce_scatter=true
--xla_gpu_enable_triton_gemm=false
--xla_gpu_graph_level=0
--xla_gpu_enable_async_all_reduce=true
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
--xla_gpu_reduce_scatter_combine_threshold_bytes=134217728
--xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true
--xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true
--xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_all_gather_combine_by_dim=false
--xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization
--xla_gpu_enable_custom_fusions=false
--xla_gpu_enable_address_computation_fusion=false"
--xla_disable_hlo_passes=rematerialization"

# Make directories that may not exist
mkdir -p $BASE_WORKSPACE_DIR
Expand Down
8 changes: 4 additions & 4 deletions rosetta/rosetta/projects/pax/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ The [GPU Performance document](../../../docs/GPU_performance.md) provides a deta
For the the 126M model, we recommend setting `--xla_gpu_all_reduce_combine_threshold_bytes=33554432`, which is different from the value recommended in `paxml/contrib/gpu/scripts_gpu/run_pile_multinode.sh`. To overwrite the default XLA flags set in the script, set the `BASE_XLA_FLAGS` environment variable prior to running `run_pile_multinode` as follows:

```
BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false
--xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true
--xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_all_reduce_combine_threshold_bytes=33554432
--xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true" bash run_pile_multinode.sh ...
BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_triton_gemm=false
--xla_gpu_all_reduce_combine_threshold_bytes=33554432
--xla_gpu_graph_level=0" bash run_pile_multinode.sh ...
```

# Configs
Expand Down

0 comments on commit a8c5273

Please sign in to comment.