Skip to content

Commit

Permalink
Speed up tests (#230)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Jul 31, 2023
1 parent f23faee commit c65b143
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 57 deletions.
19 changes: 7 additions & 12 deletions test/bumblebee/text/bart_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -144,21 +144,16 @@ defmodule Bumblebee.Text.BartTest do

test "conditional generation" do
{:ok, model_info} = Bumblebee.load_model({:hf, "facebook/bart-large-cnn"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/bart-large-cnn"})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "facebook/bart-large-cnn"})

assert %Bumblebee.Text.Bart{architecture: :for_conditional_generation} = model_info.spec

article = """
PG&E stated it scheduled the blackouts in response to forecasts for high \
winds amid dry conditions. The aim is to reduce the risk of wildfires. \
Nearly 800 thousand customers were scheduled to be affected by the shutoffs \
which were expected to last through at least midday tomorrow.
"""
inputs = %{
"input_ids" => Nx.tensor([[0, 133, 812, 9, 1470, 16, 2201, 2]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1]])
}

inputs = Bumblebee.apply_tokenizer(tokenizer, article)

generation_config = Bumblebee.configure(generation_config, max_length: 8)
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5)

generate =
Bumblebee.Text.Generation.build_generate(
Expand All @@ -167,8 +162,8 @@ defmodule Bumblebee.Text.BartTest do
generation_config
)

token_ids = EXLA.jit(generate).(model_info.params, inputs)
token_ids = generate.(model_info.params, inputs)

assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == ["PG&E scheduled the"]
assert_equal(token_ids, Nx.tensor([[2, 0, 133, 812, 9, 2]]))
end
end
14 changes: 7 additions & 7 deletions test/bumblebee/text/blenderbot_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@ defmodule Bumblebee.Text.BlenderbotTest do

test "conditional generation" do
{:ok, model_info} = Bumblebee.load_model({:hf, "facebook/blenderbot-400M-distill"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/blenderbot-400M-distill"})

{:ok, generation_config} =
Bumblebee.load_generation_config({:hf, "facebook/blenderbot-400M-distill"})

assert %Bumblebee.Text.Blenderbot{architecture: :for_conditional_generation} = model_info.spec

english_phrase = " Hey, how are you?"
inputs = %{
"input_ids" => Nx.tensor([[2675, 19, 544, 366, 304, 38, 2]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1]])
}

inputs = Bumblebee.apply_tokenizer(tokenizer, english_phrase)

generation_config = Bumblebee.configure(generation_config, min_length: 0, max_length: 6)
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5)

generate =
Bumblebee.Text.Generation.build_generate(
Expand All @@ -75,8 +75,8 @@ defmodule Bumblebee.Text.BlenderbotTest do
generation_config
)

token_ids = EXLA.jit(generate).(model_info.params, inputs)
token_ids = generate.(model_info.params, inputs)

assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == [" I'm doing well"]
assert_equal(token_ids, Nx.tensor([[1, 281, 476, 929, 731, 2]]))
end
end
17 changes: 4 additions & 13 deletions test/bumblebee/text/generation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ defmodule Bumblebee.Text.GenerationTest do
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 8)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)
Bumblebee.Text.generation(model_info, tokenizer, generation_config)

assert %{results: [%{text: "PG&E scheduled the black"}]} = Nx.Serving.run(serving, article)
end
Expand All @@ -39,9 +37,7 @@ defmodule Bumblebee.Text.GenerationTest do
Bumblebee.configure(generation_config, max_new_tokens: 12, no_repeat_ngram_length: 2)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)
Bumblebee.Text.generation(model_info, tokenizer, generation_config)

# Without :no_repeat_ngram_length we get
# %{results: [%{text: "I was going to say, 'Well, I'm going to say,"}]}
Expand All @@ -62,10 +58,7 @@ defmodule Bumblebee.Text.GenerationTest do
)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
seed: 0,
defn_options: [compiler: EXLA]
)
Bumblebee.Text.generation(model_info, tokenizer, generation_config, seed: 0)

# Note that this is just a snapshot test, we do not use any
# reference value, because of PRNG difference
Expand All @@ -89,9 +82,7 @@ defmodule Bumblebee.Text.GenerationTest do
)

serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)
Bumblebee.Text.generation(model_info, tokenizer, generation_config)

assert %{results: [%{text: "I was going to say, 'Well, I don't know what you"}]} =
Nx.Serving.run(serving, "I was going")
Expand Down
15 changes: 7 additions & 8 deletions test/bumblebee/text/mbart_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,17 @@ defmodule Bumblebee.Text.MbartTest do
module: Bumblebee.Text.Mbart
)

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/mbart-large-en-ro"})

{:ok, generation_config} =
Bumblebee.load_generation_config({:hf, "facebook/mbart-large-en-ro"})

assert %Bumblebee.Text.Mbart{architecture: :for_conditional_generation} = model_info.spec

english_phrase = "42 is the answer"

inputs = Bumblebee.apply_tokenizer(tokenizer, english_phrase)
inputs = %{
"input_ids" => Nx.tensor([[4828, 83, 70, 35166, 2, 250_004]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1]])
}

generation_config = Bumblebee.configure(generation_config, min_length: 0, max_length: 6)
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5)

generate =
Bumblebee.Text.Generation.build_generate(
Expand All @@ -178,8 +177,8 @@ defmodule Bumblebee.Text.MbartTest do
generation_config
)

token_ids = EXLA.jit(generate).(model_info.params, inputs)
token_ids = generate.(model_info.params, inputs)

assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == ["42 este răspunsul"]
assert_equal(token_ids, Nx.tensor([[250_020, 4828, 473, 54051, 202, 2]]))
end
end
28 changes: 14 additions & 14 deletions test/bumblebee/text/t5_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,15 @@ defmodule Bumblebee.Text.T5Test do
end

test "text generation" do
assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "t5-small"})
assert {:ok, model_info} = Bumblebee.load_model({:hf, "t5-small"})
assert {:ok, generation_config} = Bumblebee.load_generation_config({:hf, "t5-small"})

text = "translate English to German: How old are you?"

inputs = Bumblebee.apply_tokenizer(tokenizer, text)
inputs = %{
"input_ids" => Nx.tensor([[13959, 1566, 12, 2968, 10, 571, 625, 33, 25, 58, 1]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
}

generation_config = Bumblebee.configure(generation_config, min_length: 0, max_length: 10)
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5)

generate =
Bumblebee.Text.Generation.build_generate(
Expand All @@ -163,23 +163,23 @@ defmodule Bumblebee.Text.T5Test do
generation_config
)

token_ids = EXLA.jit(generate).(model_info.params, inputs)
token_ids = generate.(model_info.params, inputs)

assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == ["Wie alt sind Sie?"]
assert_equal(token_ids, Nx.tensor([[0, 2739, 4445, 436, 292, 58]]))
end

test "text generation (tied embeddings)" do
assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google/flan-t5-small"})
assert {:ok, model_info} = Bumblebee.load_model({:hf, "google/flan-t5-small"})

assert {:ok, generation_config} =
Bumblebee.load_generation_config({:hf, "google/flan-t5-small"})

text = "translate English to German: How old are you?"

inputs = Bumblebee.apply_tokenizer(tokenizer, text)
inputs = %{
"input_ids" => Nx.tensor([[13959, 1566, 12, 2968, 10, 571, 625, 33, 25, 58, 1]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
}

generation_config = Bumblebee.configure(generation_config, min_length: 0, max_length: 10)
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5)

generate =
Bumblebee.Text.Generation.build_generate(
Expand All @@ -188,9 +188,9 @@ defmodule Bumblebee.Text.T5Test do
generation_config
)

token_ids = EXLA.jit(generate).(model_info.params, inputs)
token_ids = generate.(model_info.params, inputs)

assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == ["Wie ich er bitten?"]
assert_equal(token_ids, Nx.tensor([[0, 2739, 3, 362, 3, 49]]))
end
end
end
4 changes: 1 addition & 3 deletions test/bumblebee/vision/image_to_text_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ defmodule Bumblebee.Vision.ImageToTextTest do
Bumblebee.load_generation_config({:hf, "Salesforce/blip-image-captioning-base"})

serving =
Bumblebee.Vision.ImageToText.image_to_text(blip, featurizer, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)
Bumblebee.Vision.ImageToText.image_to_text(blip, featurizer, tokenizer, generation_config)

image = StbImage.read_file!(Path.join(@images_dir, "coco/39769.jpeg"))

Expand Down

0 comments on commit c65b143

Please sign in to comment.