Skip to content

Commit

Permalink
fix doc
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Mar 22, 2024
1 parent a4b6334 commit f35fa16
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions optimum/neuron/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torch_neuronx.xla_impl.data_parallel import DataParallel
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, CLIPProcessor, PretrainedConfig
from transformers.modeling_utils import _add_variant
from transformers.utils import (
Expand All @@ -42,10 +41,13 @@
from transformers.utils.hub import get_checkpoint_shard_files

from ...utils import is_diffusers_available, logging
from .import_utils import is_torch_xla_available
from .import_utils import is_torch_neuronx_available, is_torch_xla_available
from .require_utils import requires_safetensors, requires_torch_xla


if is_torch_neuronx_available():
from torch_neuronx.xla_impl.data_parallel import DataParallel

if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel

Expand Down Expand Up @@ -538,7 +540,7 @@ def download_checkpoints_in_cache(


def replace_weights(
model: Union[torch.jit._script.RecursiveScriptModule, DataParallel],
model: Union[torch.jit._script.RecursiveScriptModule, "DataParallel"],
weights: Union[Dict[str, torch.Tensor], torch.nn.Module],
prefix: str = "model",
):
Expand All @@ -549,7 +551,10 @@ def replace_weights(
weights = weights.state_dict()

# extract module paths from the weights c module
model_weights = model.module.weights if isinstance(model, DataParallel) else model.weights
if is_torch_neuronx_available() and isinstance(model, DataParallel):
model_weights = model.module.weights
else:
model_weights = model.weights
code = model_weights._c.code
start_str = "__parameters__ = ["
end_str = "]\n"
Expand Down

0 comments on commit f35fa16

Please sign in to comment.