Skip to content

Commit

Permalink
[WIP] llama-70b
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Feb 1, 2024
1 parent fba21f2 commit 861c782
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 15 deletions.
3 changes: 2 additions & 1 deletion optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ..utils import is_neuronx_distributed_available, is_torch_xla_available
from ..utils.patching import Patcher
from ..utils.require_utils import requires_neuronx_distributed, requires_torch_xla
from ..utils.misc import is_main_worker
from .parallel_layers import (
IOSequenceParallelizer,
LayerNormSequenceParallelizer,
Expand Down Expand Up @@ -685,7 +686,7 @@ def save_model_checkpoint_as_regular(
if not isinstance(output_dir, Path):
output_dir = Path(output_dir)

if optimizer is not None:
if is_main_worker() and optimizer is not None:
logger.warning(
"Saving the optimizer state as a regular file under the tensor parallel setting is not supported yet."
)
Expand Down
5 changes: 3 additions & 2 deletions optimum/neuron/distributed/parallel_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ...utils import NormalizedConfigManager, logging
from ..utils import patch_everywhere, patch_within_function
from ..utils.require_utils import requires_neuronx_distributed
from ..utils.misc import is_main_worker
from .utils import (
GroupedQueryAttentionInfo,
WeightInformation,
Expand Down Expand Up @@ -227,7 +228,7 @@ def transform(
if embedding_layer.num_embeddings % tp_size != 0:
import torch_xla.core.xla_model as xm

if xm.get_ordinal() == 0:
if is_main_worker():
logger.warning(
f"Embedding parallelization for TP was skipped because the tensor parallel size ({tp_size}) does not "
f"divide the number of embeddings ({embedding_layer.num_embeddings})"
Expand Down Expand Up @@ -344,7 +345,7 @@ def transform(
raise ValueError(
"Only the cases where the number of key value heads is divisible by the TP size, or the other way around are supported."
)
elif num_key_value_heads < tp_size:
elif is_main_worker() and num_key_value_heads < tp_size:
logger.warning(
f"The TP size ({tp_size}) is bigger than the number of key value heads ({num_key_value_heads}). "
"This is not ideal because the key and value projections will not be sharded accross the TP ranks. "
Expand Down
1 change: 1 addition & 0 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,3 +879,4 @@ def is_tied(self):
@property
def is_sharded(self):
return self.kind == "sharded"

20 changes: 10 additions & 10 deletions optimum/neuron/utils/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def set_custom_cache_repo_name_in_hf_home(repo_id: str, hf_home: str = HF_HOME,
)

existing_custom_cache_repo = load_custom_cache_repo_name_from_hf_home(hf_home_cache_repo_file)
if is_main_worker() and existing_custom_cache_repo is not None:
if is_main_worker(global_main=False) and existing_custom_cache_repo is not None:
logger.warning(
f"A custom cache repo was already registered: {existing_custom_cache_repo}. It will be overwritten to "
f"{repo_id}."
Expand Down Expand Up @@ -173,7 +173,7 @@ def has_write_access_to_repo(repo_id: str) -> bool:
if org["name"] == username_or_organization:
# Role in an organization can be either:
# "admin", "write", "contributor", "read".
if is_main_worker() and org["roleInOrg"] == "contributor":
if is_main_worker(global_main=False) and org["roleInOrg"] == "contributor":
logger.warning(
f"You are logged in as a contributor to the cache repo {repo_id}. It is not possible to infer "
"whether you have write access on this repo or not, so it will be assumed you do not."
Expand All @@ -193,7 +193,7 @@ def get_hf_hub_cache_repos():
if custom_cache_repo is not None and custom_cache_repo not in hf_hub_repos:
hf_hub_repos = [custom_cache_repo] + hf_hub_repos

if is_main_worker() and saved_custom_cache_repo is None and custom_cache_repo is None:
if is_main_worker(global_main=False) and saved_custom_cache_repo is None and custom_cache_repo is None:
warn_once(
logger,
"No Neuron cache name is saved locally. This means that only the official Neuron cache will be used. You "
Expand All @@ -209,7 +209,7 @@ def get_hf_hub_cache_repos():
# making it easier for higher-level abstractions using the cache utils to reason on which
# parts should only run on the master process and which parts should run on everyone.

if is_main_worker() and hf_hub_repos and not has_write_access_to_repo(hf_hub_repos[0]):
if is_main_worker(global_main=False) and hf_hub_repos and not has_write_access_to_repo(hf_hub_repos[0]):
warn_once(
logger,
f"You do not have write access to {hf_hub_repos[0]} so you will not be able to push any cached compilation "
Expand Down Expand Up @@ -474,7 +474,7 @@ def add_in_registry(repo_id: str, neuron_hash: "NeuronHash"):
)
except Exception as e:
if "A commit has happened since" in str(e):
if is_main_worker():
if is_main_worker(global_main=False):
logger.info(
"A commit has happened in cache repository since we tried to update the registry, starting "
"again..."
Expand Down Expand Up @@ -951,7 +951,7 @@ def push_to_cache_on_hub(
)
if fail_when_could_not_push:
raise ValueError(error_message)
if is_main_worker():
if is_main_worker(global_main=False):
logger.warning(error_message)
return

Expand All @@ -969,7 +969,7 @@ def push_to_cache_on_hub(
)
if fail_when_could_not_push:
raise ValueError(error_message)
if is_main_worker():
if is_main_worker(global_main=False):
logger.warning(error_message)
return

Expand All @@ -992,7 +992,7 @@ def push_to_cache_on_hub(
exists = any(filename.startswith(path_in_repo_str) for filename in repo_filenames)
else:
exists = any(filename == path_in_repo_str for filename in repo_filenames)
if is_main_worker() and exists:
if is_main_worker(global_main=False) and exists:
if not overwrite_existing:
logger.info(
f"Did not push the cached model located at {local_cache_dir_or_file} to the repo named {cache_repo_id} "
Expand Down Expand Up @@ -1022,7 +1022,7 @@ def push_to_cache_on_hub(
raise e
msg = could_not_push_message.format(cache_repo_id=cache_repo_id, error=e)
msg = re.sub(_HF_HUB_HTTP_ERROR_REQUEST_ID_PATTERN, "", msg)
if is_main_worker():
if is_main_worker(global_main=False):
warn_once(logger, msg)
success = False
else:
Expand All @@ -1038,7 +1038,7 @@ def push_to_cache_on_hub(
raise e
msg = could_not_push_message.format(cache_repo_id=cache_repo_id, error=e)
msg = re.sub(_HF_HUB_HTTP_ERROR_REQUEST_ID_PATTERN, "", msg)
if is_main_worker():
if is_main_worker(global_main=False):
warn_once(logger, msg)
success = False

Expand Down
4 changes: 2 additions & 2 deletions optimum/neuron/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@
logger = logging.get_logger()


def is_main_worker() -> bool:
def is_main_worker(global_main: bool = True) -> bool:
if torch.distributed.is_initialized() and is_torch_xla_available():
import torch_xla.core.xla_model as xm

return xm.get_local_ordinal() == 0
return xm.get_ordinal() == 0 if global_main else xm.get_local_ordinal() == 0
return True


Expand Down

0 comments on commit 861c782

Please sign in to comment.