Skip to content

Commit

Permalink
Fix batched text generation finishing too early
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Jul 4, 2024
1 parent 6cf6839 commit 3b0eb08
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 2 additions & 0 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -609,8 +609,10 @@ defmodule Bumblebee.Text.Generation do
eos_token_ids = List.wrap(eos_token_id)

token_id
|> Nx.vectorize(:batch)
|> Nx.equal(Nx.tensor(eos_token_ids))
|> Nx.any()
|> Nx.devectorize()
else
Nx.tensor(false)
end
Expand Down
2 changes: 1 addition & 1 deletion test/bumblebee/audio/speech_to_text_whisper_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisperTest do
generation_config,
chunk_num_seconds: 30,
defn_options: [compiler: EXLA],
compile: [batch_size: 1]
compile: [batch_size: 4]
)

audio =
Expand Down

0 comments on commit 3b0eb08

Please sign in to comment.