Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CLS token pooling in text embedding #385

Merged
merged 6 commits into from
Aug 6, 2024

Conversation

nyo16
Copy link
Contributor

@nyo16 nyo16 commented Aug 2, 2024

This PR adding support for token pooling for models like BGE-M3.

model_id = "BAAI/bge-m3"
{:ok, model_info} = Bumblebee.load_model({:hf,  model_id})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_id})
serving =
      Bumblebee.Text.text_embedding(model_info, tokenizer,
        output_attribute: :hidden_state,
        output_pool: :cls_token_pooling,
        embedding_processor: :l2_norm
      )

Nx.Serving.run(serving, "A long text to test the embeddings")
%{
  embedding: #Nx.Tensor<
    f32[1024]
    [-0.04741977900266647, -0.02085975557565689, -0.028225498273968697, 0.02299957536160946, -0.011559668928384781, -0.07012897729873657, 0.03340381383895874, -0.01976216770708561, -0.03233117237687111, 0.006122056394815445, -0.010670339688658714, -0.013755269348621368, 0.022028403356671333, 0.01282929815351963, -0.011613521724939346, -0.03025030717253685, -0.0033455477096140385, -0.020186245441436768, -0.003963503520935774, -0.030614720657467842, -0.04809923842549324, -0.04979144409298897, -0.005613392218947411, 0.03222518414258957, 0.008074230514466763, 0.04548024386167526, -0.04274187982082367, -8.126439643092453e-4, 0.006089380942285061, -0.013495399616658688, 0.014120037667453289, 6.303095142357051e-4, 0.01951751671731472, -0.021895553916692734, 0.029148466885089874, -0.03433115407824516, 0.02577170729637146, 0.0065758670680224895, -0.023974351584911346, -0.032213930040597916, 6.08053058385849e-4, 0.011225073598325253, 0.00818843673914671, -0.044494032859802246, 0.009402443654835224, -0.02151455357670784, 0.02456141822040081, -0.04354599490761757, -0.03464091941714287, ...]

Bellow are the results from the python implementation (i am using only the "dense" output and not the sparse one)

In [16]: from FlagEmbedding import BGEM3FlagModel
In [17]: model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
In [18]: sentences_1 = "A long text to test the embeddings"

In [19]: embeddings_1 = model.encode(sentences_1, batch_size=12, max_length=8192 )['dense_vecs']

In [20]: embeddings_1
Out[20]:
array([-0.0474  , -0.02089 , -0.0282  , ..., -0.002155, -0.02812 ,
        0.05673 ], dtype=float16)

I believe the small differences are because of different implementation of floating point between python <> elixir.

@nyo16
Copy link
Contributor Author

nyo16 commented Aug 2, 2024

@jonatanklosko let me know if this looks good. For testing i didn't saw anything regarding :mean_pooling so I didnt add one and i was thinking actually what is testable in that case.

Comment on lines 71 to 73
raise ArgumentError,
"expected the output tensor to have rank 3 to apply :cls pooling, got: #{rank}." <>
" You should either disable pooling or pick a different output using :output_attribute"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe for any pooling we expect 3, because we reduce n tokens into 1. Now that we have more types, we can move the check before the case, like this:

if output_pool != nil and Nx.rank(output) != 3 do
  raise ...
end

We can use the message from the other clause!

@nyo16
Copy link
Contributor Author

nyo16 commented Aug 6, 2024

@jonatanklosko thank you for the comments. I changed the code to reflect them. Let me know if this is good!

lib/bumblebee/text.ex Outdated Show resolved Hide resolved
Copy link
Member

@jonatanklosko jonatanklosko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@jonatanklosko jonatanklosko changed the title Add CLS token pooling support. Add support for CLS token pooling in text embedding Aug 6, 2024
@jonatanklosko jonatanklosko merged commit 7db36b8 into elixir-nx:main Aug 6, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants