Skip to content

Commit

Permalink
Make test parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Sep 7, 2023
1 parent e119b28 commit 1b8d539
Showing 1 changed file with 76 additions and 39 deletions.
115 changes: 76 additions & 39 deletions tests/distributed/test_model_parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union

import torch
from parameterized import parameterized
Expand Down Expand Up @@ -122,15 +122,23 @@ def _test_model_parallel(
from_config: bool,
with_lazy_load: bool,
parallelize_embeddings: bool,
overwrite_model_config: Optional[Dict[str, str]] = None,
num_neuron_cores: int = NUM_NEURON_CORES_AVAILABLE,
run_test_in_parallel: bool = False,
overwrite_model_config: Optional[Dict[str, str]] = None,
):
if num_neuron_cores < tp_size:
raise ValueError(
"The number of Neuron cores available is lower than the TP size, failing since the test might not be "
"testing what is expected."
)

if run_test_in_parallel and (NUM_NEURON_CORES_AVAILABLE // num_neuron_cores) < 2:
raise ValueError(
"The test cannot be run in parallel because there is not enough Neuron cores available to preserve the "
f"number of Neuron cores requested ({NUM_NEURON_CORES_AVAILABLE} cores available and {num_neuron_cores} "
"were requested)"
)

template_content = None
template_file_path = Path(__file__).parent.resolve() / TEMPLATE_FILE_NAME
with open(template_file_path, "r") as fp:
Expand Down Expand Up @@ -160,37 +168,52 @@ def _test_model_parallel(

cmd = ["torchrun", f"--nproc_per_node={num_neuron_cores}", f"{tmpdirname}/code.py"]

# Original model.
p = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env={"is_parallel": "false", **specialization_env}
)
stdout, stderr = p.communicate()
def get_outputs_from_process(process) -> Tuple[str, str]:
stdout, stderr = process.communicate()

stdout = stdout.decode("utf-8")
stderr = stderr.decode("utf-8")
stdout = stdout.decode("utf-8")
stderr = stderr.decode("utf-8")

if stdout == "":
stdout = "N/A"
if stderr == "":
stderr = "N/A"
if stdout == "":
stdout = "N/A"
if stderr == "":
stderr = "N/A"
return stdout, stderr

full_output = f"Original model standard output:\n{stdout}\nOriginal model standard error:\n{stderr}"
print(full_output)
# When running the test in parallel, we need 2 rendez-vous endpoints: one for the script running the
# original model and one for the script running the parallel model.
rdzv_endpoint_host = "localhost"
rdzv_endpoint_port = 29400

# Parallel model.
p = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env={"is_parallel": "true", **specialization_env}
)
stdout, stderr = p.communicate()
# Original model.
env = {"is_parallel": "false", **specialization_env}
if run_test_in_parallel:
# Setting the rendez-vous endpoint for the original model process.
cmd.insert(1, f"--rdzv_endpoint={rdzv_endpoint_host}:{rdzv_endpoint_port}")
env["NEURON_RT_VISIBLE_CORES"] = f"0-{num_neuron_cores - 1}"

stdout = stdout.decode("utf-8")
stderr = stderr.decode("utf-8")
p_original = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)

if stdout == "":
stdout = "N/A"
if stderr == "":
stderr = "N/A"
# When running tests in parallel, synchronization is done after both processes started.
if not run_test_in_parallel:
stdout, stderr = get_outputs_from_process(p_original)
full_output = f"Original model standard output:\n{stdout}\nOriginal model standard error:\n{stderr}"
print(full_output)

# Parallel model.
env = {"is_parallel": "true", **specialization_env}
if run_test_in_parallel:
# Updating the rendez-vous endpoint for the parallel model process.
cmd[1] = f"--rdzv_endpoint={rdzv_endpoint_host}:{rdzv_endpoint_port + 1}"
env["NEURON_RT_VISIBLE_CORES"] = f"{num_neuron_cores}-{2 * num_neuron_cores - 1}"
p_parallel = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)

if run_test_in_parallel:
stdout, stderr = get_outputs_from_process(p_original)
full_output = f"Original model standard output:\n{stdout}\nOriginal model standard error:\n{stderr}"
print(full_output)

stdout, stderr = get_outputs_from_process(p_parallel)
full_output = f"Parallel model standard output:\n{stdout}\nParallel model standard error:\n{stderr}"
print(full_output)

Expand All @@ -206,8 +229,9 @@ def _test_model_parallel(
@parameterized.expand(MODELS_TO_TEST)
def test_model_parallel_from_config_without_lazy_load(self, model_class_name: str, model_name_or_path: str):
self._test_model_parallel(
num_neuron_cores=2,
num_neuron_cores=8,
tp_size=2,
run_test_in_parallel=True,
model_class_name=model_class_name,
model_name_or_path=model_name_or_path,
from_config=True,
Expand All @@ -233,9 +257,11 @@ def test_model_parallel_from_pretrained_without_lazy_load(self, model_class_name
)
def test_llama_v2_gqa_variants(self):
# MHA setup
# TP size = 4, num_attention_heads = 8, num_key_value_heads = 8
# TP size = 2, num_attention_heads = 8, num_key_value_heads = 8
self._test_model_parallel(
tp_size=4,
num_neuron_cores=8,
tp_size=2,
run_test_in_parallel=True,
model_class_name="LlamaForCausalLM",
model_name_or_path="anushehchaudry/llama-2-tiny-random",
from_config=True,
Expand All @@ -251,7 +277,9 @@ def test_llama_v2_gqa_variants(self):
# GQA setup with num_key_value_heads > tp_size.
# TP size = 2, num_attention_heads = 8, num_key_value_heads = 4
self._test_model_parallel(
num_neuron_cores=8,
tp_size=2,
run_test_in_parallel=True,
model_class_name="LlamaForCausalLM",
model_name_or_path="anushehchaudry/llama-2-tiny-random",
from_config=True,
Expand All @@ -265,49 +293,58 @@ def test_llama_v2_gqa_variants(self):
)

# GQA setup with num_key_value_heads = tp_size.
# TP size = 4, num_attention_heads = 8, num_key_value_heads = 4
# TP size = 8, num_attention_heads = 16, num_key_value_heads = 8
self._test_model_parallel(
tp_size=4,
num_neuron_cores=8,
tp_size=8,
run_test_in_parallel=True,
model_class_name="LlamaForCausalLM",
model_name_or_path="anushehchaudry/llama-2-tiny-random",
from_config=True,
with_lazy_load=False,
parallelize_embeddings=False,
overwrite_model_config={
"num_hidden_layers": "2",
"num_attention_heads": "8",
"num_key_value_heads": "4",
"hidden_size": "32",
"num_attention_heads": "16",
"num_key_value_heads": "8",
},
)

# GQA setup with num_key_value_heads < tp_size.
# TP size = 4, num_attention_heads = 8, num_key_value_heads = 2
# TP size = 8, num_attention_heads = 16, num_key_value_heads = 2
self._test_model_parallel(
tp_size=4,
num_neuron_cores=8,
tp_size=8,
run_test_in_parallel=True,
model_class_name="LlamaForCausalLM",
model_name_or_path="anushehchaudry/llama-2-tiny-random",
from_config=True,
with_lazy_load=False,
parallelize_embeddings=False,
overwrite_model_config={
"num_hidden_layers": "2",
"num_attention_heads": "8",
"hidden_size": "32",
"num_attention_heads": "16",
"num_key_value_heads": "2",
},
)

# MQA setup
# TP size = 4, num_attention_heads = 8, num_key_value_heads = 1
# TP size = 8, num_attention_heads = 16, num_key_value_heads = 1
self._test_model_parallel(
tp_size=4,
num_neuron_cores=8,
tp_size=8,
run_test_in_parallel=True,
model_class_name="LlamaForCausalLM",
model_name_or_path="anushehchaudry/llama-2-tiny-random",
from_config=True,
with_lazy_load=False,
parallelize_embeddings=False,
overwrite_model_config={
"num_hidden_layers": "2",
"num_attention_heads": "8",
"hidden_size": "32",
"num_attention_heads": "16",
"num_key_value_heads": "1",
},
)
Expand Down

0 comments on commit 1b8d539

Please sign in to comment.