Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Quantization of token embeddings #2885

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

kacperlukawski
Copy link

Problem

The encode method raises a ValueError when we request precision different than float32 and output_value="token_embeddings", as reported in #2882.

Solution

This PR provides a fix that combines all the token embeddings into a single array, runs the normalization, and eventually reconstructs the shape of the original array so we can distinguish token embeddings coming from each input example.

@kacperlukawski kacperlukawski changed the title [fix] Quantization of token embeddings [fix] Quantization of token embeddings Aug 8, 2024
if not isinstance(embeddings[0], list) and len(embeddings[0].shape) == 2:
# It will happen when we request token_embeddings
lengths = [embedding.shape[0] for embedding in embeddings]
embeddings = np.concatenate(embeddings)
if isinstance(embeddings[0], Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this if statement be above the previous if, as sending in a list of Tensors is also valid?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ir2718 You were absolutely right, thank you! Changed the order of the statements.

@tomaarsen
Copy link
Collaborator

Hello!

Apologies, I haven't yet had time to look into this a bit deeper, but I think an edge case that might be missed is output_value=None. This is not very well documented, but it returns both the sentence embedding and the token embeddings. I can imagine that this might be valuable for some use cases.

  • Tom Aarsen

@ir2718
Copy link
Contributor

ir2718 commented Aug 9, 2024

Not sure if I can modify the PR, but following Tom's dict edge case, I think adding this should suffice:

        if isinstance(embeddings[0], dict):
            sentence_embeddings = [x["sentence_embedding"].unsqueeze(0).cpu().numpy() for x in embeddings]

            token_embeddings = []
            for emb_dict in embeddings:
                token_emb = emb_dict["token_embeddings"]
                attention = emb_dict["attention_mask"]
                last_mask_id = len(attention) - 1
                while last_mask_id > 0 and attention[last_mask_id].item() == 0:
                    last_mask_id -= 1

                token_embeddings.append(token_emb[0 : last_mask_id + 1])

            token_embeddings = [x.cpu().numpy() for x in token_embeddings]
            embeddings = token_embeddings + sentence_embeddings
            lengths = [x.shape[0] for x in embeddings]

with a modification in SentenceTransformer.py, line 638, right before the return statement:

        if output_value is None:
            return {
                "token_embeddings": all_embeddings[:len(all_embeddings)//2],
                "sentence_embedding": all_embeddings[len(all_embeddings)//2:]
            }

@kacperlukawski
Copy link
Author

Thanks, @ir2718! I wonder whether we should return a dictionary. That breaks the interface of the encode method. @tomaarsen Would that be the expected behaviour?

@kacperlukawski
Copy link
Author

kacperlukawski commented Aug 30, 2024

I decided to implement the quantization for this edge case differently than suggested. The quantize_embeddings wasn't modified, but I extended the encode method. The all_embeddings were already a dictionary there, so I combined token and sentence embeddings and passed them all together to quantize. The output dictionary structure remains unchanged except for a different precision.

@ir2718 I didn't use the attention mask on purpose. I thought it would be best to keep the shapes consistent, no matter if we use float32 or any other precision.

@ir2718
Copy link
Contributor

ir2718 commented Aug 30, 2024

I wonder whether we should return a dictionary. That breaks the interface of the encode method.

Agreed, I was thinking about that myself, but since transformers mostly handle things in dicts my first idea was to implement it that way. Not breaking the interface is probably a better solution, but requires adding some kind of note in the docs about the ordering of embeddings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants