Skip to content

Commit

Permalink
fix remat policy and fused attn variable in test-maxtext.sh (#982)
Browse files Browse the repository at this point in the history
1. Fixed the remat policy typo `minimal_flash` 
2. Added few addition args to tune the perf
3. Fixed the enable fused attn variable
  • Loading branch information
kocchop committed Aug 9, 2024
1 parent 15e65e9 commit 1fd097d
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions .github/container/test-maxtext.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ usage() {
echo " -h, --help Print usage. Some examples:
1. test-maxtext.sh -b 2 --model-name=gpt3-52k
2. test-maxtext.sh -b 2 --model-name=gemma-2b --dtype=fp8
3. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --multiprocess
4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal-flash --dtype=fp8 --steps=10 --output train_output --multiprocess
5. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --fsdp=8 --data-parallel=8 --multiprocess
6. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --fsdp=4 --tensor-parallel=2 --data-parallel=8 --multiprocess
7. test-maxtext.sh -n 16 -b 2 --model-name=llama2-70b --attn-type=cudnn_flash_te --remat-policy=save_dot_except_mlp --steps=10 --output train_output --fsdp=128 --multiprocess
8. test-maxtext.sh -n 16 -b 2 --model-name=llama2-70b --attn-type=cudnn_flash_te --remat-policy=save_dot_except_mlp --steps=10 --output train_output --fsdp=64 --data-parallel=2 --multiprocess
3. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess
4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess -a scan_layers=false max_target_length=4096 use_iota_embed=true logits_dot_in_fp32=false
5. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --dtype=fp8 --steps=10 --fsdp=8 --output train_output --multiprocess
6. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --output train_output --fsdp=8 --data-parallel=8 --multiprocess
7. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --output train_output --fsdp=4 --tensor-parallel=2 --data-parallel=8 --multiprocess
8. test-maxtext.sh -n 16 -b 2 --model-name=llama2-70b --attn-type=cudnn_flash_te --remat-policy=save_dot_except_mlp --steps=10 --output train_output --fsdp=128 --multiprocess
9. test-maxtext.sh -n 16 -b 2 --model-name=llama2-70b --attn-type=cudnn_flash_te --remat-policy=save_dot_except_mlp --steps=10 --output train_output --fsdp=64 --data-parallel=2 --multiprocess
Note:
a) FSDP and TP needs to defined for use; DP is not necessary to define, it will always be inferred from the other two.
Expand Down Expand Up @@ -163,6 +164,10 @@ if [ $MODEL != "gpt3-52k" ]; then # gpt3-52k only works with dot_product
ADDITIONAL_ARGS+=" attention=${ATTN_TYPE}"
fi

if [ ${ATTN_TYPE} == "cudnn_flash_te" ]; then
ENABLE_FUSED_ATTN=1
fi

# for fp8 runs
if [ $DTYPE == "fp8" ]; then
ADDITIONAL_ARGS+=" quantization=$DTYPE"
Expand Down Expand Up @@ -263,4 +268,4 @@ fi
echo "Command: python3 $RUN_SETTINGS"
python3 $RUN_SETTINGS

echo "Output at ${OUTPUT}"
echo "Output at ${OUTPUT}"

0 comments on commit 1fd097d

Please sign in to comment.