Skip to content

Commit

Permalink
Up test
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Jul 31, 2023
1 parent 16bef11 commit 8d6aa28
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 15 deletions.
19 changes: 4 additions & 15 deletions test/bumblebee/text/generation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ defmodule Bumblebee.Text.GenerationTest do

@moduletag model_test_tags()

# TMP CHANGED :o

describe "integration" do
test "generates text with greedy generation" do
{:ok, model_info} = Bumblebee.load_model({:hf, "facebook/bart-large-cnn"})
Expand All @@ -25,9 +23,7 @@ defmodule Bumblebee.Text.GenerationTest do
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 8)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)
Bumblebee.Text.generation(model_info, tokenizer, generation_config)

assert %{results: [%{text: "PG&E scheduled the black"}]} = Nx.Serving.run(serving, article)
end
Expand All @@ -41,9 +37,7 @@ defmodule Bumblebee.Text.GenerationTest do
Bumblebee.configure(generation_config, max_new_tokens: 12, no_repeat_ngram_length: 2)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)
Bumblebee.Text.generation(model_info, tokenizer, generation_config)

# Without :no_repeat_ngram_length we get
# %{results: [%{text: "I was going to say, 'Well, I'm going to say,"}]}
Expand All @@ -64,10 +58,7 @@ defmodule Bumblebee.Text.GenerationTest do
)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
seed: 0,
defn_options: [compiler: EXLA]
)
Bumblebee.Text.generation(model_info, tokenizer, generation_config, seed: 0)

# Note that this is just a snapshot test, we do not use any
# reference value, because of PRNG difference
Expand All @@ -91,9 +82,7 @@ defmodule Bumblebee.Text.GenerationTest do
)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)
Bumblebee.Text.generation(model_info, tokenizer, generation_config)

assert %{results: [%{text: "I was going to say, 'Well, I don't know what you"}]} =
Nx.Serving.run(serving, "I was going")
Expand Down
1 change: 1 addition & 0 deletions test/test_helper.exs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Nx.global_default_backend(EXLA.Backend)
Nx.Defn.global_default_options(compiler: EXLA)

Application.put_env(:bumblebee, :progress_bar_enabled, false)

Expand Down

0 comments on commit 8d6aa28

Please sign in to comment.