Skip to content

Commit

Permalink
simplifying Nx slice
Browse files Browse the repository at this point in the history
  • Loading branch information
nyo16 committed Aug 2, 2024
1 parent 32dfdc3 commit 3535141
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lib/bumblebee/text/text_embedding.ex
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ defmodule Bumblebee.Text.TextEmbedding do
case Nx.rank(output) do
3 ->
# Assuming CLS token is always at the first position
Nx.slice_along_axis(output, 0, 1, axis: 1) |> Nx.squeeze(axes: [1])

Nx.take(output, 0, axis: 1)
rank ->
raise ArgumentError,
"expected the output tensor to have rank 3 to apply :cls pooling, got: #{rank}." <>
Expand Down

0 comments on commit 3535141

Please sign in to comment.