diff --git a/lib/bumblebee/text/text_embedding.ex b/lib/bumblebee/text/text_embedding.ex index d65d94a6..668e01ac 100644 --- a/lib/bumblebee/text/text_embedding.ex +++ b/lib/bumblebee/text/text_embedding.ex @@ -65,8 +65,8 @@ defmodule Bumblebee.Text.TextEmbedding do case Nx.rank(output) do 3 -> # Assuming CLS token is always at the first position - Nx.slice_along_axis(output, 0, 1, axis: 1) |> Nx.squeeze(axes: [1]) - + Nx.take(output, 0, axis: 1) + rank -> raise ArgumentError, "expected the output tensor to have rank 3 to apply :cls pooling, got: #{rank}." <>