Skip to content

Commit

Permalink
[WIP] llama-70b
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Feb 7, 2024
1 parent c4d50b1 commit f734478
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 38 deletions.
17 changes: 10 additions & 7 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,13 +387,16 @@ def predicate_func(layer):

weight_map = getattr(model, "_weight_map", {})

for fully_qualified_name, layer in model.named_modules():
if isinstance(layer, BaseParallelLinear):
xm.master_print(fully_qualified_name)
linear_weight_info, linear_bias_weight_info = ParallelLayer._get_linear_weight_info(weight_map, fully_qualified_name)
if linear_weight_info is not None:
maybe_load_linear_weight_to_parallel_linear(layer, linear_layer_weight_info=linear_weight_info, linear_layer_bias_weight_info=linear_bias_weight_info)
xm.master_print("PARALLEL LAYERS DONE")
# for fully_qualified_name, layer in model.named_modules():
# if isinstance(layer, BaseParallelLinear):
# xm.master_print(fully_qualified_name)
# try:
# linear_weight_info, linear_bias_weight_info = ParallelLayer._get_linear_weight_info(weight_map, fully_qualified_name)
# except ValueError:
# linear_weight_info = None
# if linear_weight_info is not None:
# maybe_load_linear_weight_to_parallel_linear(layer, linear_layer_weight_info=linear_weight_info, linear_layer_bias_weight_info=linear_bias_weight_info)
# xm.master_print("PARALLEL LAYERS DONE")

with torch.no_grad():
tied_weights = {}
Expand Down
51 changes: 33 additions & 18 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ..utils import DynamicPatch, Patcher
from ..utils.deprecate_utils import deprecate
from ..utils.import_utils import is_neuronx_distributed_available
from ..utils.misc import download_checkpoints_in_cache
from ..utils.misc import download_checkpoints_in_cache, is_main_worker
from ..utils.require_utils import (
is_torch_xla_available,
requires_neuronx_distributed,
Expand Down Expand Up @@ -291,7 +291,7 @@ def embedding_to_parallel_embedding(
)

del embedding_layer.weight

if lm_head_layer is None:
return parallel_embedding_layer

Expand All @@ -305,11 +305,11 @@ def maybe_load_linear_weight_to_parallel_linear(
linear_layer_bias_weight_info: Optional[WeightInformation] = None,
linear_layer: Optional["torch.nn.Linear"] = None,
):
if linear_layer_weight_info is not None and linear_layer is not None:
if (linear_layer_weight_info is not None or linear_layer_bias_weight_info is not None) and linear_layer is not None:
raise ValueError(
"Specify either a linear layer's WeightInformation, or a linear layer to copy the weights from, but not both."
)
if linear_layer_weight_info is None and linear_layer is None:
if linear_layer_weight_info is None and linear_layer_bias_weight_info is None and linear_layer is None:
raise ValueError(
"A linear's layer WeightInformation or a linear layer to copy the weight from need to specified."
)
Expand Down Expand Up @@ -412,7 +412,7 @@ def linear_to_parallel_linear(
linear_layer_bias_weight_info: Optional[WeightInformation] = None,
embedding_weight_to_tie: Optional["torch.nn.Parameter"] = None,
sequence_parallel_enabled: bool = False,
skip_weight_load: bool = True,
skip_weight_load: bool = False,
device: Optional["torch.device"] = None,
) -> Union["layers.RowParallelLinear", "layers.ColumnParallelLinear"]:
"""
Expand Down Expand Up @@ -724,22 +724,37 @@ def from_pretrained_for_mp(
if token is not None and adapter_kwargs is not None and "token" not in adapter_kwargs:
adapter_kwargs["token"] = token

filenames, sharded_metadata = download_checkpoints_in_cache(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
use_safetensors_in_priority=True,
convert_to_safetensors=True,
**kwargs,
)

import torch_xla.core.xla_model as xm

if is_main_worker():
filenames, sharded_metadata = download_checkpoints_in_cache(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
use_safetensors_in_priority=True,
convert_to_safetensors=True,
**kwargs,
)

xm.rendezvous("waiting after download and conversion")

if not is_main_worker():
filenames, sharded_metadata = download_checkpoints_in_cache(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
use_safetensors_in_priority=True,
convert_to_safetensors=True,
**kwargs,
)

if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path
Expand Down
26 changes: 13 additions & 13 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
import torch_xla.distributed.xla_backend as xbn

if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
_ORIGINAL_NEURON_CACHE_PATH = get_neuron_cache_path()
# _ORIGINAL_NEURON_CACHE_PATH = get_neuron_cache_path()

# _ORIGINAL_NEURON_CACHE_PATH is `None` when the `--no-cache` flag is set.
# if _ORIGINAL_NEURON_CACHE_PATH is not None:
Expand Down Expand Up @@ -196,18 +196,18 @@ def __init__(self, *args, **kwargs):
if self.args.local_rank <= 0:
logger.setLevel(logging.INFO)

rank = xm.get_ordinal()
push = rank <= 0 and not is_precompilation() and not self.args.skip_cache_push
fetch = rank <= 0 or self.args.mp_plugin.should_parallelize

callback = NeuronCacheCallback(
tmp_neuron_cache=_TMP_NEURON_CACHE_PATH,
original_neuron_cache_path=_ORIGINAL_NEURON_CACHE_PATH,
fetch=fetch,
push=push,
wait_for_everyone_on_fetch=True,
wait_for_everyone_on_push=True,
)
# rank = xm.get_ordinal()
# push = rank <= 0 and not is_precompilation() and not self.args.skip_cache_push
# fetch = rank <= 0 or self.args.mp_plugin.should_parallelize

# callback = NeuronCacheCallback(
# tmp_neuron_cache=_TMP_NEURON_CACHE_PATH,
# original_neuron_cache_path=_ORIGINAL_NEURON_CACHE_PATH,
# fetch=fetch,
# push=push,
# wait_for_everyone_on_fetch=True,
# wait_for_everyone_on_push=True,
# )
# self.add_callback(callback)

# Make the model Neuron-compatible for generation.
Expand Down

0 comments on commit f734478

Please sign in to comment.