Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Jul 20, 2023
1 parent 69f0fa9 commit a325df9
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 40 deletions.
5 changes: 4 additions & 1 deletion optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,10 @@ def _prepare_model_for_tp(
cpu_ids = [id(v) for v in model.parameters()]
# TODO: enable self.device (if needed).
model = self.state.tp_plugin.parallelize_model(model, return_orig_to_parallel=False, device=None)
model.to(torch.float32)
if os.environ.get("XLA_USE_BF16", "0") == "1":
model.to(torch.bfloat16)
else:
model.to(torch.float32)
parallel_layers.move_model_to_device(model, self.device)
model.tie_weights()
self._model_cpu_parameters_to_xla[id(model)] = dict(zip(cpu_ids, model.parameters()))
Expand Down
3 changes: 1 addition & 2 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,14 @@ def load_tensor_for_weight(
"""
from safetensors import safe_open

# device = weight_info.device if weight_info.device is not torch.device("cpu") else None
device = str(weight_info.device)
with safe_open(weight_info.filename, framework="pt", device=device) as fp:
if tensor_slices is None:
tensor = fp.get_tensor(weight_info.qualified_name)
else:
tensor_slice = fp.get_slice(weight_info.qualified_name)
slices = [slice(*slice_) if slice_ is not None else slice(None, None, None) for slice_ in tensor_slices]
tensor = tensor_slice[slices]
tensor = tensor_slice[slices].contiguous()
# This is needed to make sure tensor.numel() == tensor.storage().size().
tensor = torch.empty_like(tensor).copy_(tensor)

Expand Down
8 changes: 5 additions & 3 deletions optimum/neuron/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def from_trainer_state(cls, state: TrainerState) -> "NeuronTrainerState":
return neuron_trainer_state


class NeuronCacheCallaback(TrainerCallback):
class NeuronCacheCallback(TrainerCallback):
def __init__(
self,
tmp_neuron_cache: Optional[TemporaryDirectory] = None,
Expand Down Expand Up @@ -207,10 +207,12 @@ def neuron_hash_for_model(
else:
data_type = torch.float32

key = (model, input_shapes, data_type)
key_args = (model, input_shapes, data_type)
key_kwargs = {"tensor_parallel_size": args.tensor_parallel_size}
key = key_args + tuple(key_kwargs.values())
neuron_hash = self.neuron_hashes.get(key, None)
if neuron_hash is None:
neuron_hash = NeuronHash(*key)
neuron_hash = NeuronHash(*key_args, **key_kwargs)
self.neuron_hashes[key] = neuron_hash
if try_to_fetch_cached_model:
self.try_to_fetch_cached_model(neuron_hash)
Expand Down
19 changes: 10 additions & 9 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from .accelerate import NeuronAccelerator, NeuronDistributedType
from .distributed import ParallelizersManager
from .distributed.utils import make_optimizer_constructor_lazy
from .trainer_callback import NeuronCacheCallaback
from .trainer_callback import NeuronCacheCallback
from .utils import (
DynamicPatch,
ModelPatcher,
Expand Down Expand Up @@ -105,7 +105,8 @@

if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
_ORIGINAL_NEURON_CACHE_PATH = get_neuron_cache_path()
_TMP_NEURON_CACHE_DIR = NeuronCacheCallaback.create_temporary_neuron_cache(get_neuron_cache_path())
if not is_precompilation():
_TMP_NEURON_CACHE_DIR = NeuronCacheCallback.create_temporary_neuron_cache(get_neuron_cache_path())
torch.distributed.init_process_group(backend="xla")
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.")
Expand Down Expand Up @@ -150,20 +151,20 @@ def __init__(self, *args, **kwargs):
logger.setLevel(logging.INFO)

if not is_precompilation():
push = self.args.local_rank == 0
fetch = self.args.local_rank == 0
push = self.args.local_rank <= 0
fetch = self.args.local_rank <= 0

if is_neuronx_distributed_available():
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_size,
model_parallel_is_initialized,
)

if model_parallel_is_initialized():
push = get_data_parallel_size() == 0
fetch = get_data_parallel_size() == 0
pass
# push = get_data_parallel_rank() == 0
# fetch = get_data_parallel_rank() == 0

callback = NeuronCacheCallaback(
callback = NeuronCacheCallback(
tmp_neuron_cache=_TMP_NEURON_CACHE_DIR,
original_neuron_cache_path=_ORIGINAL_NEURON_CACHE_PATH,
fetch=fetch,
Expand Down Expand Up @@ -238,7 +239,7 @@ def _wrap_model(self, model, training=True, dataloader=None):
# TODO: make this cleaner.
def trigger_on_step_middle_for_neuron_cache_callback(self, model: "PreTrainedModel"):
for callback in self.callback_handler.callbacks:
if isinstance(callback, NeuronCacheCallaback):
if isinstance(callback, NeuronCacheCallback):
# kwargs might not have everything expected (like metrics) but all we need is here.
kwargs = {
"model": model,
Expand Down
38 changes: 27 additions & 11 deletions optimum/neuron/utils/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,20 +518,37 @@ def constructor():
return constructor

def check_requirements_are_met(self, neuron_compiler_version: str):
if self.should_be_inserted_in_hash_dict(neuron_compiler_version) and self.default is None:
raise ValueError("A default value must be specified.")
# from ..version import __version__

# optimum_neuron_requirement = True
# if self.min_optimum_neuron_version is not None:
# if version.parse(__version__) >= version.parse(self.min_optimum_neuron_version):
# optimum_neuron_requirement = self.default is not None

# neuron_compiler_requirement = True
# if self.min_neuron_compiler_version is not None:
# if version.parse(neuron_compiler_version) >= version.parse(self.min_neuron_compiler_version):
# neuron_compiler_requirement = self.default is not None

# if not optimum_neuron_requirement or not neuron_compiler_requirement:
# raise ValueError("A default value must be specified.")

def should_be_inserted_in_hash_dict(self, neuron_compiler_version: str) -> bool:
from ..version import __version__

optimum_neuron_requirement = True
optimum_neuron_requirement = False
if self.min_optimum_neuron_version is not None:
if version.parse(__version__) >= version.parse(self.min_optimum_neuron_version):
optimum_neuron_requirement = self.default is not None
optimum_neuron_requirement = version.parse(__version__) >= version.parse(self.min_optimum_neuron_version)

neuron_compiler_requirement = True
neuron_compiler_requirement = False
if self.min_neuron_compiler_version is not None:
if version.parse(neuron_compiler_version) >= version.parse(self.min_neuron_compiler_version):
neuron_compiler_requirement = self.default is not None
neuron_compiler_requirement = version.parse(neuron_compiler_version) >= version.parse(
self.min_neuron_compiler_version
)

if not optimum_neuron_requirement or not neuron_compiler_requirement:
raise ValueError("A default value must be specified.")
return optimum_neuron_requirement or neuron_compiler_requirement


@dataclass(frozen=True)
Expand Down Expand Up @@ -561,9 +578,8 @@ def _insert_potential_unspecified_hash_attribute(
"""
Inserts `attribute` in `hash_dict` only if it is a specified attribute or if it has a default value.
"""
if isinstance(attribute, _UnspecifiedHashAttribute):
if attribute.default is not None:
hash_dict[attribute_name] = attribute
if isinstance(attribute, _UnspecifiedHashAttribute) and attribute.should_be_inserted_in_hash_dict:
hash_dict[attribute_name] = attribute.default
else:
hash_dict[attribute_name] = attribute

Expand Down
27 changes: 19 additions & 8 deletions optimum/neuron/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,36 +218,44 @@ def download_checkpoints_in_cache(
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
):
# Load from a TF 1.0 checkpoint in priority if from_tf
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)):
# Load from a TF 2.0 checkpoint in priority if from_tf
os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
elif from_flax and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
):
# Load from a Flax checkpoint in priority if from_flax
os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
elif use_safetensors is not False and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
):
# Load from a safetensors checkpoint
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
)
elif use_safetensors is not False and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
):
# Load from a sharded safetensors checkpoint
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
)
is_sharded = True
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
):
# Load from a PyTorch checkpoint
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
)
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
):
# Load from a sharded PyTorch checkpoint
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error.
elif os.path.isfile(
Expand Down Expand Up @@ -276,14 +284,15 @@ def download_checkpoints_in_cache(
f" {pretrained_model_name_or_path}."
)
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
archive_file = pretrained_model_name_or_path
is_local = True
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")):
if not from_tf:
raise ValueError(
f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set "
"from_tf to True to load from this checkpoint."
)
os.path.join(subfolder, pretrained_model_name_or_path + ".index")
archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
filename = pretrained_model_name_or_path
Expand Down Expand Up @@ -397,6 +406,8 @@ def download_checkpoints_in_cache(
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
)

if is_local:
resolved_archive_file = archive_file
else:
resolved_archive_file = None

Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"datasets",
"sacremoses",
"diffusers>=0.17.0",
"safetensors",
]

QUALITY_REQUIRES = [
Expand Down Expand Up @@ -57,6 +58,7 @@
"transformers-neuronx",
"torch==1.13.1.*",
"torchvision==0.14.*",
"neuronx_distributed",
],
"diffusers": ["diffusers"],
}
Expand Down
54 changes: 54 additions & 0 deletions tests/distributed/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for distributed utility functions and classes."""

from pathlib import Path
from tempfile import TemporaryDirectory

import torch
from safetensors.torch import save_file

from optimum.neuron.distributed.utils import WeightInformation, load_tensor_for_weight


def test_load_tensor_for_weight():
with TemporaryDirectory() as tmpdirname:
tmpdir = Path(tmpdirname)
filename = tmpdir / "tensors.safetensors"

t1 = torch.empty((24, 24), dtype=torch.bfloat16)
# Creating a slice from t1, meaning that it shares the same storage as t1.
# It is important to make sure that the resulting loaded file does not have a bigger storage than needed.
t2 = t1[:2, :2]
save_file({"t1": t1, "t2": t2}, filename)

weight_info_1 = WeightInformation(filename, "t1")
weight_info_2 = WeightInformation(filename, "t2", device=torch.device("cpu"))

loaded_t1 = load_tensor_for_weight(weight_info_1)
loaded_t2 = load_tensor_for_weight(weight_info_2)
loaded_sliced_t1 = load_tensor_for_weight(weight_info_1, tensor_slices=((2,), (2,)))

assert torch.testing.assert_close(t1, loaded_t1)
assert torch.testing.assert_close(t2, loaded_t2)
assert torch.testing.assert_close(t2, loaded_sliced_t1)

assert loaded_t1.numel() == loaded_t1.storage().size()
assert loaded_t2.numel() == loaded_t2.storage().size()
assert loaded_sliced_t1.numel() == loaded_sliced_t1.storage().size()


def test_embedding_to_parallel_embedding():
pass
12 changes: 6 additions & 6 deletions tests/test_trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from transformers import TrainingArguments
from transformers.testing_utils import is_staging_test

from optimum.neuron.trainers import NeuronCacheCallaback
from optimum.neuron.trainers import NeuronCacheCallback
from optimum.neuron.utils.cache_utils import (
NEURON_COMPILE_CACHE_NAME,
NeuronHash,
Expand All @@ -38,7 +38,7 @@

@is_trainium_test
@is_staging_test
class NeuronCacheCallabackTestCase(StagingTestMixin, TestCase):
class NeuronCacheCallbackTestCase(StagingTestMixin, TestCase):
def test_neuron_hash_for_model(self):
with TemporaryDirectory() as tmpdirname:
args = TrainingArguments(tmpdirname)
Expand All @@ -47,7 +47,7 @@ def test_neuron_hash_for_model(self):
"x": torch.rand((1,)),
}

callback = NeuronCacheCallaback()
callback = NeuronCacheCallback()

# We first check that no hashes is in the hash cache already.
self.assertFalse(callback.neuron_hashes)
Expand All @@ -74,7 +74,7 @@ def test_try_to_fetch_cached_model(self):

with TemporaryDirectory() as tmpdirname:
set_neuron_cache_path(tmpdirname)
callback = NeuronCacheCallaback()
callback = NeuronCacheCallback()
args = TrainingArguments(tmpdirname)
inputs = {"x": torch.rand((24, 1))}
neuron_hash = callback.neuron_hash_for_model(args, model, inputs)
Expand Down Expand Up @@ -109,7 +109,7 @@ def test_try_to_fetch_cached_model(self):
def test_synchronize_temporary_neuron_cache_state(self):
with TemporaryDirectory() as tmpdirname:
set_neuron_cache_path(tmpdirname)
callback = NeuronCacheCallaback()
callback = NeuronCacheCallback()

diff = callback.synchronize_temporary_neuron_cache_state()
self.assertListEqual(diff, [], "The diff should be empty.")
Expand All @@ -133,7 +133,7 @@ def test_synchronize_temporary_neuron_cache(self):
with TemporaryDirectory() as tmpdirname:
set_neuron_cache_path(tmpdirname)
args = TrainingArguments(tmpdirname)
callback = NeuronCacheCallaback()
callback = NeuronCacheCallback()

callback.synchronize_temporary_neuron_cache()
files_in_repo = HfApi().list_repo_files(repo_id=self.CUSTOM_PRIVATE_CACHE_REPO)
Expand Down

0 comments on commit a325df9

Please sign in to comment.