-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rename linear projection #37
Changes from all commits
49032a8
6a51db4
84059a9
ed0d433
eb1e7de
c5f9085
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import json | ||
import os | ||
|
||
import torch | ||
from safetensors.torch import load_model as load_safetensors_model | ||
from sentence_transformers.models import Dense as DenseSentenceTransformer | ||
from sentence_transformers.util import import_from_string | ||
from torch import nn | ||
|
||
__all__ = ["Dense"] | ||
|
||
|
||
class Dense(DenseSentenceTransformer): | ||
"""Performs linear projection on the token embeddings to a lower dimension. | ||
|
||
Parameters | ||
---------- | ||
in_features | ||
Size of the embeddings in output of the tansformer. | ||
out_features | ||
Size of the output embeddings after linear projection | ||
bias | ||
Add a bias vector | ||
init_weight | ||
Initial value for the matrix of the linear layer | ||
init_bias | ||
Initial value for the bias of the linear layer. | ||
|
||
Examples | ||
-------- | ||
>>> from giga_cherche import models | ||
|
||
>>> model = models.Dense( | ||
... in_features=768, | ||
... out_features=128, | ||
... ) | ||
|
||
>>> features = { | ||
... "token_embeddings": torch.randn(2, 768), | ||
... } | ||
|
||
>>> projected_features = model(features) | ||
|
||
>>> assert projected_features["token_embeddings"].shape == (2, 128) | ||
>>> assert isinstance(model, DenseSentenceTransformer) | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
bias: bool = True, | ||
activation_function=nn.Identity(), | ||
init_weight: torch.Tensor = None, | ||
init_bias: torch.Tensor = None, | ||
) -> None: | ||
super(Dense, self).__init__( | ||
in_features, out_features, bias, activation_function, init_weight, init_bias | ||
) | ||
|
||
def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: | ||
"""Performs linear projection on the token embeddings.""" | ||
token_embeddings = features["token_embeddings"] | ||
projected_embeddings = self.linear(token_embeddings) | ||
features["token_embeddings"] = projected_embeddings | ||
return features | ||
|
||
@staticmethod | ||
def from_sentence_transformers(dense_st: DenseSentenceTransformer): | ||
config = dense_st.get_config_dict() | ||
config["activation_function"] = nn.Identity() | ||
model = Dense(**config) | ||
model.load_state_dict(dense_st.state_dict()) | ||
return model | ||
|
||
@staticmethod | ||
def load(input_path): | ||
with open(os.path.join(input_path, "config.json")) as fIn: | ||
config = json.load(fIn) | ||
|
||
config["activation_function"] = import_from_string( | ||
config["activation_function"] | ||
)() | ||
model = Dense(**config) | ||
if os.path.exists(os.path.join(input_path, "model.safetensors")): | ||
load_safetensors_model(model, os.path.join(input_path, "model.safetensors")) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can remove the else and return the model in the if after |
||
model.load_state_dict( | ||
torch.load( | ||
os.path.join(input_path, "pytorch_model.bin"), | ||
map_location=torch.device("cpu"), | ||
) | ||
) | ||
return model |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from .colbert import ColBERT | ||
from .LinearProjection import LinearProjection | ||
from .Dense import Dense | ||
|
||
__all__ = ["ColBERT", "LinearProjection"] | ||
__all__ = ["ColBERT", "Dense"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
SentenceTransformerModelCardData, | ||
generate_model_card, | ||
) | ||
from sentence_transformers.models import Dense as DenseSentenceTransformer | ||
from sentence_transformers.models import Transformer | ||
from sentence_transformers.quantization import quantize_embeddings | ||
from sentence_transformers.similarity_functions import SimilarityFunction | ||
|
@@ -32,7 +33,7 @@ | |
from tqdm.autonotebook import trange | ||
|
||
from ..utils import _start_multi_process_pool | ||
from .LinearProjection import LinearProjection | ||
from .Dense import Dense | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -196,7 +197,8 @@ def __init__( | |
token: bool | str | None = None, | ||
use_auth_token: bool | str | None = None, | ||
truncate_dim: int | None = None, | ||
embedding_size: int | None = 128, | ||
embedding_size: int | None = None, | ||
bias: bool = False, | ||
query_prefix: str | None = "[Q] ", | ||
document_prefix: str | None = "[D] ", | ||
add_special_tokens: bool = True, | ||
|
@@ -242,18 +244,38 @@ def __init__( | |
model_card_data=model_card_data, | ||
) | ||
|
||
hidden_size = self._modules["0"].get_word_embedding_dimension() | ||
hidden_size = self[0].get_word_embedding_dimension() | ||
|
||
# If there is no linear projection layer, add one | ||
# TODO: do this more cleanly | ||
if len(self._modules) < 2: | ||
if self.num_modules() < 2: | ||
# Add a linear projection layer to the model in order to project the embeddings to the desired size | ||
embedding_size = embedding_size or 128 | ||
self.append( | ||
Dense(in_features=hidden_size, out_features=embedding_size, bias=bias) | ||
) | ||
logger.warning( | ||
f"The checkpoint does not contain a linear projection layer. Adding one with output dimensions ({hidden_size}, {embedding_size})" | ||
) | ||
# Add a linear projection layer to the model in order to project the embeddings to the desired size | ||
self._modules[f"{len(self._modules)}"] = LinearProjection( | ||
in_features=hidden_size, out_features=embedding_size, bias=False | ||
|
||
elif ( | ||
embedding_size is not None | ||
and self[1].get_sentence_embedding_dimension() != embedding_size | ||
): | ||
logger.warning( | ||
f"The checkpoint contains a dense layer but with incorrect dimension. Replacing it with a Dense layer with output dimensions ({hidden_size}, {embedding_size})" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be great to print the actual embedding dimension you found self[1].get_sentence_embedding_dimension() |
||
) | ||
self[1] = Dense( | ||
in_features=hidden_size, out_features=embedding_size, bias=bias | ||
) | ||
# If it is an instance of Dense from ST, convert it to our Dense to remove activation function in the forward pass, else the layer is already correct | ||
elif not isinstance(self[1], Dense): | ||
logger.warning( | ||
f"Converting the existing Dense layer from SentenceTransform with output dimensions ({hidden_size}, {self[1].get_sentence_embedding_dimension()})" | ||
) | ||
self[1] = Dense.from_sentence_transformers(self[1]) | ||
else: | ||
logger.warning("Correctly loaded the Dense layer") | ||
|
||
self.to(device) | ||
self.is_hpu_graph_enabled = False | ||
|
@@ -294,6 +316,18 @@ def __init__( | |
def load(input_path) -> "ColBERT": | ||
return ColBERT(model_name_or_path=input_path) | ||
|
||
def num_modules(self) -> int: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can rename this function by def __len__(self) -> int:
return len(_self.modules) then you can call len(self) to get the actual number of modules |
||
return len(self._modules) | ||
|
||
def convert_dense_layer_from_sentence_transformer( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a docstring here ? |
||
self, in_features: int, out_features: int | ||
): | ||
dense_layer = Dense( | ||
in_features=in_features, out_features=out_features, bias=False | ||
) | ||
dense_layer.load_state_dict(self[1].state_dict(), strict=False) | ||
self[1] = dense_layer | ||
|
||
@staticmethod | ||
def insert_prefix_token(input_ids: torch.Tensor, prefix_id: int) -> torch.Tensor: | ||
"""Inserts a prefix token at the beginning of each sequence in the input tensor.""" | ||
|
@@ -1120,5 +1154,6 @@ def _load_sbert_model( | |
return [ | ||
module | ||
for module in modules.values() | ||
if isinstance(module, Transformer) or isinstance(module, LinearProjection) | ||
if isinstance(module, Transformer) | ||
or isinstance(module, DenseSentenceTransformer) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth to add a docstring here to explain that the Sentence Transformer has an activation function and we don't want one