Skip to content

Commit

Permalink
Merge pull request #52 from lightonai/loading_logic
Browse files Browse the repository at this point in the history
Loading logic rework
  • Loading branch information
raphaelsty authored Sep 12, 2024
2 parents 68d3b86 + 5daff31 commit b647061
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 13 deletions.
57 changes: 57 additions & 0 deletions pylate/models/Dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import os

import torch
from safetensors import safe_open
from safetensors.torch import load_model as load_safetensors_model
from sentence_transformers.models import Dense as DenseSentenceTransformer
from sentence_transformers.util import import_from_string
from torch import nn
from transformers.utils import cached_file

__all__ = ["Dense"]

Expand Down Expand Up @@ -77,6 +79,61 @@ def from_sentence_transformers(dense: DenseSentenceTransformer) -> "Dense":
model.load_state_dict(dense.state_dict())
return model

@staticmethod
def from_stanford_weights(
model_name_or_path: str | os.PathLike,
cache_folder: str | os.PathLike | None = None,
revision: str | None = None,
local_files_only: bool | None = None,
token: str | bool | None = None,
use_auth_token: str | bool | None = None,
) -> "Dense":
"""Load the weight of the Dense layer using weights from a stanford-nlp checkpoint.
Parameters
----------
model_name_or_path (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a model repo on huggingface.co.
- a path to a *directory* potentially containing the file.
cache_folder (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
"""
# Check if the model is locally available
if not (os.path.exists(os.path.join(model_name_or_path))):
# Else download the model/use the cached version
model_name_or_path = cached_file(
model_name_or_path,
filename="model.safetensors",
cache_dir=cache_folder,
revision=revision,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
)
with safe_open(model_name_or_path, framework="pt", device="cpu") as f:
state_dict = {"linear.weight": f.get_tensor("linear.weight")}

# Determine input and output dimensions
in_features = state_dict["linear.weight"].shape[1]
out_features = state_dict["linear.weight"].shape[0]

# Create Dense layer instance
model = Dense(in_features=in_features, out_features=out_features, bias=False)

model.load_state_dict(state_dict)
return model

@staticmethod
def load(input_path) -> "Dense":
"""Load a Dense layer."""
Expand Down
39 changes: 27 additions & 12 deletions pylate/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,6 @@ def __init__(
config_kwargs: dict | None = None,
model_card_data: Optional[SentenceTransformerModelCardData] = None,
) -> None:
model_kwargs = {} if model_kwargs is None else model_kwargs
model_kwargs["add_pooling_layer"] = False

self.query_prefix = query_prefix
self.document_prefix = document_prefix
self.query_length = query_length
Expand All @@ -248,20 +245,35 @@ def __init__(
config_kwargs=config_kwargs,
model_card_data=model_card_data,
)

hidden_size = self[0].get_word_embedding_dimension()

# Add a linear projection layer to the model in order to project the embeddings to the desired size.
if len(self) < 2:
# Add a linear projection layer to the model in order to project the embeddings to the desired size
embedding_size = embedding_size or 128
# If the model is a stanford-nlp ColBERT, load the weights of the dense layer
if self[0].auto_model.config.architectures[0] == "HF_ColBERT":
self.append(
Dense.from_stanford_weights(
model_name_or_path,
cache_folder,
revision,
local_files_only,
token,
use_auth_token,
)
)
logger.warning("Loaded the ColBERT model from Stanford NLP.")
else:
# Add a linear projection layer to the model in order to project the embeddings to the desired size
embedding_size = embedding_size or 128

logger.warning(
f"The checkpoint does not contain a linear projection layer. Adding one with output dimensions ({hidden_size}, {embedding_size})."
)
self.append(
Dense(in_features=hidden_size, out_features=embedding_size, bias=bias)
)
logger.warning(
f"The checkpoint does not contain a linear projection layer. Adding one with output dimensions ({hidden_size}, {embedding_size})."
)
self.append(
Dense(
in_features=hidden_size, out_features=embedding_size, bias=bias
)
)

elif (
embedding_size is not None
Expand All @@ -282,6 +294,9 @@ def __init__(
else:
logger.warning("Pylate model loaded successfully.")

if model_kwargs is not None and "torch_dtype" in model_kwargs:
self[1].to(model_kwargs["torch_dtype"])

self.to(device)
self.is_hpu_graph_enabled = False

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
long_description = fh.read()

base_packages = [
"sentence-transformers >= 3.0.1",
"sentence-transformers == 3.0.1",
"datasets >= 2.20.0",
"accelerate >= 0.31.0",
"voyager >= 2.0.9",
Expand Down

0 comments on commit b647061

Please sign in to comment.