Skip to content

Commit

Permalink
Up test
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Jul 31, 2023
1 parent 7a5aa7d commit 16bef11
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions test/bumblebee/text/generation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ defmodule Bumblebee.Text.GenerationTest do

@moduletag model_test_tags()

# TMP CHANGED :o

describe "integration" do
test "generates text with greedy generation" do
{:ok, model_info} = Bumblebee.load_model({:hf, "facebook/bart-large-cnn"})
Expand All @@ -23,7 +25,9 @@ defmodule Bumblebee.Text.GenerationTest do
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 8)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config)
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)

assert %{results: [%{text: "PG&E scheduled the black"}]} = Nx.Serving.run(serving, article)
end
Expand All @@ -37,7 +41,9 @@ defmodule Bumblebee.Text.GenerationTest do
Bumblebee.configure(generation_config, max_new_tokens: 12, no_repeat_ngram_length: 2)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config)
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)

# Without :no_repeat_ngram_length we get
# %{results: [%{text: "I was going to say, 'Well, I'm going to say,"}]}
Expand All @@ -58,7 +64,10 @@ defmodule Bumblebee.Text.GenerationTest do
)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config, seed: 0)
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
seed: 0,
defn_options: [compiler: EXLA]
)

# Note that this is just a snapshot test, we do not use any
# reference value, because of PRNG difference
Expand All @@ -82,7 +91,9 @@ defmodule Bumblebee.Text.GenerationTest do
)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config)
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)

assert %{results: [%{text: "I was going to say, 'Well, I don't know what you"}]} =
Nx.Serving.run(serving, "I was going")
Expand Down

0 comments on commit 16bef11

Please sign in to comment.