Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Nov 21, 2023
1 parent a10e37b commit fd3e9b7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
5 changes: 4 additions & 1 deletion optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def parallelize(
from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_rank

# Parallelizing the model.
# This needs to be done prior to preparing the model for sequence parallelism because modules can be overriden.
model = cls._parallelize(
model,
device=device,
Expand Down Expand Up @@ -543,7 +544,9 @@ def load_model_sharded_checkpoint(cls, model: "PreTrainedModel", load_dir: Union

if not isinstance(load_dir, Path):
load_dir = Path(load_dir)
neuronx_distributed.parallel_layers.load(load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME, model=model, sharded=True)
neuronx_distributed.parallel_layers.load(
load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME, model_or_optimizer=model, sharded=True
)

@classmethod
def load_model_checkpoint(cls, model: "PreTrainedModel", load_dir: Union[str, Path]):
Expand Down
5 changes: 2 additions & 3 deletions tests/distributed/test_model_parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,7 @@ def _test_model_parallel(
"lazy_load": "true" if with_lazy_load else "false",
"parallelize_embeddings": "true" if parallelize_embeddings else "false",
"sequence_parallel_enabled": "true" if sequence_parallel_enabled else "false",
# TODO: disable that once that loss computation compilation for LLama does not take forever.
"computing_loss_is_supported": "true" if not model_class_name.startswith("Llama") else "true",
"computing_loss_is_supported": "true",
**os.environ,
}

Expand Down Expand Up @@ -330,7 +329,7 @@ def test_model_parallel_from_config_no_lazy_load(
# 3. Do not enable sequence parallelism => this feature should not depend on whether the model is initialized
# lazily or not.
self._test_model_parallel(
num_neuron_cores=2,
num_neuron_cores=8,
tp_size=2,
run_test_in_parallel=True,
model_class_name=model_class_name,
Expand Down

0 comments on commit fd3e9b7

Please sign in to comment.