diff --git a/pylate/models/Dense.py b/pylate/models/Dense.py index b9f3575..2ae1783 100644 --- a/pylate/models/Dense.py +++ b/pylate/models/Dense.py @@ -2,7 +2,6 @@ import os import torch -from safetensors import safe_open from safetensors.torch import load_model as load_safetensors_model from sentence_transformers.models import Dense as DenseSentenceTransformer from sentence_transformers.util import import_from_string @@ -114,18 +113,23 @@ def from_stanford_weights( # Else download the model/use the cached version model_name_or_path = cached_file( model_name_or_path, - filename="model.safetensors", + filename="pytorch_model.bin", cache_dir=cache_folder, revision=revision, local_files_only=local_files_only, token=token, use_auth_token=use_auth_token, ) - # If the model a local folder, load the safetensor + # If the model a local folder, load the PyTorch model else: - model_name_or_path = os.path.join(model_name_or_path, "model.safetensors") - with safe_open(model_name_or_path, framework="pt", device="cpu") as f: - state_dict = {"linear.weight": f.get_tensor("linear.weight")} + model_name_or_path = os.path.join(model_name_or_path, "pytorch_model.bin") + + # Load the state dict using torch.load instead of safe_open + state_dict = { + "linear.weight": torch.load(model_name_or_path, map_location="cpu")[ + "linear.weight" + ] + } # Determine input and output dimensions in_features = state_dict["linear.weight"].shape[1]