Skip to content

Commit

Permalink
Correctly setting safetensor path if the stanford model is a local fo…
Browse files Browse the repository at this point in the history
…lder
  • Loading branch information
NohTow committed Oct 22, 2024
1 parent b89517f commit ba4efc9
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pylate/models/Dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def from_stanford_weights(
token=token,
use_auth_token=use_auth_token,
)
# If the model a local folder, load the safetensor
else:
model_name_or_path = os.path.join(model_name_or_path, "model.safetensors")
with safe_open(model_name_or_path, framework="pt", device="cpu") as f:
state_dict = {"linear.weight": f.get_tensor("linear.weight")}

Expand Down

0 comments on commit ba4efc9

Please sign in to comment.