Skip to content

Commit

Permalink
test-maxtext.sh: set default mem fraction to 0.9 (#979)
Browse files Browse the repository at this point in the history
Co-authored-by: Frédéric Bastien <[email protected]>
  • Loading branch information
sergachev and nouiz committed Aug 6, 2024
1 parent 3f6999e commit 4696e4d
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions .github/container/test-maxtext.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ usage() {
echo ""
echo " OPTIONS DESCRIPTION"
echo " -a, --additional-args Additional args to pass to MaxText/train.py"
echo " --mem-fraction Specify the percentage of memory to preallocate for XLA. Example: 0.90, 0.85, 0.65"
echo " --mem-fraction Specify the percentage of memory to preallocate for XLA. Example: 0.90, 0.85, 0.65". Default to 0.90, contradicting JAX default of 0.75.
echo " --model-name Specify the model names to run [Preferred]. If you specify model name then you do not need to specify decoder-block. Currently supported ootb models:
gemma-2b, gemma-7b, gpt3-175b, gpt3-22b, gpt3-52k, gpt3-6b, llama2-13b, llama2-70b, llama2-7b, llama3-70b, llama3-8b, mistral-7b, mixtral-8x7b"
echo " --decoder-block Specify decoder block to run. Example: llama2, default. Use this option only to define a custom model. This is not preferred, only used in CI"
Expand All @@ -33,12 +33,12 @@ usage() {
echo " -h, --help Print usage. Some examples:
1. test-maxtext.sh -b 2 --model-name=gpt3-52k
2. test-maxtext.sh -b 2 --model-name=gemma-2b --dtype=fp8
3. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --mem-fraction 0.90 --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --multiprocess
4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --mem-fraction 0.90 --attn-type=cudnn_flash_te --remat-policy=minimal-flash --dtype=fp8 --steps=10 --output train_output --multiprocess
5. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --mem-fraction 0.90 --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --fsdp=8 --data-parallel=8 --multiprocess
6. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --mem-fraction 0.90 --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --fsdp=4 --tensor-parallel=2 --data-parallel=8 --multiprocess
7. test-maxtext.sh -n 16 -b 2 --model-name=llama2-70b --mem-fraction 0.90 --attn-type=cudnn_flash_te --remat-policy=save_dot_except_mlp --steps=10 --output train_output --fsdp=128 --multiprocess
8. test-maxtext.sh -n 16 -b 2 --model-name=llama2-70b --mem-fraction 0.90 --attn-type=cudnn_flash_te --remat-policy=save_dot_except_mlp --steps=10 --output train_output --fsdp=64 --data-parallel=2 --multiprocess
3. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --multiprocess
4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal-flash --dtype=fp8 --steps=10 --output train_output --multiprocess
5. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --fsdp=8 --data-parallel=8 --multiprocess
6. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --fsdp=4 --tensor-parallel=2 --data-parallel=8 --multiprocess
7. test-maxtext.sh -n 16 -b 2 --model-name=llama2-70b --attn-type=cudnn_flash_te --remat-policy=save_dot_except_mlp --steps=10 --output train_output --fsdp=128 --multiprocess
8. test-maxtext.sh -n 16 -b 2 --model-name=llama2-70b --attn-type=cudnn_flash_te --remat-policy=save_dot_except_mlp --steps=10 --output train_output --fsdp=64 --data-parallel=2 --multiprocess
Note:
a) FSDP and TP needs to defined for use; DP is not necessary to define, it will always be inferred from the other two.
Expand All @@ -54,7 +54,7 @@ fi
# Default arguments
HARDWARE='gpu'
OUTPUT=$(mktemp -d)
MEM_FRACTION=0.65
MEM_FRACTION=0.90

MODEL="gpt3-52k"
DECODER_BLOCK=""
Expand Down Expand Up @@ -263,4 +263,4 @@ fi
echo "Command: python3 $RUN_SETTINGS"
python3 $RUN_SETTINGS

echo "Output at ${OUTPUT}"
echo "Output at ${OUTPUT}"

0 comments on commit 4696e4d

Please sign in to comment.