Skip to content

Commit

Permalink
[WIP]
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Sep 15, 2023
1 parent 5f63f45 commit 70c3e5c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
2 changes: 1 addition & 1 deletion optimum/neuron/distributed/parallel_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# from torch_neuronx.xla_impl.ops import SimpleCrossEntropyLoss
# output = SimpleCrossEntropyLoss.gen_override().forward(self, input, target)
output = safe_parallel_cross_entropy(
input.clone(),
input,
target,
weight=self.weight,
ignore_index=self.ignore_index,
Expand Down
4 changes: 2 additions & 2 deletions optimum/neuron/utils/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ class Precision(str, Enum):
bf16 = "bf16"


def run_command_with_realtime_output(cmd: List[str]) -> Tuple[int, str]:
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
def run_command_with_realtime_output(cmd: List[str], **popen_kwargs) -> Tuple[int, str]:
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **popen_kwargs)
stdout = []
decoder = codecs.getincrementaldecoder("utf-8")()
while True:
Expand Down
27 changes: 15 additions & 12 deletions tests/distributed/test_model_parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

from optimum.neuron.utils.cache_utils import get_num_neuron_cores, set_neuron_cache_path
from optimum.neuron.utils.import_utils import is_neuronx_available
from optimum.neuron.utils.runner import run_command_with_realtime_output

from ..test_utils import is_trainium_test

Expand Down Expand Up @@ -207,33 +208,35 @@ def _test_model_parallel(
cmd.insert(1, f"--rdzv_endpoint={rdzv_endpoint_host}:{rdzv_endpoint_port}")
env["NEURON_RT_VISIBLE_CORES"] = f"0-{num_neuron_cores - 1}"

p_original = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env)

# When running tests in parallel, synchronization is done after both processes started.
if not run_test_in_parallel:
stdout, _ = p_original.communicate()
stdout = stdout.decode("utf-8")
full_output = f"Original model standard output:\n{stdout}"
print(full_output)
_, stdout = run_command_with_realtime_output(cmd, env=env)
else:
p_original = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env)

# Parallel model.
env = {"is_parallel": "true", **specialization_env, "NEURON_CC_FLAGS": neuron_cc_flags}
if run_test_in_parallel:
# Updating the rendez-vous endpoint for the parallel model process.
cmd[1] = f"--rdzv_endpoint={rdzv_endpoint_host}:{rdzv_endpoint_port + 1}"
env["NEURON_RT_VISIBLE_CORES"] = f"{num_neuron_cores}-{2 * num_neuron_cores - 1}"
p_parallel = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env)

if run_test_in_parallel:
p_parallel = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env)

stdout, _ = p_original.communicate()
stdout = stdout.decode("utf-8")
full_output = f"Original model standard output:\n{stdout}"
print(full_output)

stdout, _ = p_parallel.communicate()
stdout = stdout.decode("utf-8")
full_output = f"Parallel model standard output:\n{stdout}"
print(full_output)
stdout, _ = p_parallel.communicate()
stdout = stdout.decode("utf-8")
full_output = f"Parallel model standard output:\n{stdout}"
print(full_output)

else:
_, stdout = run_command_with_realtime_output(cmd, env=env)


temporary_dir = Path(tmpdirname)
original_model_outputs = torch.load(temporary_dir / "original.bin")
Expand All @@ -256,7 +259,7 @@ def test_model_parallel_from_config_without_lazy_load(
self._test_model_parallel(
num_neuron_cores=8,
tp_size=2,
run_test_in_parallel=True,
run_test_in_parallel=False,
model_class_name=model_class_name,
model_name_or_path=model_name_or_path,
from_config=True,
Expand Down

0 comments on commit 70c3e5c

Please sign in to comment.