diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 48bdf656..d95e64be 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -609,8 +609,10 @@ defmodule Bumblebee.Text.Generation do eos_token_ids = List.wrap(eos_token_id) token_id + |> Nx.vectorize(:batch) |> Nx.equal(Nx.tensor(eos_token_ids)) |> Nx.any() + |> Nx.devectorize() else Nx.tensor(false) end diff --git a/test/bumblebee/audio/speech_to_text_whisper_test.exs b/test/bumblebee/audio/speech_to_text_whisper_test.exs index b8788847..b33e4f85 100644 --- a/test/bumblebee/audio/speech_to_text_whisper_test.exs +++ b/test/bumblebee/audio/speech_to_text_whisper_test.exs @@ -109,7 +109,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisperTest do generation_config, chunk_num_seconds: 30, defn_options: [compiler: EXLA], - compile: [batch_size: 1] + compile: [batch_size: 4] ) audio =