diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index 89fc363a2..e2ac71fff 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -461,6 +461,15 @@ def save_model_checkpoint_as_sharded( optimizer: Optional["torch.optim.Optimizer"] = None, ): cls._check_model_was_parallelized(model) + + from neuronx_distributed.parallel_layers.parallel_state import ( + get_data_parallel_rank, + get_tensor_model_parallel_rank, + ) + + data_parallel_rank = get_data_parallel_rank() + tensor_parallel_rank = get_tensor_model_parallel_rank() + if not isinstance(output_dir, Path): output_dir = Path(output_dir) @@ -474,12 +483,8 @@ def save_model_checkpoint_as_sharded( state_dict["optimizer_state_dict"] = optimizer.state_dict() output_path = output_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME - from neuronx_distributed.parallel_layers.parallel_state import ( - get_data_parallel_rank, - get_tensor_model_parallel_rank, - ) - if get_data_parallel_rank() == 0 and get_tensor_model_parallel_rank() == 0: + if data_parallel_rank == 0 and tensor_parallel_rank == 0: if output_path.is_dir(): shutil.rmtree(output_path, ignore_errors=True) output_path.mkdir() diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 17aaaad2b..e52d39852 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -314,6 +314,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for if self.control.should_log: logs: Dict[str, float] = {} + xm.mark_step() + if self.args.tp_plugin.tensor_parallel_size > 1: from neuronx_distributed.parallel_layers.parallel_state import ( get_data_parallel_group, @@ -330,7 +332,6 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for tr_loss_scalar = tr_loss_scalar.detach().item() else: # all_gather + mean() to get average loss over all processes - xm.mark_step() tr_loss_scalar = self._nested_gather(tr_loss).mean().item() # reset tr_loss to zero