Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename linear projection #37

Merged
merged 6 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions giga_cherche/models/Dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import json
import os

import torch
from safetensors.torch import load_model as load_safetensors_model
from sentence_transformers.models import Dense as DenseSentenceTransformer
from sentence_transformers.util import import_from_string
from torch import nn

__all__ = ["Dense"]


class Dense(DenseSentenceTransformer):
"""Performs linear projection on the token embeddings to a lower dimension.

Parameters
----------
in_features
Size of the embeddings in output of the tansformer.
out_features
Size of the output embeddings after linear projection
bias
Add a bias vector
init_weight
Initial value for the matrix of the linear layer
init_bias
Initial value for the bias of the linear layer.

Examples
--------
>>> from giga_cherche import models

>>> model = models.Dense(
... in_features=768,
... out_features=128,
... )

>>> features = {
... "token_embeddings": torch.randn(2, 768),
... }

>>> projected_features = model(features)

>>> assert projected_features["token_embeddings"].shape == (2, 128)
>>> assert isinstance(model, DenseSentenceTransformer)

"""

def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
activation_function=nn.Identity(),
init_weight: torch.Tensor = None,
init_bias: torch.Tensor = None,
) -> None:
super(Dense, self).__init__(
in_features, out_features, bias, activation_function, init_weight, init_bias
)

def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Performs linear projection on the token embeddings."""
token_embeddings = features["token_embeddings"]
projected_embeddings = self.linear(token_embeddings)
features["token_embeddings"] = projected_embeddings
return features

@staticmethod
def from_sentence_transformers(dense_st: DenseSentenceTransformer):
Copy link
Collaborator

@raphaelsty raphaelsty Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth to add a docstring here to explain that the Sentence Transformer has an activation function and we don't want one

config = dense_st.get_config_dict()
config["activation_function"] = nn.Identity()
model = Dense(**config)
model.load_state_dict(dense_st.state_dict())
return model

@staticmethod
def load(input_path):
with open(os.path.join(input_path, "config.json")) as fIn:
config = json.load(fIn)

config["activation_function"] = import_from_string(
config["activation_function"]
)()
model = Dense(**config)
if os.path.exists(os.path.join(input_path, "model.safetensors")):
load_safetensors_model(model, os.path.join(input_path, "model.safetensors"))
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove the else and return the model in the if after load_safetensors_model
It will enable to designed model_load_state_dict

model.load_state_dict(
torch.load(
os.path.join(input_path, "pytorch_model.bin"),
map_location=torch.device("cpu"),
)
)
return model
107 changes: 0 additions & 107 deletions giga_cherche/models/LinearProjection.py

This file was deleted.

4 changes: 2 additions & 2 deletions giga_cherche/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .colbert import ColBERT
from .LinearProjection import LinearProjection
from .Dense import Dense

__all__ = ["ColBERT", "LinearProjection"]
__all__ = ["ColBERT", "Dense"]
51 changes: 43 additions & 8 deletions giga_cherche/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
SentenceTransformerModelCardData,
generate_model_card,
)
from sentence_transformers.models import Dense as DenseSentenceTransformer
from sentence_transformers.models import Transformer
from sentence_transformers.quantization import quantize_embeddings
from sentence_transformers.similarity_functions import SimilarityFunction
Expand All @@ -32,7 +33,7 @@
from tqdm.autonotebook import trange

from ..utils import _start_multi_process_pool
from .LinearProjection import LinearProjection
from .Dense import Dense

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -196,7 +197,8 @@ def __init__(
token: bool | str | None = None,
use_auth_token: bool | str | None = None,
truncate_dim: int | None = None,
embedding_size: int | None = 128,
embedding_size: int | None = None,
bias: bool = False,
query_prefix: str | None = "[Q] ",
document_prefix: str | None = "[D] ",
add_special_tokens: bool = True,
Expand Down Expand Up @@ -242,18 +244,38 @@ def __init__(
model_card_data=model_card_data,
)

hidden_size = self._modules["0"].get_word_embedding_dimension()
hidden_size = self[0].get_word_embedding_dimension()

# If there is no linear projection layer, add one
# TODO: do this more cleanly
if len(self._modules) < 2:
if self.num_modules() < 2:
# Add a linear projection layer to the model in order to project the embeddings to the desired size
embedding_size = embedding_size or 128
self.append(
Dense(in_features=hidden_size, out_features=embedding_size, bias=bias)
)
logger.warning(
f"The checkpoint does not contain a linear projection layer. Adding one with output dimensions ({hidden_size}, {embedding_size})"
)
# Add a linear projection layer to the model in order to project the embeddings to the desired size
self._modules[f"{len(self._modules)}"] = LinearProjection(
in_features=hidden_size, out_features=embedding_size, bias=False

elif (
embedding_size is not None
and self[1].get_sentence_embedding_dimension() != embedding_size
):
logger.warning(
f"The checkpoint contains a dense layer but with incorrect dimension. Replacing it with a Dense layer with output dimensions ({hidden_size}, {embedding_size})"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to print the actual embedding dimension you found self[1].get_sentence_embedding_dimension()

)
self[1] = Dense(
in_features=hidden_size, out_features=embedding_size, bias=bias
)
# If it is an instance of Dense from ST, convert it to our Dense to remove activation function in the forward pass, else the layer is already correct
elif not isinstance(self[1], Dense):
logger.warning(
f"Converting the existing Dense layer from SentenceTransform with output dimensions ({hidden_size}, {self[1].get_sentence_embedding_dimension()})"
)
self[1] = Dense.from_sentence_transformers(self[1])
else:
logger.warning("Correctly loaded the Dense layer")

self.to(device)
self.is_hpu_graph_enabled = False
Expand Down Expand Up @@ -294,6 +316,18 @@ def __init__(
def load(input_path) -> "ColBERT":
return ColBERT(model_name_or_path=input_path)

def num_modules(self) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can rename this function by

def __len__(self) -> int: 
    return len(_self.modules)

then you can call len(self) to get the actual number of modules

return len(self._modules)

def convert_dense_layer_from_sentence_transformer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a docstring here ?

self, in_features: int, out_features: int
):
dense_layer = Dense(
in_features=in_features, out_features=out_features, bias=False
)
dense_layer.load_state_dict(self[1].state_dict(), strict=False)
self[1] = dense_layer

@staticmethod
def insert_prefix_token(input_ids: torch.Tensor, prefix_id: int) -> torch.Tensor:
"""Inserts a prefix token at the beginning of each sequence in the input tensor."""
Expand Down Expand Up @@ -1120,5 +1154,6 @@ def _load_sbert_model(
return [
module
for module in modules.values()
if isinstance(module, Transformer) or isinstance(module, LinearProjection)
if isinstance(module, Transformer)
or isinstance(module, DenseSentenceTransformer)
]
Loading