From 554ae17cdfad1eccc4f6083bd7bb3f320829d6e6 Mon Sep 17 00:00:00 2001 From: Antoine Chaffin Date: Wed, 11 Sep 2024 14:00:42 +0000 Subject: [PATCH 1/4] Adding off-the-shelf Stanford-NLP loading, removing the add_pooling_layer parameters and casting the Dense layer to dtype if set in model_kwargs --- pylate/models/Dense.py | 36 ++++++++++++++++++++++++++++++++++++ pylate/models/colbert.py | 22 +++++++++++++++++----- 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/pylate/models/Dense.py b/pylate/models/Dense.py index 71107f7..56f493c 100644 --- a/pylate/models/Dense.py +++ b/pylate/models/Dense.py @@ -2,10 +2,12 @@ import os import torch +from safetensors import safe_open from safetensors.torch import load_model as load_safetensors_model from sentence_transformers.models import Dense as DenseSentenceTransformer from sentence_transformers.util import import_from_string from torch import nn +from transformers.utils import cached_file __all__ = ["Dense"] @@ -77,6 +79,40 @@ def from_sentence_transformers(dense: DenseSentenceTransformer) -> "Dense": model.load_state_dict(dense.state_dict()) return model + @staticmethod + def from_stanford_weights( + model_name_or_path, + cache_folder, + revision, + local_files_only, + token, + use_auth_token, + ) -> "Dense": + # Check if the model is locally available + if not (os.path.exists(os.path.join(model_name_or_path))): + # Else download the model/use the cached version + model_name_or_path = cached_file( + model_name_or_path, + filename="model.safetensors", + cache_dir=cache_folder, + revision=revision, + local_files_only=local_files_only, + token=token, + use_auth_token=use_auth_token, + ) + with safe_open(model_name_or_path, framework="pt", device="cpu") as f: + state_dict = {"linear.weight": f.get_tensor("linear.weight")} + + # Determine input and output dimensions + in_features = state_dict["linear.weight"].shape[1] + out_features = state_dict["linear.weight"].shape[0] + + # Create Dense layer instance + model = Dense(in_features=in_features, out_features=out_features, bias=False) + + model.load_state_dict(state_dict) + return model + @staticmethod def load(input_path) -> "Dense": """Load a Dense layer.""" diff --git a/pylate/models/colbert.py b/pylate/models/colbert.py index b2f0251..5aa0ab5 100644 --- a/pylate/models/colbert.py +++ b/pylate/models/colbert.py @@ -219,9 +219,6 @@ def __init__( config_kwargs: dict | None = None, model_card_data: Optional[SentenceTransformerModelCardData] = None, ) -> None: - model_kwargs = {} if model_kwargs is None else model_kwargs - model_kwargs["add_pooling_layer"] = False - self.query_prefix = query_prefix self.document_prefix = document_prefix self.query_length = query_length @@ -250,9 +247,21 @@ def __init__( ) hidden_size = self[0].get_word_embedding_dimension() - + # If the model is a stanford-nlp ColBERT, load the weights of the dense layer + if self[0].auto_model.config.architectures[0] == "HF_ColBERT": + self.append( + Dense.from_stanford_weights( + model_name_or_path, + cache_folder, + revision, + local_files_only, + token, + use_auth_token, + ) + ) + logger.warning("Loaded the ColBERT model from Stanford NLP.") # Add a linear projection layer to the model in order to project the embeddings to the desired size. - if len(self) < 2: + elif len(self) < 2: # Add a linear projection layer to the model in order to project the embeddings to the desired size embedding_size = embedding_size or 128 @@ -282,6 +291,9 @@ def __init__( else: logger.warning("Pylate model loaded successfully.") + if model_kwargs is not None and "torch_dtype" in model_kwargs: + self[1].to(model_kwargs["torch_dtype"]) + self.to(device) self.is_hpu_graph_enabled = False From 0a50f9763cdc959b34e480269568acba92555ec9 Mon Sep 17 00:00:00 2001 From: Antoine Chaffin Date: Thu, 12 Sep 2024 08:09:45 +0000 Subject: [PATCH 2/4] Adding a docstring --- pylate/models/Dense.py | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/pylate/models/Dense.py b/pylate/models/Dense.py index 56f493c..fb37256 100644 --- a/pylate/models/Dense.py +++ b/pylate/models/Dense.py @@ -81,13 +81,34 @@ def from_sentence_transformers(dense: DenseSentenceTransformer) -> "Dense": @staticmethod def from_stanford_weights( - model_name_or_path, - cache_folder, - revision, - local_files_only, - token, - use_auth_token, + model_name_or_path: str | os.PathLike, + cache_folder: str | os.PathLike | None = None, + revision: str | None = None, + local_files_only: bool | None = None, + token: str | bool | None = None, + use_auth_token: str | bool | None = None, ) -> "Dense": + """Load the weight of the Dense layer using weights from a stanford-nlp checkpoint. + + Parameters + ---------- + model_name_or_path (`str` or `os.PathLike`): + This can be either: + - a string, the *model id* of a model repo on huggingface.co. + - a path to a *directory* potentially containing the file. + cache_folder (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + """ # Check if the model is locally available if not (os.path.exists(os.path.join(model_name_or_path))): # Else download the model/use the cached version From ba13c2ea72a51161648c9c01e740b8730c28a489 Mon Sep 17 00:00:00 2001 From: Antoine Chaffin Date: Thu, 12 Sep 2024 08:10:46 +0000 Subject: [PATCH 3/4] Change to not load both PyLate AND stanford dense when a repository contains both --- pylate/models/colbert.py | 49 +++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/pylate/models/colbert.py b/pylate/models/colbert.py index 5aa0ab5..a5d279b 100644 --- a/pylate/models/colbert.py +++ b/pylate/models/colbert.py @@ -245,32 +245,35 @@ def __init__( config_kwargs=config_kwargs, model_card_data=model_card_data, ) - hidden_size = self[0].get_word_embedding_dimension() - # If the model is a stanford-nlp ColBERT, load the weights of the dense layer - if self[0].auto_model.config.architectures[0] == "HF_ColBERT": - self.append( - Dense.from_stanford_weights( - model_name_or_path, - cache_folder, - revision, - local_files_only, - token, - use_auth_token, - ) - ) - logger.warning("Loaded the ColBERT model from Stanford NLP.") + # Add a linear projection layer to the model in order to project the embeddings to the desired size. - elif len(self) < 2: - # Add a linear projection layer to the model in order to project the embeddings to the desired size - embedding_size = embedding_size or 128 + if len(self) < 2: + # If the model is a stanford-nlp ColBERT, load the weights of the dense layer + if self[0].auto_model.config.architectures[0] == "HF_ColBERT": + self.append( + Dense.from_stanford_weights( + model_name_or_path, + cache_folder, + revision, + local_files_only, + token, + use_auth_token, + ) + ) + logger.warning("Loaded the ColBERT model from Stanford NLP.") + else: + # Add a linear projection layer to the model in order to project the embeddings to the desired size + embedding_size = embedding_size or 128 - logger.warning( - f"The checkpoint does not contain a linear projection layer. Adding one with output dimensions ({hidden_size}, {embedding_size})." - ) - self.append( - Dense(in_features=hidden_size, out_features=embedding_size, bias=bias) - ) + logger.warning( + f"The checkpoint does not contain a linear projection layer. Adding one with output dimensions ({hidden_size}, {embedding_size})." + ) + self.append( + Dense( + in_features=hidden_size, out_features=embedding_size, bias=bias + ) + ) elif ( embedding_size is not None From 5daff31f7bdf6b7abdfb5a0c059bcc35dd48b808 Mon Sep 17 00:00:00 2001 From: Antoine Chaffin Date: Thu, 12 Sep 2024 08:11:17 +0000 Subject: [PATCH 4/4] Fixing the version of ST to pre 3.1 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index fa8060a..14dd39d 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ long_description = fh.read() base_packages = [ - "sentence-transformers >= 3.0.1", + "sentence-transformers == 3.0.1", "datasets >= 2.20.0", "accelerate >= 0.31.0", "voyager >= 2.0.9",