Skip to content

Commit

Permalink
Up test
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Jul 31, 2023
1 parent 8d6aa28 commit cc2c091
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 41 deletions.
20 changes: 13 additions & 7 deletions test/bumblebee/text/bart_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,21 @@ defmodule Bumblebee.Text.BartTest do

test "conditional generation" do
{:ok, model_info} = Bumblebee.load_model({:hf, "facebook/bart-large-cnn"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/bart-large-cnn"})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "facebook/bart-large-cnn"})

assert %Bumblebee.Text.Bart{architecture: :for_conditional_generation} = model_info.spec

inputs = %{
"input_ids" => Nx.tensor([[0, 133, 812, 9, 1470, 16, 2201, 2]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1]])
}
article = """
PG&E stated it scheduled the blackouts in response to forecasts for high \
winds amid dry conditions. The aim is to reduce the risk of wildfires. \
Nearly 800 thousand customers were scheduled to be affected by the shutoffs \
which were expected to last through at least midday tomorrow.
"""

generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5)
inputs = Bumblebee.apply_tokenizer(tokenizer, article)

generation_config = Bumblebee.configure(generation_config, max_length: 8)

generate =
Bumblebee.Text.Generation.build_generate(
Expand All @@ -162,8 +167,9 @@ defmodule Bumblebee.Text.BartTest do
generation_config
)

token_ids = generate.(model_info.params, inputs)
token_ids = EXLA.jit(generate).(model_info.params, inputs)

assert_equal(token_ids, Nx.tensor([[2, 0, 133, 812, 9, 2]]))
assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == ["PG&E scheduled the"]
end
end
# tmp
15 changes: 8 additions & 7 deletions test/bumblebee/text/blenderbot_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@ defmodule Bumblebee.Text.BlenderbotTest do

test "conditional generation" do
{:ok, model_info} = Bumblebee.load_model({:hf, "facebook/blenderbot-400M-distill"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/blenderbot-400M-distill"})

{:ok, generation_config} =
Bumblebee.load_generation_config({:hf, "facebook/blenderbot-400M-distill"})

assert %Bumblebee.Text.Blenderbot{architecture: :for_conditional_generation} = model_info.spec

inputs = %{
"input_ids" => Nx.tensor([[2675, 19, 544, 366, 304, 38, 2]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1]])
}
english_phrase = " Hey, how are you?"

generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5)
inputs = Bumblebee.apply_tokenizer(tokenizer, english_phrase)

generation_config = Bumblebee.configure(generation_config, min_length: 0, max_length: 6)

generate =
Bumblebee.Text.Generation.build_generate(
Expand All @@ -75,8 +75,9 @@ defmodule Bumblebee.Text.BlenderbotTest do
generation_config
)

token_ids = generate.(model_info.params, inputs)
token_ids = EXLA.jit(generate).(model_info.params, inputs)

assert_equal(token_ids, Nx.tensor([[1, 281, 476, 929, 731, 2]]))
assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == [" I'm doing well"]
end
end
# tmp
18 changes: 14 additions & 4 deletions test/bumblebee/text/generation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ defmodule Bumblebee.Text.GenerationTest do
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 8)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config)
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)

assert %{results: [%{text: "PG&E scheduled the black"}]} = Nx.Serving.run(serving, article)
end
Expand All @@ -37,7 +39,9 @@ defmodule Bumblebee.Text.GenerationTest do
Bumblebee.configure(generation_config, max_new_tokens: 12, no_repeat_ngram_length: 2)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config)
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)

# Without :no_repeat_ngram_length we get
# %{results: [%{text: "I was going to say, 'Well, I'm going to say,"}]}
Expand All @@ -58,7 +62,10 @@ defmodule Bumblebee.Text.GenerationTest do
)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config, seed: 0)
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
seed: 0,
defn_options: [compiler: EXLA]
)

# Note that this is just a snapshot test, we do not use any
# reference value, because of PRNG difference
Expand All @@ -82,10 +89,13 @@ defmodule Bumblebee.Text.GenerationTest do
)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config)
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)

assert %{results: [%{text: "I was going to say, 'Well, I don't know what you"}]} =
Nx.Serving.run(serving, "I was going")
end
end
end
# tmp
16 changes: 9 additions & 7 deletions test/bumblebee/text/mbart_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -158,17 +158,18 @@ defmodule Bumblebee.Text.MbartTest do
module: Bumblebee.Text.Mbart
)

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/mbart-large-en-ro"})

{:ok, generation_config} =
Bumblebee.load_generation_config({:hf, "facebook/mbart-large-en-ro"})

assert %Bumblebee.Text.Mbart{architecture: :for_conditional_generation} = model_info.spec

inputs = %{
"input_ids" => Nx.tensor([[4828, 83, 70, 35166, 2, 250_004]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1]])
}
english_phrase = "42 is the answer"

inputs = Bumblebee.apply_tokenizer(tokenizer, english_phrase)

generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5)
generation_config = Bumblebee.configure(generation_config, min_length: 0, max_length: 6)

generate =
Bumblebee.Text.Generation.build_generate(
Expand All @@ -177,8 +178,9 @@ defmodule Bumblebee.Text.MbartTest do
generation_config
)

token_ids = generate.(model_info.params, inputs)
token_ids = EXLA.jit(generate).(model_info.params, inputs)

assert_equal(token_ids, Nx.tensor([[250_020, 4828, 473, 54051, 202, 2]]))
assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == ["42 este răspunsul"]
end
end
# tmp
29 changes: 15 additions & 14 deletions test/bumblebee/text/t5_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,15 @@ defmodule Bumblebee.Text.T5Test do
end

test "text generation" do
assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "t5-small"})
assert {:ok, model_info} = Bumblebee.load_model({:hf, "t5-small"})
assert {:ok, generation_config} = Bumblebee.load_generation_config({:hf, "t5-small"})

inputs = %{
"input_ids" => Nx.tensor([[13959, 1566, 12, 2968, 10, 571, 625, 33, 25, 58, 1]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
}
text = "translate English to German: How old are you?"

inputs = Bumblebee.apply_tokenizer(tokenizer, text)

generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5)
generation_config = Bumblebee.configure(generation_config, min_length: 0, max_length: 10)

generate =
Bumblebee.Text.Generation.build_generate(
Expand All @@ -163,23 +163,23 @@ defmodule Bumblebee.Text.T5Test do
generation_config
)

token_ids = generate.(model_info.params, inputs)
token_ids = EXLA.jit(generate).(model_info.params, inputs)

assert_equal(token_ids, Nx.tensor([[0, 2739, 4445, 436, 292, 58]]))
assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == ["Wie alt sind Sie?"]
end

test "text generation (tied embeddings)" do
assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google/flan-t5-small"})
assert {:ok, model_info} = Bumblebee.load_model({:hf, "google/flan-t5-small"})

assert {:ok, generation_config} =
Bumblebee.load_generation_config({:hf, "google/flan-t5-small"})

inputs = %{
"input_ids" => Nx.tensor([[13959, 1566, 12, 2968, 10, 571, 625, 33, 25, 58, 1]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
}
text = "translate English to German: How old are you?"

inputs = Bumblebee.apply_tokenizer(tokenizer, text)

generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5)
generation_config = Bumblebee.configure(generation_config, min_length: 0, max_length: 10)

generate =
Bumblebee.Text.Generation.build_generate(
Expand All @@ -188,9 +188,10 @@ defmodule Bumblebee.Text.T5Test do
generation_config
)

token_ids = generate.(model_info.params, inputs)
token_ids = EXLA.jit(generate).(model_info.params, inputs)

assert_equal(token_ids, Nx.tensor([[0, 2739, 3, 362, 3, 49]]))
assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == ["Wie ich er bitten?"]
end
end
end
# tmp
5 changes: 4 additions & 1 deletion test/bumblebee/vision/image_to_text_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ defmodule Bumblebee.Vision.ImageToTextTest do
Bumblebee.load_generation_config({:hf, "Salesforce/blip-image-captioning-base"})

serving =
Bumblebee.Vision.ImageToText.image_to_text(blip, featurizer, tokenizer, generation_config)
Bumblebee.Vision.ImageToText.image_to_text(blip, featurizer, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)

image = StbImage.read_file!(Path.join(@images_dir, "coco/39769.jpeg"))

Expand All @@ -30,3 +32,4 @@ defmodule Bumblebee.Vision.ImageToTextTest do
end
end
end
# tmp
1 change: 0 additions & 1 deletion test/test_helper.exs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
Nx.global_default_backend(EXLA.Backend)
Nx.Defn.global_default_options(compiler: EXLA)

Application.put_env(:bumblebee, :progress_bar_enabled, false)

Expand Down

0 comments on commit cc2c091

Please sign in to comment.