Skip to content

Commit

Permalink
Use pytorch_model.bin to load stanford model instead of safetensor
Browse files Browse the repository at this point in the history
  • Loading branch information
NohTow committed Oct 28, 2024
1 parent b34bba0 commit a6a97d3
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions pylate/models/Dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os

import torch
from safetensors import safe_open
from safetensors.torch import load_model as load_safetensors_model
from sentence_transformers.models import Dense as DenseSentenceTransformer
from sentence_transformers.util import import_from_string
Expand Down Expand Up @@ -114,18 +113,23 @@ def from_stanford_weights(
# Else download the model/use the cached version
model_name_or_path = cached_file(
model_name_or_path,
filename="model.safetensors",
filename="pytorch_model.bin",
cache_dir=cache_folder,
revision=revision,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
)
# If the model a local folder, load the safetensor
# If the model a local folder, load the PyTorch model
else:
model_name_or_path = os.path.join(model_name_or_path, "model.safetensors")
with safe_open(model_name_or_path, framework="pt", device="cpu") as f:
state_dict = {"linear.weight": f.get_tensor("linear.weight")}
model_name_or_path = os.path.join(model_name_or_path, "pytorch_model.bin")

# Load the state dict using torch.load instead of safe_open
state_dict = {
"linear.weight": torch.load(model_name_or_path, map_location="cpu")[
"linear.weight"
]
}

# Determine input and output dimensions
in_features = state_dict["linear.weight"].shape[1]
Expand Down

0 comments on commit a6a97d3

Please sign in to comment.