diff --git a/pylate/models/Dense.py b/pylate/models/Dense.py index fb37256..b9f3575 100644 --- a/pylate/models/Dense.py +++ b/pylate/models/Dense.py @@ -121,6 +121,9 @@ def from_stanford_weights( token=token, use_auth_token=use_auth_token, ) + # If the model a local folder, load the safetensor + else: + model_name_or_path = os.path.join(model_name_or_path, "model.safetensors") with safe_open(model_name_or_path, framework="pt", device="cpu") as f: state_dict = {"linear.weight": f.get_tensor("linear.weight")}