diff --git a/lib/bumblebee/text.ex b/lib/bumblebee/text.ex index 2e9233b0..aa4ca19b 100644 --- a/lib/bumblebee/text.ex +++ b/lib/bumblebee/text.ex @@ -313,7 +313,14 @@ defmodule Bumblebee.Text do this option is ignored. Defaults to `:pooled_state` * `:output_pool` - pooling to apply on top of the model output, in case - it is not already a pooled embedding. Supported values: `:mean_pooling`. + it is not already a pooled embedding. Supported values: + + * `:mean_pooling` - performs a mean across all tokens + + * `cls_token_pooling` - takes the embedding for the special CLS token. + Note that we currently assume that the CLS token is the first token + in the sequence + By default no pooling is applied * `:embedding_processor` - a post-processing step to apply to the diff --git a/lib/bumblebee/text/text_embedding.ex b/lib/bumblebee/text/text_embedding.ex index bf6a7a97..41e284e8 100644 --- a/lib/bumblebee/text/text_embedding.ex +++ b/lib/bumblebee/text/text_embedding.ex @@ -56,22 +56,21 @@ defmodule Bumblebee.Text.TextEmbedding do output end + if output_pool != nil and Nx.rank(output) != 3 do + raise ArgumentError, + "expected the output tensor to have rank 3 to apply :output_pool, got: #{Nx.rank(output)}." <> + " You should either disable pooling or pick a different output using :output_attribute" + end + output = case output_pool do nil -> output - :mean_pooling -> - case Nx.rank(output) do - 3 -> - :ok - - rank -> - raise ArgumentError, - "expected the output tensor to have rank 3 to apply :output_pool, got: #{rank}." <> - " You should either disable pooling or pick a different output using :output_attribute" - end + :cls_token_pooling -> + Nx.take(output, 0, axis: 1) + :mean_pooling -> input_mask_expanded = Nx.new_axis(inputs["attention_mask"], -1) output @@ -81,7 +80,7 @@ defmodule Bumblebee.Text.TextEmbedding do other -> raise ArgumentError, - "expected :output_pool to be one of nil or :mean_pooling, got: #{inspect(other)}" + "expected :output_pool to be one of :cls_token_pooling, :mean_pooling or nil, got: #{inspect(other)}" end output =