Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up tests #230

Merged
merged 1 commit into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions test/bumblebee/text/bart_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -144,21 +144,16 @@ defmodule Bumblebee.Text.BartTest do

test "conditional generation" do
{:ok, model_info} = Bumblebee.load_model({:hf, "facebook/bart-large-cnn"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/bart-large-cnn"})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "facebook/bart-large-cnn"})

assert %Bumblebee.Text.Bart{architecture: :for_conditional_generation} = model_info.spec

article = """
PG&E stated it scheduled the blackouts in response to forecasts for high \
winds amid dry conditions. The aim is to reduce the risk of wildfires. \
Nearly 800 thousand customers were scheduled to be affected by the shutoffs \
which were expected to last through at least midday tomorrow.
"""
inputs = %{
"input_ids" => Nx.tensor([[0, 133, 812, 9, 1470, 16, 2201, 2]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1]])
}

inputs = Bumblebee.apply_tokenizer(tokenizer, article)

generation_config = Bumblebee.configure(generation_config, max_length: 8)
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5)

generate =
Bumblebee.Text.Generation.build_generate(
Expand All @@ -167,8 +162,8 @@ defmodule Bumblebee.Text.BartTest do
generation_config
)

token_ids = EXLA.jit(generate).(model_info.params, inputs)
token_ids = generate.(model_info.params, inputs)

assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == ["PG&E scheduled the"]
assert_equal(token_ids, Nx.tensor([[2, 0, 133, 812, 9, 2]]))
end
end
14 changes: 7 additions & 7 deletions test/bumblebee/text/blenderbot_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@ defmodule Bumblebee.Text.BlenderbotTest do

test "conditional generation" do
{:ok, model_info} = Bumblebee.load_model({:hf, "facebook/blenderbot-400M-distill"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/blenderbot-400M-distill"})

{:ok, generation_config} =
Bumblebee.load_generation_config({:hf, "facebook/blenderbot-400M-distill"})

assert %Bumblebee.Text.Blenderbot{architecture: :for_conditional_generation} = model_info.spec

english_phrase = " Hey, how are you?"
inputs = %{
"input_ids" => Nx.tensor([[2675, 19, 544, 366, 304, 38, 2]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1]])
}

inputs = Bumblebee.apply_tokenizer(tokenizer, english_phrase)

generation_config = Bumblebee.configure(generation_config, min_length: 0, max_length: 6)
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5)

generate =
Bumblebee.Text.Generation.build_generate(
Expand All @@ -75,8 +75,8 @@ defmodule Bumblebee.Text.BlenderbotTest do
generation_config
)

token_ids = EXLA.jit(generate).(model_info.params, inputs)
token_ids = generate.(model_info.params, inputs)

assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == [" I'm doing well"]
assert_equal(token_ids, Nx.tensor([[1, 281, 476, 929, 731, 2]]))
end
end
17 changes: 4 additions & 13 deletions test/bumblebee/text/generation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ defmodule Bumblebee.Text.GenerationTest do
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 8)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)
Bumblebee.Text.generation(model_info, tokenizer, generation_config)

assert %{results: [%{text: "PG&E scheduled the black"}]} = Nx.Serving.run(serving, article)
end
Expand All @@ -39,9 +37,7 @@ defmodule Bumblebee.Text.GenerationTest do
Bumblebee.configure(generation_config, max_new_tokens: 12, no_repeat_ngram_length: 2)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)
Bumblebee.Text.generation(model_info, tokenizer, generation_config)

# Without :no_repeat_ngram_length we get
# %{results: [%{text: "I was going to say, 'Well, I'm going to say,"}]}
Expand All @@ -62,10 +58,7 @@ defmodule Bumblebee.Text.GenerationTest do
)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
seed: 0,
defn_options: [compiler: EXLA]
)
Bumblebee.Text.generation(model_info, tokenizer, generation_config, seed: 0)

# Note that this is just a snapshot test, we do not use any
# reference value, because of PRNG difference
Expand All @@ -89,9 +82,7 @@ defmodule Bumblebee.Text.GenerationTest do
)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)
Bumblebee.Text.generation(model_info, tokenizer, generation_config)

assert %{results: [%{text: "I was going to say, 'Well, I don't know what you"}]} =
Nx.Serving.run(serving, "I was going")
Expand Down
15 changes: 7 additions & 8 deletions test/bumblebee/text/mbart_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,17 @@ defmodule Bumblebee.Text.MbartTest do
module: Bumblebee.Text.Mbart
)

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/mbart-large-en-ro"})

{:ok, generation_config} =
Bumblebee.load_generation_config({:hf, "facebook/mbart-large-en-ro"})

assert %Bumblebee.Text.Mbart{architecture: :for_conditional_generation} = model_info.spec

english_phrase = "42 is the answer"

inputs = Bumblebee.apply_tokenizer(tokenizer, english_phrase)
inputs = %{
"input_ids" => Nx.tensor([[4828, 83, 70, 35166, 2, 250_004]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1]])
}

generation_config = Bumblebee.configure(generation_config, min_length: 0, max_length: 6)
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5)

generate =
Bumblebee.Text.Generation.build_generate(
Expand All @@ -178,8 +177,8 @@ defmodule Bumblebee.Text.MbartTest do
generation_config
)

token_ids = EXLA.jit(generate).(model_info.params, inputs)
token_ids = generate.(model_info.params, inputs)

assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == ["42 este răspunsul"]
assert_equal(token_ids, Nx.tensor([[250_020, 4828, 473, 54051, 202, 2]]))
end
end
28 changes: 14 additions & 14 deletions test/bumblebee/text/t5_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,15 @@ defmodule Bumblebee.Text.T5Test do
end

test "text generation" do
assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "t5-small"})
assert {:ok, model_info} = Bumblebee.load_model({:hf, "t5-small"})
assert {:ok, generation_config} = Bumblebee.load_generation_config({:hf, "t5-small"})

text = "translate English to German: How old are you?"

inputs = Bumblebee.apply_tokenizer(tokenizer, text)
inputs = %{
"input_ids" => Nx.tensor([[13959, 1566, 12, 2968, 10, 571, 625, 33, 25, 58, 1]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
}

generation_config = Bumblebee.configure(generation_config, min_length: 0, max_length: 10)
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5)

generate =
Bumblebee.Text.Generation.build_generate(
Expand All @@ -163,23 +163,23 @@ defmodule Bumblebee.Text.T5Test do
generation_config
)

token_ids = EXLA.jit(generate).(model_info.params, inputs)
token_ids = generate.(model_info.params, inputs)

assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == ["Wie alt sind Sie?"]
assert_equal(token_ids, Nx.tensor([[0, 2739, 4445, 436, 292, 58]]))
end

test "text generation (tied embeddings)" do
assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google/flan-t5-small"})
assert {:ok, model_info} = Bumblebee.load_model({:hf, "google/flan-t5-small"})

assert {:ok, generation_config} =
Bumblebee.load_generation_config({:hf, "google/flan-t5-small"})

text = "translate English to German: How old are you?"

inputs = Bumblebee.apply_tokenizer(tokenizer, text)
inputs = %{
"input_ids" => Nx.tensor([[13959, 1566, 12, 2968, 10, 571, 625, 33, 25, 58, 1]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
}

generation_config = Bumblebee.configure(generation_config, min_length: 0, max_length: 10)
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5)

generate =
Bumblebee.Text.Generation.build_generate(
Expand All @@ -188,9 +188,9 @@ defmodule Bumblebee.Text.T5Test do
generation_config
)

token_ids = EXLA.jit(generate).(model_info.params, inputs)
token_ids = generate.(model_info.params, inputs)

assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == ["Wie ich er bitten?"]
assert_equal(token_ids, Nx.tensor([[0, 2739, 3, 362, 3, 49]]))
end
end
end
4 changes: 1 addition & 3 deletions test/bumblebee/vision/image_to_text_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ defmodule Bumblebee.Vision.ImageToTextTest do
Bumblebee.load_generation_config({:hf, "Salesforce/blip-image-captioning-base"})

serving =
Bumblebee.Vision.ImageToText.image_to_text(blip, featurizer, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)
Bumblebee.Vision.ImageToText.image_to_text(blip, featurizer, tokenizer, generation_config)

image = StbImage.read_file!(Path.join(@images_dir, "coco/39769.jpeg"))

Expand Down
Loading