From a6a97d3ed9b0f6bf341eac078bc1b6c2b7b0fd64 Mon Sep 17 00:00:00 2001 From: Antoine Chaffin Date: Mon, 28 Oct 2024 11:04:12 +0000 Subject: [PATCH] Use pytorch_model.bin to load stanford model instead of safetensor --- pylate/models/Dense.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/pylate/models/Dense.py b/pylate/models/Dense.py index b9f3575..2ae1783 100644 --- a/pylate/models/Dense.py +++ b/pylate/models/Dense.py @@ -2,7 +2,6 @@ import os import torch -from safetensors import safe_open from safetensors.torch import load_model as load_safetensors_model from sentence_transformers.models import Dense as DenseSentenceTransformer from sentence_transformers.util import import_from_string @@ -114,18 +113,23 @@ def from_stanford_weights( # Else download the model/use the cached version model_name_or_path = cached_file( model_name_or_path, - filename="model.safetensors", + filename="pytorch_model.bin", cache_dir=cache_folder, revision=revision, local_files_only=local_files_only, token=token, use_auth_token=use_auth_token, ) - # If the model a local folder, load the safetensor + # If the model a local folder, load the PyTorch model else: - model_name_or_path = os.path.join(model_name_or_path, "model.safetensors") - with safe_open(model_name_or_path, framework="pt", device="cpu") as f: - state_dict = {"linear.weight": f.get_tensor("linear.weight")} + model_name_or_path = os.path.join(model_name_or_path, "pytorch_model.bin") + + # Load the state dict using torch.load instead of safe_open + state_dict = { + "linear.weight": torch.load(model_name_or_path, map_location="cpu")[ + "linear.weight" + ] + } # Determine input and output dimensions in_features = state_dict["linear.weight"].shape[1]