Skip to content

Commit

Permalink
skip weights/neff sep test for torch 2.*
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Mar 4, 2024
1 parent 356ef2a commit c81dc3f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
8 changes: 8 additions & 0 deletions optimum/neuron/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from packaging import version
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, CLIPProcessor
from transformers.modeling_utils import _add_variant
from transformers.utils import (
Expand All @@ -41,6 +42,7 @@
from ...utils import logging
from .import_utils import is_torch_xla_available
from .require_utils import requires_safetensors
from .version_utils import get_torch_version


if TYPE_CHECKING:
Expand Down Expand Up @@ -522,6 +524,12 @@ def replace_weights(
"""
Replaces the weights in a Neuron Model with weights from another model, the original neuron model should have separated weights(by setting `inline_weights_to_neff=Talse` during the tracing).
"""
torch_version = get_torch_version()
if version.parse(torch_version) >= version.parse("2.0.0"):
raise RuntimeError(
"Weights Neff separation is not yet supported by Neuron SDK for PyTorch 2.*. You can downgrade your PyTorch version to 1.13.1."
)

if isinstance(weights, torch.nn.Module):
weights = weights.state_dict()

Expand Down
9 changes: 9 additions & 0 deletions optimum/neuron/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

import unittest

from packaging import version

from .import_utils import is_neuron_available, is_neuronx_available
from .version_utils import get_torch_version


def requires_neuron(test_case):
Expand All @@ -33,6 +36,12 @@ def requires_neuron_or_neuronx(test_case):
)(test_case)


def requires_pytorch_1_13(test_case):
return unittest.skipUnless(
version.parse(get_torch_version()) < version.parse("2.0.0"), "test requires PyTorch < 2.0.0"
)(test_case)


def is_trainium_test(test_case):
test_case = requires_neuronx(test_case)
try:
Expand Down
3 changes: 2 additions & 1 deletion tests/inference/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
pipeline,
)
from optimum.neuron.utils import NEURON_FILE_NAME, is_neuron_available, is_neuronx_available
from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx
from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx, requires_pytorch_1_13
from optimum.utils import (
CONFIG_NAME,
logging,
Expand Down Expand Up @@ -144,6 +144,7 @@ def test_save_compiler_intermediary_files(self):
self.assertTrue(os.path.isdir(save_path))
self.assertTrue(os.path.exists(neff_path))

@requires_pytorch_1_13
@requires_neuronx
def test_decouple_weights_neff_and_replace_weight(self):
with tempfile.TemporaryDirectory() as tempdir:
Expand Down

0 comments on commit c81dc3f

Please sign in to comment.