Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added move_to_cpu flag in encode method to avoid OOM for token embeddings. #1812

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def encode(self, sentences: Union[str, List[str]],
output_value: str = 'sentence_embedding',
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
move_to_cpu: bool = False,
device: str = None,
normalize_embeddings: bool = False) -> Union[List[Tensor], ndarray, Tensor]:
"""
Expand All @@ -125,6 +126,7 @@ def encode(self, sentences: Union[str, List[str]],
:param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
:param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
:param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
:param move_to_cpu: If true, the obtained embedding tensors are sequentially moved to the CPU.
:param device: Which torch.device to use for the computation
:param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.

Expand All @@ -142,6 +144,9 @@ def encode(self, sentences: Union[str, List[str]],
convert_to_tensor = False
convert_to_numpy = False

if convert_to_numpy:
move_to_cpu = True

input_was_string = False
if isinstance(sentences, str) or not hasattr(sentences, '__len__'): #Cast an individual sentence to a list with length 1
sentences = [sentences]
Expand Down Expand Up @@ -171,7 +176,10 @@ def encode(self, sentences: Union[str, List[str]],
while last_mask_id > 0 and attention[last_mask_id].item() == 0:
last_mask_id -= 1

embeddings.append(token_emb[0:last_mask_id+1])
token_embeddings = token_emb[0:last_mask_id+1]
if move_to_cpu:
token_embeddings = token_embeddings.cpu()
embeddings.append(token_embeddings)
elif output_value is None: #Return all outputs
embeddings = []
for sent_idx in range(len(out_features['sentence_embedding'])):
Expand All @@ -184,7 +192,7 @@ def encode(self, sentences: Union[str, List[str]],
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

# fixes for #522 and #487 to avoid oom problems on gpu with large datasets
if convert_to_numpy:
if move_to_cpu:
embeddings = embeddings.cpu()

all_embeddings.extend(embeddings)
Expand Down