Skip to content

Commit

Permalink
Jina loading (#54)
Browse files Browse the repository at this point in the history
* More general dense layer dtype casting

* Setting the prefixes from stanford-nlp

* Modifying logging to be more accurate

* Adding logic for model not handling vocabulary resizing (as jina-colbert-v2)

* Bumping version
  • Loading branch information
NohTow authored Sep 13, 2024
1 parent 1d9c184 commit 0c74287
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pylate/__version__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
VERSION = (1, 1, 0)
VERSION = (1, 1, 1)

__version__ = ".".join(map(str, VERSION))
25 changes: 20 additions & 5 deletions pylate/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ def __init__(
use_auth_token,
)
)
# Setting the prefixes from stanford-nlp models
self.query_prefix = "[unused0]"
self.document_prefix = "[unused1]"
logger.warning("Loaded the ColBERT model from Stanford NLP.")
else:
# Add a linear projection layer to the model in order to project the embeddings to the desired size
Expand All @@ -269,6 +272,7 @@ def __init__(
logger.warning(
f"The checkpoint does not contain a linear projection layer. Adding one with output dimensions ({hidden_size}, {embedding_size})."
)
logger.warning("Created a PyLate model from base encoder.")
self.append(
Dense(
in_features=hidden_size, out_features=embedding_size, bias=bias
Expand All @@ -294,14 +298,25 @@ def __init__(
else:
logger.warning("Pylate model loaded successfully.")

if model_kwargs is not None and "torch_dtype" in model_kwargs:
self[1].to(model_kwargs["torch_dtype"])
# Ensure all tensors in the model are of the same dtype as the first tensor
try:
dtype = next(self.parameters()).dtype
self.to(dtype)
except StopIteration:
pass

self.to(device)
self.is_hpu_graph_enabled = False

self.tokenizer.add_tokens([self.query_prefix, self.document_prefix])
self._first_module().auto_model.resize_token_embeddings(len(self.tokenizer))
# Try adding the prefixes to the tokenizer. We call resize_token_embeddings twice to ensure the tokens are added only if resize_token_embeddings works. There should be a better way to do this.
try:
self._first_module().auto_model.resize_token_embeddings(len(self.tokenizer))
self.tokenizer.add_tokens([self.query_prefix, self.document_prefix])
self._first_module().auto_model.resize_token_embeddings(len(self.tokenizer))
except NotImplementedError:
logger.warning(
"The tokenizer does not support resizing the token embeddings, the prefixes token have not been added to vocabulary."
)

self.document_prefix_id = self.tokenizer.convert_tokens_to_ids(
self.document_prefix
Expand Down Expand Up @@ -1072,7 +1087,7 @@ def _load_auto_model(
"""
logger.warning(
f"No sentence-transformers model found with name {model_name_or_path}. Creating a ColBERT model from base encoder."
f"No sentence-transformers model found with name {model_name_or_path}."
)

shared_kwargs = {
Expand Down

0 comments on commit 0c74287

Please sign in to comment.