Skip to content

Commit

Permalink
add optional model_revision to SentenceTransformer
Browse files Browse the repository at this point in the history
  • Loading branch information
plynch-chwy committed Nov 3, 2023
1 parent c5f93f7 commit 0014d2b
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ class SentenceTransformer(nn.Sequential):
:param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used.
:param cache_folder: Path to store models. Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable.
:param use_auth_token: HuggingFace authentication token to download private models.
:param model_revision: The specific model version to use. It can be a branch name, a tag name, or a commit id, for a stored model on huggingface.co.
"""
def __init__(self, model_name_or_path: Optional[str] = None,
modules: Optional[Iterable[nn.Module]] = None,
device: Optional[str] = None,
cache_folder: Optional[str] = None,
use_auth_token: Union[bool, str, None] = None
use_auth_token: Union[bool, str, None] = None,
model_revision: Optional[str] = None
):
self._model_card_vars = {}
self._model_card_text = None
Expand Down Expand Up @@ -89,7 +91,8 @@ def __init__(self, model_name_or_path: Optional[str] = None,
library_name='sentence-transformers',
library_version=__version__,
ignore_files=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'],
use_auth_token=use_auth_token)
use_auth_token=use_auth_token,
revision=model_revision)

if os.path.exists(os.path.join(model_path, 'modules.json')): #Load as SentenceTransformer model
modules = self._load_sbert_model(model_path)
Expand Down

0 comments on commit 0014d2b

Please sign in to comment.