Skip to content

Commit

Permalink
Making wrapper tensor subclass to work in serialization (#2440)
Browse files Browse the repository at this point in the history
* Making wrapper tensor subclass to work in huggingface_hub serialization (non-safetensor)

Summary:
huggingface_hub seriliazation relies on storage_ptr of a tensor to implement sharding logic, but
wrapper_tensor_subclass does not have storage, so we unflatten the tensor and get storage_id from
adding all storage_ids from internal plain tensors, this is a bit hacky, open to more robust ideas.

Test Plan:
tested with script in huggingface/transformers#32364

Reviewers:

Subscribers:

Tasks:

Tags:

* add tests

* update signature to include new changes for tensor subclass

* add torch version checks and move around import

* more fixes

* tested with torch 2.0.0 and 2.5.0

* remove torch_version_at_least from _torch.py

* simplify code for checking if tensor subclass is available or not

* minor fix

* addressing comments and run tests with torch 2.4.0

* some linting

* add test_split_torch_state_dict_into_shards for tensor subclass state dict

* lint

* style

* quality

---------

Co-authored-by: Lucain <[email protected]>
Co-authored-by: Lucain Pouget <[email protected]>
  • Loading branch information
3 people authored Aug 30, 2024
1 parent ecbbeb3 commit f12ba86
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 11 deletions.
75 changes: 64 additions & 11 deletions src/huggingface_hub/serialization/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from collections import defaultdict
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union

from .. import constants, logging
from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
Expand Down Expand Up @@ -336,17 +336,24 @@ def split_torch_state_dict_into_shards(
)


def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", int, int]:
def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
"""Returns a unique id for plain tensor
or a (potentially nested) Tuple of unique id for the flattened Tensor
if the input is a wrapper tensor subclass Tensor
"""
Return unique identifier to a tensor storage.

Multiple different tensors can share the same underlying storage. For
example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
try:
# for torch 2.1 and above we can also handle tensor subclasses
from torch.utils._python_dispatch import is_traceable_wrapper_subclass

if is_traceable_wrapper_subclass(tensor):
attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined]
return tuple(_get_unique_id(getattr(tensor, attr)) for attr in attrs)

except ImportError:
# for torch version less than 2.1, we can fallback to original implementation
pass

Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278.
"""
if tensor.device.type == "xla" and is_torch_tpu_available():
# NOTE: xla tensors dont have storage
# use some other unique id to distinguish.
Expand All @@ -358,13 +365,38 @@ def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", int, i
else:
unique_id = storage_ptr(tensor)

return tensor.device, unique_id, get_torch_storage_size(tensor)
return unique_id


def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", Union[int, Tuple[Any, ...]], int]:
"""
Return unique identifier to a tensor storage.
Multiple different tensors can share the same underlying storage. For
example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278.
"""
return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor)


def get_torch_storage_size(tensor: "torch.Tensor") -> int:
"""
Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59
"""
try:
# for torch 2.1 and above we can also handle tensor subclasses
from torch.utils._python_dispatch import is_traceable_wrapper_subclass

if is_traceable_wrapper_subclass(tensor):
attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined]
return sum(get_torch_storage_size(getattr(tensor, attr)) for attr in attrs)
except ImportError:
# for torch version less than 2.1, we can fallback to original implementation
pass

try:
return tensor.untyped_storage().nbytes()
except AttributeError:
Expand Down Expand Up @@ -398,10 +430,20 @@ def is_torch_tpu_available(check_device=True):
return False


def storage_ptr(tensor: "torch.Tensor") -> int:
def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
"""
Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L11.
"""
try:
# for torch 2.1 and above we can also handle tensor subclasses
from torch.utils._python_dispatch import is_traceable_wrapper_subclass

if is_traceable_wrapper_subclass(tensor):
return _get_unique_id(tensor)
except ImportError:
# for torch version less than 2.1, we can fallback to original implementation
pass

try:
return tensor.untyped_storage().data_ptr()
except Exception:
Expand Down Expand Up @@ -496,6 +538,17 @@ def _is_complete(tensor: "torch.Tensor") -> bool:
"""
Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80
"""
try:
# for torch 2.1 and above we can also handle tensor subclasses
from torch.utils._python_dispatch import is_traceable_wrapper_subclass

if is_traceable_wrapper_subclass(tensor):
attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined]
return all(_is_complete(getattr(tensor, attr)) for attr in attrs)
except ImportError:
# for torch version less than 2.1, we can fallback to original implementation
pass

return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _get_dtype_size(
tensor.dtype
) == get_torch_storage_size(tensor)
Expand Down
112 changes: 112 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import pytest
from pytest_mock import MockerFixture

from huggingface_hub import constants
from huggingface_hub.serialization import (
get_tf_storage_size,
get_torch_storage_size,
save_torch_model,
save_torch_state_dict,
split_state_dict_into_shards_factory,
split_torch_state_dict_into_shards,
)
from huggingface_hub.serialization._base import parse_size_to_int

Expand All @@ -31,6 +33,16 @@ def _dummy_get_storage_size(item):
return sum(item)


# util functions for checking the version for pytorch
def is_wrapper_tensor_subclass_available():
try:
from torch.utils._python_dispatch import is_traceable_wrapper_subclass # noqa: F401

return True
except ImportError:
return False


@pytest.fixture
def dummy_state_dict() -> Dict[str, List[int]]:
return {
Expand Down Expand Up @@ -58,6 +70,25 @@ def torch_state_dict() -> Dict[str, "torch.Tensor"]:
pytest.skip("torch is not available")


@pytest.fixture
def torch_state_dict_tensor_subclass() -> Dict[str, "torch.Tensor"]:
try:
import torch
from torch.testing._internal.two_tensor import TwoTensor

t = torch.tensor([4])
return {
"layer_1": torch.tensor([4]),
"layer_2": torch.tensor([10]),
"layer_3": torch.tensor([30]),
"layer_4": torch.tensor([2]),
"layer_5": torch.tensor([2]),
"layer_6": TwoTensor(t, t),
}
except ImportError:
pytest.skip("torch is not available")


@pytest.fixture
def torch_state_dict_shared_layers() -> Dict[str, "torch.Tensor"]:
try:
Expand All @@ -75,6 +106,31 @@ def torch_state_dict_shared_layers() -> Dict[str, "torch.Tensor"]:
pytest.skip("torch is not available")


@pytest.fixture
def torch_state_dict_shared_layers_tensor_subclass() -> Dict[str, "torch.Tensor"]:
try:
import torch
from torch.testing._internal.two_tensor import TwoTensor

t = torch.tensor([4])
tensor_subclass_tensor = TwoTensor(t, t)

t = torch.tensor([4])
shared_tensor_subclass_tensor = TwoTensor(t, t)
return {
"layer_1": torch.tensor([4]),
"layer_2": torch.tensor([10]),
"layer_3": torch.tensor([30]),
"layer_4": torch.tensor([2]),
"layer_5": torch.tensor([2]),
"layer_6": tensor_subclass_tensor,
"ts_shared_1": shared_tensor_subclass_tensor,
"ts_shared_2": shared_tensor_subclass_tensor,
}
except ImportError:
pytest.skip("torch is not available")


def test_single_shard(dummy_state_dict):
state_dict_split = split_state_dict_into_shards_factory(
dummy_state_dict,
Expand Down Expand Up @@ -170,6 +226,18 @@ def test_get_torch_storage_size():
assert get_torch_storage_size(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2


@requires("torch")
@pytest.mark.skipif(not is_wrapper_tensor_subclass_available(), reason="requires torch 2.1 or higher")
def test_get_torch_storage_size_wrapper_tensor_subclass():
import torch
from torch.testing._internal.two_tensor import TwoTensor

t = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float64)
assert get_torch_storage_size(TwoTensor(t, t)) == 5 * 8 * 2
t = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)
assert get_torch_storage_size(TwoTensor(t, TwoTensor(t, t))) == 5 * 2 * 3


def test_parse_size_to_int():
assert parse_size_to_int("1KB") == 1 * 10**3
assert parse_size_to_int("2MB") == 2 * 10**6
Expand Down Expand Up @@ -247,6 +315,38 @@ def test_save_torch_state_dict_unsafe_not_sharded(
assert not (tmp_path / "pytorch_model.bin.index.json").is_file()


@pytest.mark.skipif(not is_wrapper_tensor_subclass_available(), reason="requires torch 2.1 or higher")
def test_save_torch_state_dict_tensor_subclass_unsafe_not_sharded(
tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict_tensor_subclass: Dict[str, "torch.Tensor"]
) -> None:
"""Save as pickle without sharding."""
with caplog.at_level("WARNING"):
save_torch_state_dict(
torch_state_dict_tensor_subclass, tmp_path, max_shard_size="1GB", safe_serialization=False
)
assert "we strongly recommend using safe serialization" in caplog.text

assert (tmp_path / "pytorch_model.bin").is_file()
assert not (tmp_path / "pytorch_model.bin.index.json").is_file()


@pytest.mark.skipif(not is_wrapper_tensor_subclass_available(), reason="requires torch 2.1 or higher")
def test_save_torch_state_dict_shared_layers_tensor_subclass_unsafe_not_sharded(
tmp_path: Path,
caplog: pytest.LogCaptureFixture,
torch_state_dict_shared_layers_tensor_subclass: Dict[str, "torch.Tensor"],
) -> None:
"""Save as pickle without sharding."""
with caplog.at_level("WARNING"):
save_torch_state_dict(
torch_state_dict_shared_layers_tensor_subclass, tmp_path, max_shard_size="1GB", safe_serialization=False
)
assert "we strongly recommend using safe serialization" in caplog.text

assert (tmp_path / "pytorch_model.bin").is_file()
assert not (tmp_path / "pytorch_model.bin.index.json").is_file()


def test_save_torch_state_dict_unsafe_sharded(
tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict: Dict[str, "torch.Tensor"]
) -> None:
Expand Down Expand Up @@ -314,6 +414,18 @@ def test_save_torch_state_dict_shared_layers_sharded(
assert "shared_2" not in state_dict


def test_split_torch_state_dict_into_shards(
tmp_path: Path, torch_state_dict_shared_layers_tensor_subclass: Dict[str, "torch.Tensor"]
):
# the model size is 72, setting max_shard_size to 32 means we'll shard the file
state_dict_split = split_torch_state_dict_into_shards(
torch_state_dict_shared_layers_tensor_subclass,
filename_pattern=constants.PYTORCH_WEIGHTS_FILE_PATTERN,
max_shard_size=32,
)
assert state_dict_split.is_sharded


def test_save_torch_state_dict_custom_filename(tmp_path: Path, torch_state_dict: Dict[str, "torch.Tensor"]) -> None:
"""Custom filename pattern is respected."""
# Not sharded
Expand Down

0 comments on commit f12ba86

Please sign in to comment.