Skip to content

Commit

Permalink
Model XLA Flags (#1052)
Browse files Browse the repository at this point in the history
Moves XLA flags from model CI into their own files that can be sourced.
Each file can be sourced and will print what it sets.

Some files source other files, which was intentional to avoid
introducing sim-links into the repo, which can sometimes have platform
issues (like on windows).

---------

Signed-off-by: Terry Kong <[email protected]>
  • Loading branch information
terrykong authored Sep 26, 2024
1 parent ccededf commit 1a3febb
Show file tree
Hide file tree
Showing 14 changed files with 111 additions and 0 deletions.
24 changes: 24 additions & 0 deletions rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
set -x
NUM_NODES=1
NUM_GPUS=8
THRESHOLD_BYTES=1073741824
export XLA_FLAGS="\
--xla_gpu_enable_latency_hiding_scheduler=true \
--xla_gpu_enable_triton_gemm=false \
--xla_gpu_graph_level=0 \
--xla_gpu_enable_highest_priority_async_stream=true \
--xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \
--xla_gpu_all_gather_combine_threshold_bytes=$((THRESHOLD_BYTES/(NUM_NODES*NUM_GPUS))) \
--xla_gpu_reduce_scatter_combine_threshold_bytes=$((THRESHOLD_BYTES/(NUM_NODES*NUM_GPUS*2))) \
--xla_gpu_enable_pipelined_all_gather=true \
--xla_gpu_enable_pipelined_reduce_scatter=true \
--xla_gpu_enable_pipelined_all_reduce=true \
--xla_gpu_enable_while_loop_double_buffering=true \
--xla_gpu_enable_triton_softmax_fusion=false \
--xla_gpu_enable_all_gather_combine_by_dim=false \
--xla_gpu_enable_reduce_scatter_combine_by_dim=false \
--xla_disable_hlo_passes=rematerialization \
"
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9
unset NUM_NODES NUM_GPUS THRESHOLD_BYTES
set +x
13 changes: 13 additions & 0 deletions rosetta/rosetta/projects/pax/xla_flags/common.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
set -x
THRESHOLD_BYTES=51200
export XLA_FLAGS="\
--xla_gpu_enable_latency_hiding_scheduler=true \
--xla_allow_excess_precision \
--xla_gpu_enable_highest_priority_async_stream=true \
--xla_gpu_enable_triton_softmax_fusion=false \
--xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \
--xla_gpu_graph_level=0 \
"
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8
unset THRESHOLD_BYTES
set +x
3 changes: 3 additions & 0 deletions rosetta/rosetta/projects/pax/xla_flags/glam-126m64e.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
source $SCRIPT_DIR/common.env
unset SCRIPT_DIR
3 changes: 3 additions & 0 deletions rosetta/rosetta/projects/pax/xla_flags/glam-64b64e.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
source $SCRIPT_DIR/common.env
unset SCRIPT_DIR
14 changes: 14 additions & 0 deletions rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
set -x
THRESHOLD_BYTES=33554432
export XLA_FLAGS="\
--xla_gpu_enable_latency_hiding_scheduler=true \
--xla_allow_excess_precision \
--xla_gpu_enable_highest_priority_async_stream=true \
--xla_gpu_enable_triton_softmax_fusion=false \
--xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \
--xla_gpu_graph_level=0 \
--xla_gpu_enable_cudnn_fmha=false \
"
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8
unset THRESHOLD_BYTES
set +x
3 changes: 3 additions & 0 deletions rosetta/rosetta/projects/pax/xla_flags/gpt-175b.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
source $SCRIPT_DIR/common.env
unset SCRIPT_DIR
3 changes: 3 additions & 0 deletions rosetta/rosetta/projects/pax/xla_flags/gpt-5b.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
source $SCRIPT_DIR/common.env
unset SCRIPT_DIR
25 changes: 25 additions & 0 deletions rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
set -x
ALL_REDUCE_THRESHOLD_BYTES=3221225472
ALL_GATHER_THRESHOLD_BYTES=3221225472
REDUCE_SCATTER_THRESHOLD_BYTES=402653184
export XLA_FLAGS="\
--xla_gpu_enable_latency_hiding_scheduler=true \
--xla_allow_excess_precision \
--xla_gpu_enable_highest_priority_async_stream=true \
--xla_gpu_enable_triton_softmax_fusion=false \
--xla_gpu_all_reduce_combine_threshold_bytes=${ALL_REDUCE_THRESHOLD_BYTES} \
--xla_gpu_graph_level=0 \
--xla_gpu_all_gather_combine_threshold_bytes=${ALL_GATHER_THRESHOLD_BYTES} \
--xla_gpu_reduce_scatter_combine_threshold_bytes=${REDUCE_SCATTER_THRESHOLD_BYTES} \
--xla_gpu_enable_pipelined_all_gather=true \
--xla_gpu_enable_pipelined_reduce_scatter=true \
--xla_gpu_enable_pipelined_all_reduce=true \
--xla_gpu_enable_while_loop_double_buffering=true \
--xla_gpu_enable_all_gather_combine_by_dim=false \
--xla_gpu_enable_reduce_scatter_combine_by_dim=false \
--xla_disable_hlo_passes=rematerialization \
--xla_gpu_enable_custom_fusions=true
"
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9
unset ALL_REDUCE_THRESHOLD_BYTES ALL_GATHER_THRESHOLD_BYTES REDUCE_SCATTER_THRESHOLD_BYTES
set +x
3 changes: 3 additions & 0 deletions rosetta/rosetta/projects/pax/xla_flags/llama-70b.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
source $SCRIPT_DIR/common.env
unset SCRIPT_DIR
4 changes: 4 additions & 0 deletions rosetta/rosetta/projects/pax/xla_flags/llama-7b-lora.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set -x
echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'"
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
set +x
4 changes: 4 additions & 0 deletions rosetta/rosetta/projects/pax/xla_flags/llama-7b.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set -x
echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'"
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8
set +x
4 changes: 4 additions & 0 deletions rosetta/rosetta/projects/t5x/xla_flags/t5.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set -x
echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'"
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8
set +x
4 changes: 4 additions & 0 deletions rosetta/rosetta/projects/vit/xla_flags/vit-base-highgbs.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set -x
echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'"
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.75
set +x
4 changes: 4 additions & 0 deletions rosetta/rosetta/projects/vit/xla_flags/vit-base.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set -x
echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'"
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9
set +x

0 comments on commit 1a3febb

Please sign in to comment.