diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 550bb5f6..0deb58f4 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -5,6 +5,8 @@ defmodule Bumblebee.Text.GenerationTest do @moduletag model_test_tags() + # TMP CHANGED :o + describe "integration" do test "generates text with greedy generation" do {:ok, model_info} = Bumblebee.load_model({:hf, "facebook/bart-large-cnn"}) @@ -23,7 +25,9 @@ defmodule Bumblebee.Text.GenerationTest do generation_config = Bumblebee.configure(generation_config, max_new_tokens: 8) serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config) + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + defn_options: [compiler: EXLA] + ) assert %{results: [%{text: "PG&E scheduled the black"}]} = Nx.Serving.run(serving, article) end @@ -37,7 +41,9 @@ defmodule Bumblebee.Text.GenerationTest do Bumblebee.configure(generation_config, max_new_tokens: 12, no_repeat_ngram_length: 2) serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config) + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + defn_options: [compiler: EXLA] + ) # Without :no_repeat_ngram_length we get # %{results: [%{text: "I was going to say, 'Well, I'm going to say,"}]} @@ -58,7 +64,10 @@ defmodule Bumblebee.Text.GenerationTest do ) serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config, seed: 0) + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + seed: 0, + defn_options: [compiler: EXLA] + ) # Note that this is just a snapshot test, we do not use any # reference value, because of PRNG difference @@ -82,7 +91,9 @@ defmodule Bumblebee.Text.GenerationTest do ) serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config) + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + defn_options: [compiler: EXLA] + ) assert %{results: [%{text: "I was going to say, 'Well, I don't know what you"}]} = Nx.Serving.run(serving, "I was going")