Skip to content

Commit

Permalink
test-maxtext.sh: set default mem fraction to 0.9.
Browse files Browse the repository at this point in the history
This fixes the gemma-2b example.
  • Loading branch information
sergachev committed Aug 5, 2024
1 parent 3f6999e commit 7c37880
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions .github/container/test-maxtext.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ usage() {
echo " -h, --help Print usage. Some examples:
1. test-maxtext.sh -b 2 --model-name=gpt3-52k
2. test-maxtext.sh -b 2 --model-name=gemma-2b --dtype=fp8
3. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --mem-fraction 0.90 --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --multiprocess
4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --mem-fraction 0.90 --attn-type=cudnn_flash_te --remat-policy=minimal-flash --dtype=fp8 --steps=10 --output train_output --multiprocess
5. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --mem-fraction 0.90 --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --fsdp=8 --data-parallel=8 --multiprocess
6. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --mem-fraction 0.90 --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --fsdp=4 --tensor-parallel=2 --data-parallel=8 --multiprocess
7. test-maxtext.sh -n 16 -b 2 --model-name=llama2-70b --mem-fraction 0.90 --attn-type=cudnn_flash_te --remat-policy=save_dot_except_mlp --steps=10 --output train_output --fsdp=128 --multiprocess
8. test-maxtext.sh -n 16 -b 2 --model-name=llama2-70b --mem-fraction 0.90 --attn-type=cudnn_flash_te --remat-policy=save_dot_except_mlp --steps=10 --output train_output --fsdp=64 --data-parallel=2 --multiprocess
3. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --multiprocess
4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal-flash --dtype=fp8 --steps=10 --output train_output --multiprocess
5. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --fsdp=8 --data-parallel=8 --multiprocess
6. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --fsdp=4 --tensor-parallel=2 --data-parallel=8 --multiprocess
7. test-maxtext.sh -n 16 -b 2 --model-name=llama2-70b --attn-type=cudnn_flash_te --remat-policy=save_dot_except_mlp --steps=10 --output train_output --fsdp=128 --multiprocess
8. test-maxtext.sh -n 16 -b 2 --model-name=llama2-70b --attn-type=cudnn_flash_te --remat-policy=save_dot_except_mlp --steps=10 --output train_output --fsdp=64 --data-parallel=2 --multiprocess
Note:
a) FSDP and TP needs to defined for use; DP is not necessary to define, it will always be inferred from the other two.
Expand All @@ -54,7 +54,7 @@ fi
# Default arguments
HARDWARE='gpu'
OUTPUT=$(mktemp -d)
MEM_FRACTION=0.65
MEM_FRACTION=0.90

MODEL="gpt3-52k"
DECODER_BLOCK=""
Expand Down Expand Up @@ -263,4 +263,4 @@ fi
echo "Command: python3 $RUN_SETTINGS"
python3 $RUN_SETTINGS

echo "Output at ${OUTPUT}"
echo "Output at ${OUTPUT}"

0 comments on commit 7c37880

Please sign in to comment.