diff --git a/test/bumblebee/text/bart_test.exs b/test/bumblebee/text/bart_test.exs index 0854ec11..29a874b2 100644 --- a/test/bumblebee/text/bart_test.exs +++ b/test/bumblebee/text/bart_test.exs @@ -144,16 +144,21 @@ defmodule Bumblebee.Text.BartTest do test "conditional generation" do {:ok, model_info} = Bumblebee.load_model({:hf, "facebook/bart-large-cnn"}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/bart-large-cnn"}) {:ok, generation_config} = Bumblebee.load_generation_config({:hf, "facebook/bart-large-cnn"}) assert %Bumblebee.Text.Bart{architecture: :for_conditional_generation} = model_info.spec - inputs = %{ - "input_ids" => Nx.tensor([[0, 133, 812, 9, 1470, 16, 2201, 2]]), - "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1]]) - } + article = """ + PG&E stated it scheduled the blackouts in response to forecasts for high \ + winds amid dry conditions. The aim is to reduce the risk of wildfires. \ + Nearly 800 thousand customers were scheduled to be affected by the shutoffs \ + which were expected to last through at least midday tomorrow. + """ - generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5) + inputs = Bumblebee.apply_tokenizer(tokenizer, article) + + generation_config = Bumblebee.configure(generation_config, max_length: 8) generate = Bumblebee.Text.Generation.build_generate( @@ -162,8 +167,9 @@ defmodule Bumblebee.Text.BartTest do generation_config ) - token_ids = generate.(model_info.params, inputs) + token_ids = EXLA.jit(generate).(model_info.params, inputs) - assert_equal(token_ids, Nx.tensor([[2, 0, 133, 812, 9, 2]])) + assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == ["PG&E scheduled the"] end end +# tmp diff --git a/test/bumblebee/text/blenderbot_test.exs b/test/bumblebee/text/blenderbot_test.exs index 6a20bf60..0203b605 100644 --- a/test/bumblebee/text/blenderbot_test.exs +++ b/test/bumblebee/text/blenderbot_test.exs @@ -55,18 +55,18 @@ defmodule Bumblebee.Text.BlenderbotTest do test "conditional generation" do {:ok, model_info} = Bumblebee.load_model({:hf, "facebook/blenderbot-400M-distill"}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/blenderbot-400M-distill"}) {:ok, generation_config} = Bumblebee.load_generation_config({:hf, "facebook/blenderbot-400M-distill"}) assert %Bumblebee.Text.Blenderbot{architecture: :for_conditional_generation} = model_info.spec - inputs = %{ - "input_ids" => Nx.tensor([[2675, 19, 544, 366, 304, 38, 2]]), - "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1]]) - } + english_phrase = " Hey, how are you?" - generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5) + inputs = Bumblebee.apply_tokenizer(tokenizer, english_phrase) + + generation_config = Bumblebee.configure(generation_config, min_length: 0, max_length: 6) generate = Bumblebee.Text.Generation.build_generate( @@ -75,8 +75,9 @@ defmodule Bumblebee.Text.BlenderbotTest do generation_config ) - token_ids = generate.(model_info.params, inputs) + token_ids = EXLA.jit(generate).(model_info.params, inputs) - assert_equal(token_ids, Nx.tensor([[1, 281, 476, 929, 731, 2]])) + assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == [" I'm doing well"] end end +# tmp diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 550bb5f6..73aafeab 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -23,7 +23,9 @@ defmodule Bumblebee.Text.GenerationTest do generation_config = Bumblebee.configure(generation_config, max_new_tokens: 8) serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config) + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + defn_options: [compiler: EXLA] + ) assert %{results: [%{text: "PG&E scheduled the black"}]} = Nx.Serving.run(serving, article) end @@ -37,7 +39,9 @@ defmodule Bumblebee.Text.GenerationTest do Bumblebee.configure(generation_config, max_new_tokens: 12, no_repeat_ngram_length: 2) serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config) + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + defn_options: [compiler: EXLA] + ) # Without :no_repeat_ngram_length we get # %{results: [%{text: "I was going to say, 'Well, I'm going to say,"}]} @@ -58,7 +62,10 @@ defmodule Bumblebee.Text.GenerationTest do ) serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config, seed: 0) + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + seed: 0, + defn_options: [compiler: EXLA] + ) # Note that this is just a snapshot test, we do not use any # reference value, because of PRNG difference @@ -82,10 +89,13 @@ defmodule Bumblebee.Text.GenerationTest do ) serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config) + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + defn_options: [compiler: EXLA] + ) assert %{results: [%{text: "I was going to say, 'Well, I don't know what you"}]} = Nx.Serving.run(serving, "I was going") end end end +# tmp diff --git a/test/bumblebee/text/mbart_test.exs b/test/bumblebee/text/mbart_test.exs index 197db02d..8a3ae285 100644 --- a/test/bumblebee/text/mbart_test.exs +++ b/test/bumblebee/text/mbart_test.exs @@ -158,17 +158,18 @@ defmodule Bumblebee.Text.MbartTest do module: Bumblebee.Text.Mbart ) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/mbart-large-en-ro"}) + {:ok, generation_config} = Bumblebee.load_generation_config({:hf, "facebook/mbart-large-en-ro"}) assert %Bumblebee.Text.Mbart{architecture: :for_conditional_generation} = model_info.spec - inputs = %{ - "input_ids" => Nx.tensor([[4828, 83, 70, 35166, 2, 250_004]]), - "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1]]) - } + english_phrase = "42 is the answer" + + inputs = Bumblebee.apply_tokenizer(tokenizer, english_phrase) - generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5) + generation_config = Bumblebee.configure(generation_config, min_length: 0, max_length: 6) generate = Bumblebee.Text.Generation.build_generate( @@ -177,8 +178,9 @@ defmodule Bumblebee.Text.MbartTest do generation_config ) - token_ids = generate.(model_info.params, inputs) + token_ids = EXLA.jit(generate).(model_info.params, inputs) - assert_equal(token_ids, Nx.tensor([[250_020, 4828, 473, 54051, 202, 2]])) + assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == ["42 este răspunsul"] end end +# tmp diff --git a/test/bumblebee/text/t5_test.exs b/test/bumblebee/text/t5_test.exs index 441e9a6a..85f894fa 100644 --- a/test/bumblebee/text/t5_test.exs +++ b/test/bumblebee/text/t5_test.exs @@ -146,15 +146,15 @@ defmodule Bumblebee.Text.T5Test do end test "text generation" do + assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "t5-small"}) assert {:ok, model_info} = Bumblebee.load_model({:hf, "t5-small"}) assert {:ok, generation_config} = Bumblebee.load_generation_config({:hf, "t5-small"}) - inputs = %{ - "input_ids" => Nx.tensor([[13959, 1566, 12, 2968, 10, 571, 625, 33, 25, 58, 1]]), - "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) - } + text = "translate English to German: How old are you?" + + inputs = Bumblebee.apply_tokenizer(tokenizer, text) - generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5) + generation_config = Bumblebee.configure(generation_config, min_length: 0, max_length: 10) generate = Bumblebee.Text.Generation.build_generate( @@ -163,23 +163,23 @@ defmodule Bumblebee.Text.T5Test do generation_config ) - token_ids = generate.(model_info.params, inputs) + token_ids = EXLA.jit(generate).(model_info.params, inputs) - assert_equal(token_ids, Nx.tensor([[0, 2739, 4445, 436, 292, 58]])) + assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == ["Wie alt sind Sie?"] end test "text generation (tied embeddings)" do + assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google/flan-t5-small"}) assert {:ok, model_info} = Bumblebee.load_model({:hf, "google/flan-t5-small"}) assert {:ok, generation_config} = Bumblebee.load_generation_config({:hf, "google/flan-t5-small"}) - inputs = %{ - "input_ids" => Nx.tensor([[13959, 1566, 12, 2968, 10, 571, 625, 33, 25, 58, 1]]), - "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) - } + text = "translate English to German: How old are you?" + + inputs = Bumblebee.apply_tokenizer(tokenizer, text) - generation_config = Bumblebee.configure(generation_config, max_new_tokens: 5) + generation_config = Bumblebee.configure(generation_config, min_length: 0, max_length: 10) generate = Bumblebee.Text.Generation.build_generate( @@ -188,9 +188,10 @@ defmodule Bumblebee.Text.T5Test do generation_config ) - token_ids = generate.(model_info.params, inputs) + token_ids = EXLA.jit(generate).(model_info.params, inputs) - assert_equal(token_ids, Nx.tensor([[0, 2739, 3, 362, 3, 49]])) + assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == ["Wie ich er bitten?"] end end end +# tmp diff --git a/test/bumblebee/vision/image_to_text_test.exs b/test/bumblebee/vision/image_to_text_test.exs index 7ff840ad..62a2b680 100644 --- a/test/bumblebee/vision/image_to_text_test.exs +++ b/test/bumblebee/vision/image_to_text_test.exs @@ -20,7 +20,9 @@ defmodule Bumblebee.Vision.ImageToTextTest do Bumblebee.load_generation_config({:hf, "Salesforce/blip-image-captioning-base"}) serving = - Bumblebee.Vision.ImageToText.image_to_text(blip, featurizer, tokenizer, generation_config) + Bumblebee.Vision.ImageToText.image_to_text(blip, featurizer, tokenizer, generation_config, + defn_options: [compiler: EXLA] + ) image = StbImage.read_file!(Path.join(@images_dir, "coco/39769.jpeg")) @@ -30,3 +32,4 @@ defmodule Bumblebee.Vision.ImageToTextTest do end end end +# tmp diff --git a/test/test_helper.exs b/test/test_helper.exs index c0a34615..7592a443 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -1,5 +1,4 @@ Nx.global_default_backend(EXLA.Backend) -Nx.Defn.global_default_options(compiler: EXLA) Application.put_env(:bumblebee, :progress_bar_enabled, false)