Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added model support to test-maxtext.sh #971

Merged
merged 3 commits into from
Aug 3, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 80 additions & 33 deletions .github/container/test-maxtext.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,40 @@ usage() {
echo "Usage: $0 [OPTIONS]"
echo ""
echo " OPTIONS DESCRIPTION"
echo " -a, --additional-args Additional fiddle args to pass to MaxText/train.py"
echo " -a, --additional-args Additional args to pass to MaxText/train.py"
echo " --mem-fraction Specify the percentage of memory to preallocate for XLA. Example: 0.90, 0.85, 0.65"
echo " --decoder-block Specify decoder block to run. Example: llama2, default"
echo " --attn-type Specify the attention type. Example: dot_product, cudnn_flash_te"
echo " --model-name Specify the model names to run [Preffered]. If you specify model name then you do not need to specify decoder-block. Currently supported ootb models:
kocchop marked this conversation as resolved.
Show resolved Hide resolved
nouiz marked this conversation as resolved.
Show resolved Hide resolved
gemma-2b, gemma-7b, gpt3-175b, gpt3-22b, gpt3-52k, gpt3-6b, llama2-13b, llama2-70b, llama2-7b, llama3-70b, llama3-8b, mistral-7b, mixtral-8x7b"
echo " --decoder-block Specify decoder block to run. Example: llama2, default. Use this option only to define a custom model. This is not preferred, only used in CI"
echo " --attn-type Specify the attention type. For gpt3-52k, we only use dot_product since the head_dim=8 is too small. Example: dot_product, cudnn_flash_te"
echo " --remat-policy Specify remat policy. Example: minimal, minimal_flash, save_dot_except_mlp"
echo " -b, --batch-per-gpu Batch size per GPU, defaults to 2."
echo " --dtype Batch size, defaults to bfloat16."
echo " --dtype Data type, defaults to bfloat16. Example: bfloat16, fp8"
echo " -s, --steps Number of steps to run, defaults to 500."
echo " --multiprocess Enable the multiprocess GPU mode."
echo " --multiprocess Enable the multiprocess GPU mode. Should be used when run on multinode"
echo " -o, --output NAME Name for the output folder, a temporary folder will be created if none specified."
echo " --data-parallel Data parallelism to use. Defaults to 1."
echo " --fsdp Fully-sharded data parallelism to use. Defaults to 1."
echo " --tensor-parallel Tensor parallelism to use. Defaults to 1."
echo " --data-parallel Data parallelism to use. Defaults to 1. If specified FSDP dims will be inferred."
echo " --fsdp Fully-sharded data parallelism to use. Defaults to 1. If none of the sharding specs are provided it will assume its FSDP across all available gpus."
echo " --tensor-parallel Tensor parallelism to use. Defaults to 1. If specified, FSDP dims will be inferred."
echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining."
echo " -n, --nodes Number of nodes."
echo " -h, --help Print usage."
echo " -h, --help Print usage. Some examples:
1. test-maxtext.sh -b 2 --model-name=gpt3-52k
2. test-maxtext.sh -b 2 --model-name=gemma-2b --dtype=fp8
3. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --mem-fraction 0.90 --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --multiprocess
4. test-maxtext.sh -n 1 -b 2 --model-name=llama2-7b --mem-fraction 0.90 --attn-type=cudnn_flash_te --remat-policy=minimal-flash --dtype=fp8 --steps=10 --output train_output --multiprocess
5. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --mem-fraction 0.90 --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --fsdp=8 --data-parallel=8 --multiprocess
6. test-maxtext.sh -n 8 -b 2 --model-name=llama2-7b --mem-fraction 0.90 --attn-type=cudnn_flash_te --remat-policy=minimal-flash --steps=10 --output train_output --fsdp=4 --tensor-parallel=2 --data-parallel=8 --multiprocess
7. test-maxtext.sh -n 16 -b 2 --model-name=llama2-70b --mem-fraction 0.90 --attn-type=cudnn_flash_te --remat-policy=save_dot_except_mlp --steps=10 --output train_output --fsdp=128 --multiprocess
8. test-maxtext.sh -n 16 -b 2 --model-name=llama2-70b --mem-fraction 0.90 --attn-type=cudnn_flash_te --remat-policy=save_dot_except_mlp --steps=10 --output train_output --fsdp=64 --data-parallel=2 --multiprocess

Note:
a) FSDP and TP needs to defined for use; DP is not necessary to define, it will always be inferred from the other two.
b) Multinode tests have to be launched with appropriate slurm commands i.e. sbatch and srun"
exit $1
}

args=$(getopt -o a:b:s:o:n:h --long additional-args:,mem-fraction:,decoder-block:,attn-type:,remat-policy:,batch-per-gpu:,dtype:,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@")
args=$(getopt -o a:b:s:o:n:h --long additional-args:,mem-fraction:,model-name:,decoder-block:,attn-type:,remat-policy:,batch-per-gpu:,dtype:,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@")
if [[ $? -ne 0 ]]; then
exit $1
fi
Expand All @@ -42,7 +56,8 @@ HARDWARE='gpu'
OUTPUT=$(mktemp -d)
MEM_FRACTION=0.65

DECODER_BLOCK="default"
MODEL="gpt3-52k"
DECODER_BLOCK=""
ATTN_TYPE="dot_product"
REMAT_POLICY="minimal"
BATCH_PER_GPU=2
Expand All @@ -67,6 +82,10 @@ while [ : ]; do
MEM_FRACTION="$2"
shift 2
;;
--model-name)
MODEL="$2"
shift 2
;;
--decoder-block)
DECODER_BLOCK="$2"
shift 2
Expand Down Expand Up @@ -139,26 +158,39 @@ done

# # Set derived variables

# if not the gpt3-52k, we can use any any attention type such as cudnn_flash_te or dot_product
if [ $MODEL != "gpt3-52k" ]; then # gpt3-52k only works with dot_product
ADDITIONAL_ARGS+=" attention=${ATTN_TYPE}"
fi

# for fp8 runs
if [ $DTYPE == "fp8" ]; then
ADDITIONAL_ARGS+=" quantization=$DTYPE"
fi

GPUS_PER_NODE=$(nvidia-smi -L | grep -c '^GPU')
NGPUS=$((GPUS_PER_NODE * NODES))

# Heuristic to figure out ici and dcn of DP
# We only use DP across different nodes
if [ $NGPUS -gt 8 ]; then
dcn_DP=$((NGPUS / 8))
ici_DP=$((DP / dcn_DP))
# TP is always ici; after TP it will be FSDP and
# from TP and FSDP, we can find out ici and dcn DP
# in other words, DP dim across ici and dcn axis will always be inferred
ici_TP=${TP}
ici_DP=1
dcn_FSDP=1
if [ $((FSDP*TP)) -gt ${GPUS_PER_NODE} ]; then
ici_FSDP=$((GPUS_PER_NODE/TP))
dcn_FSDP=$((FSDP/ici_FSDP))
dcn_DP=$((NGPUS/(ici_FSDP*ici_TP*ici_DP*dcn_FSDP)))
else
dcn_DP=1
ici_DP=$DP
fi

if [[ $ATTN_TYPE == 'cudnn_flash_te' ]]; then
ENABLE_FUSED_ATTN=1
REMAT_POLICY="minimal_flash"
ici_FSDP=$FSDP
ici_DP=$((GPUS_PER_NODE/(FSDP*TP)))
dcn_DP=$((NGPUS/(ici_FSDP*ici_TP*ici_DP*dcn_FSDP)))
fi

print_var ADDITIONAL_ARGS
print_var MEM_FRACTION
print_var MODEL
print_var DECODER_BLOCK
print_var ATTN_TYPE
print_var REMAT_POLICY
Expand All @@ -172,8 +204,9 @@ print_var ENABLE_FUSED_ATTN
print_var DP
print_var ici_DP
print_var dcn_DP
print_var FSDP
print_var TP
print_var ici_FSDP
print_var dcn_FSDP
print_var ici_TP
print_var PP

MAXTEXT_DIR="/opt/maxtext"
Expand All @@ -191,7 +224,7 @@ export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_schedule
--xla_gpu_enable_highest_priority_async_stream=true
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
--xla_gpu_all_gather_combine_threshold_bytes=1073741824
--xla_gpu_reduce_scatter_combine_threshold_bytes=1073741824
--xla_gpu_reduce_scatter_combine_threshold_bytes=134217728
nouiz marked this conversation as resolved.
Show resolved Hide resolved
--xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true
--xla_gpu_enable_pipelined_all_reduce=true
Expand All @@ -205,15 +238,29 @@ export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}"

RUN_NAME="logdir" ## the RUN_NAME cannot be changed

RUN_SETTINGS="MaxText/train.py MaxText/configs/base.yml run_name=${RUN_NAME} logits_via_embedding=true decoder_block=${DECODER_BLOCK} \
steps=$STEPS per_device_batch_size=${BATCH_PER_GPU} base_emb_dim=2560 base_mlp_dim=8192 remat_policy=${REMAT_POLICY} attention=${ATTN_TYPE}\
base_num_query_heads=8 base_num_kv_heads=8 base_num_decoder_layers=8 head_dim=128 enable_checkpointing=false\
base_output_directory=$OUTPUT dataset_path=local dataset_type=synthetic hardware=$HARDWARE\
dcn_fsdp_parallelism=1 ici_fsdp_parallelism=$FSDP\
ici_data_parallelism=$ici_DP dcn_data_parallelism=$dcn_DP\
ici_tensor_parallelism=$TP dcn_tensor_parallelism=1 ${ADDITIONAL_ARGS}"

if [ -z "$DECODER_BLOCK" ]; then

# this part could be used to test different model ootb
RUN_SETTINGS="MaxText/train.py MaxText/configs/base.yml run_name=${RUN_NAME} model_name=${MODEL}\
steps=$STEPS per_device_batch_size=${BATCH_PER_GPU} remat_policy=${REMAT_POLICY} enable_checkpointing=false\
base_output_directory=$OUTPUT dataset_path=local dataset_type=synthetic hardware=$HARDWARE\
dcn_fsdp_parallelism=$dcn_FSDP ici_fsdp_parallelism=$ici_FSDP\
ici_data_parallelism=$ici_DP dcn_data_parallelism=$dcn_DP\
ici_tensor_parallelism=$ici_TP dcn_tensor_parallelism=1 ${ADDITIONAL_ARGS}"

else
# this is essentially used for CI run
RUN_SETTINGS="MaxText/train.py MaxText/configs/base.yml run_name=${RUN_NAME} logits_via_embedding=true decoder_block=${DECODER_BLOCK} \
steps=$STEPS per_device_batch_size=${BATCH_PER_GPU} base_emb_dim=2560 base_mlp_dim=8192 remat_policy=${REMAT_POLICY} attention=${ATTN_TYPE}\
base_num_query_heads=8 base_num_kv_heads=8 base_num_decoder_layers=8 head_dim=128 enable_checkpointing=false\
base_output_directory=$OUTPUT dataset_path=local dataset_type=synthetic hardware=$HARDWARE\
dcn_fsdp_parallelism=$dcn_FSDP ici_fsdp_parallelism=$ici_FSDP\
ici_data_parallelism=$ici_DP dcn_data_parallelism=$dcn_DP\
ici_tensor_parallelism=$ici_TP dcn_tensor_parallelism=1 ${ADDITIONAL_ARGS}"
fi

echo "Command: python3 $RUN_SETTINGS"
python3 $RUN_SETTINGS

echo "Output at ${OUTPUT}"
echo "Output at ${OUTPUT}"
Loading