diff --git a/.github/container/test-maxtext.sh b/.github/container/test-maxtext.sh index 51f82b9e0..21591c91c 100755 --- a/.github/container/test-maxtext.sh +++ b/.github/container/test-maxtext.sh @@ -33,12 +33,13 @@ usage() { echo " -h, --help Print usage. Some examples: 1. test-maxtext.sh -b 2 --model-name=gpt3-52k 2. test-maxtext.sh -b 2 --model-name=gemma-2b --dtype=fp8 - 3. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --multiprocess - 4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal-flash --dtype=fp8 --steps=10 --output train_output --multiprocess - 5. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --fsdp=8 --data-parallel=8 --multiprocess - 6. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --fsdp=4 --tensor-parallel=2 --data-parallel=8 --multiprocess - 7. test-maxtext.sh -n 16 -b 2 --model-name=llama2-70b --attn-type=cudnn_flash_te --remat-policy=save_dot_except_mlp --steps=10 --output train_output --fsdp=128 --multiprocess - 8. test-maxtext.sh -n 16 -b 2 --model-name=llama2-70b --attn-type=cudnn_flash_te --remat-policy=save_dot_except_mlp --steps=10 --output train_output --fsdp=64 --data-parallel=2 --multiprocess + 3. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess + 4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --fsdp=8 --output train_output --multiprocess -a scan_layers=false max_target_length=4096 use_iota_embed=true logits_dot_in_fp32=false + 5. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --dtype=fp8 --steps=10 --fsdp=8 --output train_output --multiprocess + 6. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --output train_output --fsdp=8 --data-parallel=8 --multiprocess + 7. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal_flash --steps=10 --output train_output --fsdp=4 --tensor-parallel=2 --data-parallel=8 --multiprocess + 8. test-maxtext.sh -n 16 -b 2 --model-name=llama2-70b --attn-type=cudnn_flash_te --remat-policy=save_dot_except_mlp --steps=10 --output train_output --fsdp=128 --multiprocess + 9. test-maxtext.sh -n 16 -b 2 --model-name=llama2-70b --attn-type=cudnn_flash_te --remat-policy=save_dot_except_mlp --steps=10 --output train_output --fsdp=64 --data-parallel=2 --multiprocess Note: a) FSDP and TP needs to defined for use; DP is not necessary to define, it will always be inferred from the other two. @@ -163,6 +164,10 @@ if [ $MODEL != "gpt3-52k" ]; then # gpt3-52k only works with dot_product ADDITIONAL_ARGS+=" attention=${ATTN_TYPE}" fi +if [ ${ATTN_TYPE} == "cudnn_flash_te" ]; then + ENABLE_FUSED_ATTN=1 +fi + # for fp8 runs if [ $DTYPE == "fp8" ]; then ADDITIONAL_ARGS+=" quantization=$DTYPE" @@ -263,4 +268,4 @@ fi echo "Command: python3 $RUN_SETTINGS" python3 $RUN_SETTINGS -echo "Output at ${OUTPUT}" +echo "Output at ${OUTPUT}" \ No newline at end of file