Skip to content

Commit

Permalink
[WIP] fix greedy
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Oct 20, 2023
1 parent bd5c339 commit 7be5db6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
15 changes: 10 additions & 5 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,15 @@ def save_model_checkpoint_as_sharded(
optimizer: Optional["torch.optim.Optimizer"] = None,
):
cls._check_model_was_parallelized(model)

from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_rank,
get_tensor_model_parallel_rank,
)

data_parallel_rank = get_data_parallel_rank()
tensor_parallel_rank = get_tensor_model_parallel_rank()

if not isinstance(output_dir, Path):
output_dir = Path(output_dir)

Expand All @@ -474,12 +483,8 @@ def save_model_checkpoint_as_sharded(
state_dict["optimizer_state_dict"] = optimizer.state_dict()

output_path = output_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_rank,
get_tensor_model_parallel_rank,
)

if get_data_parallel_rank() == 0 and get_tensor_model_parallel_rank() == 0:
if data_parallel_rank == 0 and tensor_parallel_rank == 0:
if output_path.is_dir():
shutil.rmtree(output_path, ignore_errors=True)
output_path.mkdir()
Expand Down
3 changes: 2 additions & 1 deletion optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for
if self.control.should_log:
logs: Dict[str, float] = {}

xm.mark_step()

if self.args.tp_plugin.tensor_parallel_size > 1:
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_group,
Expand All @@ -330,7 +332,6 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for
tr_loss_scalar = tr_loss_scalar.detach().item()
else:
# all_gather + mean() to get average loss over all processes
xm.mark_step()
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

# reset tr_loss to zero
Expand Down

0 comments on commit 7be5db6

Please sign in to comment.