diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 7fac4d8a..1f034e6d 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -650,9 +650,10 @@ defmodule Bumblebee do * `:return_offsets` - whether to return token offsets for encoded sequence. Defaults to `false` - * `:length` - applies fixed length padding or truncation to the given - input if set - + * `:length` - applies fixed length padding or truncation to the + given input if set. Can be either a specific number or a list + of numbers. When a list is given, the smallest number that + exceeds all input lengths is used as the padding length ## Examples diff --git a/lib/bumblebee/shared.ex b/lib/bumblebee/shared.ex index 52c71d6c..7ec5d2c0 100644 --- a/lib/bumblebee/shared.ex +++ b/lib/bumblebee/shared.ex @@ -359,6 +359,39 @@ defmodule Bumblebee.Shared do end end + @doc """ + Returns batch keys for the given sequence length specified in text + serving compile options. + """ + @spec sequence_batch_keys(nil | non_neg_integer() | list(non_neg_integer())) :: list() + def sequence_batch_keys(sequence_length) + + def sequence_batch_keys(nil), do: [:default] + + def sequence_batch_keys(length) when is_number(length) do + [{:sequence_length, length}] + end + + def sequence_batch_keys(lengths) when is_list(lengths) do + Enum.map(lengths, &{:sequence_length, &1}) + end + + @doc """ + Determines batch key compatible with `sequence_batch_keys/1` based + on tokenized inputs. + """ + @spec sequence_batch_key_for_inputs( + inputs :: any(), + nil | non_neg_integer() | list(non_neg_integer()) + ) :: term() + def sequence_batch_key_for_inputs(inputs, sequence_length) do + if sequence_length do + {:sequence_length, Nx.axis_size(inputs["input_ids"], 1)} + else + :default + end + end + @doc """ Generates tokenizer implementation. """ diff --git a/lib/bumblebee/text.ex b/lib/bumblebee/text.ex index 368a70db..ef8a6538 100644 --- a/lib/bumblebee/text.ex +++ b/lib/bumblebee/text.ex @@ -76,7 +76,10 @@ defmodule Bumblebee.Text do are optionally padded to always match this batch size * `:sequence_length` - the maximum input sequence length. Input - sequences are always padded/truncated to match that length + sequences are always padded/truncated to match that length. + A list can be given, in which case the serving compiles + a separate computation for each length and then inputs are + matched to the smallest bounding length It is advised to set this option in production and also configure a defn compiler using `:defn_options` to maximally reduce inference @@ -138,7 +141,10 @@ defmodule Bumblebee.Text do are optionally padded to always match this batch size * `:sequence_length` - the maximum input sequence length. Input - sequences are always padded/truncated to match that length + sequences are always padded/truncated to match that length. + A list can be given, in which case the serving compiles + a separate computation for each length and then inputs are + matched to the smallest bounding length It is advised to set this option in production and also configure a defn compiler using `:defn_options` to maximally reduce inference @@ -203,6 +209,9 @@ defmodule Bumblebee.Text do * `:sequence_length` - the maximum input sequence length. Input sequences are always padded/truncated to match that length. + A list can be given, in which case the serving compiles + a separate computation for each length and then inputs are + matched to the smallest bounding length. Note that in this case, the whole conversation history is the input, so this value should be relatively large to allow long history (though the supported upper limit depends on the model) @@ -267,7 +276,10 @@ defmodule Bumblebee.Text do are optionally padded to always match this batch size * `:sequence_length` - the maximum input sequence length. Input - sequences are always padded/truncated to match that length + sequences are always padded/truncated to match that length. + A list can be given, in which case the serving compiles + a separate computation for each length and then inputs are + matched to the smallest bounding length It is advised to set this option in production and also configure a defn compiler using `:defn_options` to maximally reduce inference @@ -336,7 +348,10 @@ defmodule Bumblebee.Text do are optionally padded to always match this batch size * `:sequence_length` - the maximum input sequence length. Input - sequences are always padded/truncated to match that length + sequences are always padded/truncated to match that length. + A list can be given, in which case the serving compiles + a separate computation for each length and then inputs are + matched to the smallest bounding length It is advised to set this option in production and also configure a defn compiler using `:defn_options` to maximally reduce inference @@ -398,7 +413,10 @@ defmodule Bumblebee.Text do are optionally padded to always match this batch size * `:sequence_length` - the maximum input sequence length. Input - sequences are always padded/truncated to match that length + sequences are always padded/truncated to match that length. + A list can be given, in which case the serving compiles + a separate computation for each length and then inputs are + matched to the smallest bounding length It is advised to set this option in production and also configure a defn compiler using `:defn_options` to maximally reduce inference @@ -441,6 +459,7 @@ defmodule Bumblebee.Text do end: number(), score: number() } + @doc """ Builds serving for the question answering task. @@ -463,7 +482,10 @@ defmodule Bumblebee.Text do prompt and label * `:sequence_length` - the maximum input sequence length. Input - sequences are always padded/truncated to match that length + sequences are always padded/truncated to match that length. + A list can be given, in which case the serving compiles + a separate computation for each length and then inputs are + matched to the smallest bounding length It is advised to set this option in production and also configure a defn compiler using `:defn_options` to maximally reduce inference @@ -528,7 +550,10 @@ defmodule Bumblebee.Text do prompt and label * `:sequence_length` - the maximum input sequence length. Input - sequences are always padded/truncated to match that length + sequences are always padded/truncated to match that length. + A list can be given, in which case the serving compiles + a separate computation for each length and then inputs are + matched to the smallest bounding length It is advised to set this option in production and also configure a defn compiler using `:defn_options` to maximally reduce inference diff --git a/lib/bumblebee/text/conversation.ex b/lib/bumblebee/text/conversation.ex index 977d518f..ddff98b4 100644 --- a/lib/bumblebee/text/conversation.ex +++ b/lib/bumblebee/text/conversation.ex @@ -45,10 +45,14 @@ defmodule Bumblebee.Text.Conversation do generate_fun = Text.Generation.build_generate(model, spec, generation_config, Keyword.take(opts, [:seed])) + batch_keys = Shared.sequence_batch_keys(sequence_length) + Nx.Serving.new( - fn defn_options -> + fn batch_key, defn_options -> generate_fun = Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn -> + {:sequence_length, sequence_length} = batch_key + inputs = %{ "input_ids" => Nx.template({batch_size, sequence_length}, :u32), "attention_mask" => Nx.template({batch_size, sequence_length}, :u32) @@ -74,7 +78,7 @@ defmodule Bumblebee.Text.Conversation do end, defn_options ) - |> Nx.Serving.process_options(batch_size: batch_size) + |> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys) |> Nx.Serving.client_preprocessing(fn input -> {histories, multi?} = Shared.validate_serving_input!(input, &validate_input/1) @@ -91,7 +95,10 @@ defmodule Bumblebee.Text.Conversation do return_token_type_ids: false ) - {Nx.Batch.concatenate([inputs]), {histories, multi?}} + batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) + batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) + + {batch, {histories, multi?}} end) |> Nx.Serving.client_postprocessing(fn {token_ids, _metadata}, {histories, multi?} -> decoded = Bumblebee.Tokenizer.decode(tokenizer, token_ids) diff --git a/lib/bumblebee/text/fill_mask.ex b/lib/bumblebee/text/fill_mask.ex index 7e2abc5d..36956f87 100644 --- a/lib/bumblebee/text/fill_mask.ex +++ b/lib/bumblebee/text/fill_mask.ex @@ -47,10 +47,14 @@ defmodule Bumblebee.Text.FillMask do |> Nx.squeeze(axes: [1]) end + batch_keys = Shared.sequence_batch_keys(sequence_length) + Nx.Serving.new( - fn defn_options -> + fn batch_key, defn_options -> scores_fun = Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> + {:sequence_length, sequence_length} = batch_key + inputs = %{ "input_ids" => Nx.template({batch_size, sequence_length}, :u32), "attention_mask" => Nx.template({batch_size, sequence_length}, :u32) @@ -66,7 +70,7 @@ defmodule Bumblebee.Text.FillMask do end, defn_options ) - |> Nx.Serving.process_options(batch_size: batch_size) + |> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys) |> Nx.Serving.client_preprocessing(fn input -> {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1) @@ -78,7 +82,10 @@ defmodule Bumblebee.Text.FillMask do return_token_type_ids: false ) - {Nx.Batch.concatenate([inputs]), multi?} + batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) + batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) + + {batch, multi?} end) |> Nx.Serving.client_postprocessing(fn {scores, _metadata}, multi? -> for scores <- Bumblebee.Utils.Nx.batch_to_list(scores) do diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 4e3b9eb8..979201e0 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -75,10 +75,14 @@ defmodule Bumblebee.Text.Generation do generate_fun = build_generate(model, spec, generation_config, Keyword.take(opts, [:seed])) + batch_keys = Shared.sequence_batch_keys(sequence_length) + Nx.Serving.new( - fn defn_options -> + fn batch_key, defn_options -> generate_fun = Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn -> + {:sequence_length, sequence_length} = batch_key + inputs = %{ "input_ids" => Nx.template({batch_size, sequence_length}, :u32), "attention_mask" => Nx.template({batch_size, sequence_length}, :u32) @@ -94,7 +98,7 @@ defmodule Bumblebee.Text.Generation do end, defn_options ) - |> Nx.Serving.process_options(batch_size: batch_size) + |> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys) |> Nx.Serving.client_preprocessing(fn input -> {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1) @@ -105,7 +109,10 @@ defmodule Bumblebee.Text.Generation do return_token_type_ids: false ) - {Nx.Batch.concatenate([inputs]), multi?} + batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) + batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) + + {batch, multi?} end) |> Nx.Serving.client_postprocessing(fn {token_ids, _metadata}, multi? -> decoded = Bumblebee.Tokenizer.decode(tokenizer, token_ids) diff --git a/lib/bumblebee/text/question_answering.ex b/lib/bumblebee/text/question_answering.ex index a332cf4c..2e462671 100644 --- a/lib/bumblebee/text/question_answering.ex +++ b/lib/bumblebee/text/question_answering.ex @@ -31,10 +31,14 @@ defmodule Bumblebee.Text.QuestionAnswering do %{start_scores: start_scores, end_scores: end_scores} end + batch_keys = Shared.sequence_batch_keys(sequence_length) + Nx.Serving.new( - fn defn_options -> + fn batch_key, defn_options -> predict_fun = Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> + {:sequence_length, sequence_length} = batch_key + inputs = %{ "input_ids" => Nx.template({batch_size, sequence_length}, :u32), "attention_mask" => Nx.template({batch_size, sequence_length}, :u32), @@ -52,7 +56,7 @@ defmodule Bumblebee.Text.QuestionAnswering do end, defn_options ) - |> Nx.Serving.process_options(batch_size: batch_size) + |> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys) |> Nx.Serving.client_preprocessing(fn raw_input -> {raw_inputs, multi?} = Shared.validate_serving_input!(raw_input, fn @@ -73,7 +77,11 @@ defmodule Bumblebee.Text.QuestionAnswering do ) inputs = Map.take(all_inputs, ["input_ids", "attention_mask", "token_type_ids"]) - {Nx.Batch.concatenate([inputs]), {all_inputs, raw_inputs, multi?}} + + batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) + batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) + + {batch, {all_inputs, raw_inputs, multi?}} end) |> Nx.Serving.client_postprocessing(fn {outputs, _metadata}, {inputs, raw_inputs, multi?} -> Enum.zip_with( diff --git a/lib/bumblebee/text/text_classification.ex b/lib/bumblebee/text/text_classification.ex index ce35c5ab..9f4074f9 100644 --- a/lib/bumblebee/text/text_classification.ex +++ b/lib/bumblebee/text/text_classification.ex @@ -31,10 +31,14 @@ defmodule Bumblebee.Text.TextClassification do Shared.logits_to_scores(outputs.logits, scores_function) end + batch_keys = Shared.sequence_batch_keys(sequence_length) + Nx.Serving.new( - fn defn_options -> + fn batch_key, defn_options -> scores_fun = Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> + {:sequence_length, sequence_length} = batch_key + inputs = %{ "input_ids" => Nx.template({batch_size, sequence_length}, :u32), "attention_mask" => Nx.template({batch_size, sequence_length}, :u32) @@ -50,7 +54,7 @@ defmodule Bumblebee.Text.TextClassification do end, defn_options ) - |> Nx.Serving.process_options(batch_size: batch_size) + |> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys) |> Nx.Serving.client_preprocessing(fn input -> {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1) @@ -60,7 +64,10 @@ defmodule Bumblebee.Text.TextClassification do return_token_type_ids: false ) - {Nx.Batch.concatenate([inputs]), multi?} + batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) + batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) + + {batch, multi?} end) |> Nx.Serving.client_postprocessing(fn {scores, _metadata}, multi? -> for scores <- Bumblebee.Utils.Nx.batch_to_list(scores) do diff --git a/lib/bumblebee/text/text_embedding.ex b/lib/bumblebee/text/text_embedding.ex index abedb6e9..a293b8c0 100644 --- a/lib/bumblebee/text/text_embedding.ex +++ b/lib/bumblebee/text/text_embedding.ex @@ -76,10 +76,14 @@ defmodule Bumblebee.Text.TextEmbedding do output end + batch_keys = Shared.sequence_batch_keys(sequence_length) + Nx.Serving.new( - fn defn_options -> + fn batch_key, defn_options -> embedding_fun = Shared.compile_or_jit(embedding_fun, defn_options, compile != nil, fn -> + {:sequence_length, sequence_length} = batch_key + inputs = %{ "input_ids" => Nx.template({batch_size, sequence_length}, :u32), "attention_mask" => Nx.template({batch_size, sequence_length}, :u32) @@ -95,7 +99,7 @@ defmodule Bumblebee.Text.TextEmbedding do end, defn_options ) - |> Nx.Serving.process_options(batch_size: batch_size) + |> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys) |> Nx.Serving.client_preprocessing(fn input -> {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1) @@ -105,7 +109,10 @@ defmodule Bumblebee.Text.TextEmbedding do return_token_type_ids: false ) - {Nx.Batch.concatenate([inputs]), multi?} + batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) + batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) + + {batch, multi?} end) |> Nx.Serving.client_postprocessing(fn {embeddings, _metadata}, multi? -> for embedding <- Bumblebee.Utils.Nx.batch_to_list(embeddings) do diff --git a/lib/bumblebee/text/token_classification.ex b/lib/bumblebee/text/token_classification.ex index fc4faf6e..25fd0f2f 100644 --- a/lib/bumblebee/text/token_classification.ex +++ b/lib/bumblebee/text/token_classification.ex @@ -39,10 +39,14 @@ defmodule Bumblebee.Text.TokenClassification do Shared.logits_to_scores(outputs.logits, scores_function) end + batch_keys = Shared.sequence_batch_keys(sequence_length) + Nx.Serving.new( - fn defn_options -> + fn batch_key, defn_options -> scores_fun = Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> + {:sequence_length, sequence_length} = batch_key + inputs = %{ "input_ids" => Nx.template({batch_size, sequence_length}, :u32), "attention_mask" => Nx.template({batch_size, sequence_length}, :u32) @@ -58,7 +62,7 @@ defmodule Bumblebee.Text.TokenClassification do end, defn_options ) - |> Nx.Serving.process_options(batch_size: batch_size) + |> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys) |> Nx.Serving.client_preprocessing(fn input -> {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1) @@ -71,7 +75,10 @@ defmodule Bumblebee.Text.TokenClassification do inputs = Map.take(all_inputs, ["input_ids", "attention_mask"]) - {Nx.Batch.concatenate([inputs]), {all_inputs, multi?}} + batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) + batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) + + {batch, {all_inputs, multi?}} end) |> Nx.Serving.client_postprocessing(fn {scores, _metadata}, {inputs, multi?} -> Enum.zip_with( diff --git a/lib/bumblebee/text/zero_shot_classification.ex b/lib/bumblebee/text/zero_shot_classification.ex index 314ab7c9..0f5e11d0 100644 --- a/lib/bumblebee/text/zero_shot_classification.ex +++ b/lib/bumblebee/text/zero_shot_classification.ex @@ -52,10 +52,14 @@ defmodule Bumblebee.Text.ZeroShotClassification do logits end + batch_keys = Shared.sequence_batch_keys(sequence_length) + Nx.Serving.new( - fn defn_options -> + fn batch_key, defn_options -> scores_fun = Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> + {:sequence_length, sequence_length} = batch_key + inputs = %{ "input_ids" => Nx.template({batch_size, sequences_per_batch, sequence_length}, :u32), @@ -74,7 +78,7 @@ defmodule Bumblebee.Text.ZeroShotClassification do end, defn_options ) - |> Nx.Serving.process_options(batch_size: batch_size) + |> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys) |> Nx.Serving.client_preprocessing(fn input -> {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1) @@ -86,9 +90,13 @@ defmodule Bumblebee.Text.ZeroShotClassification do return_token_type_ids: false ) + batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) + inputs = Utils.Nx.composite_unflatten_batch(inputs, length(texts)) - {Nx.Batch.concatenate([inputs]), multi?} + batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) + + {batch, multi?} end) |> Nx.Serving.client_postprocessing(fn {scores, _metadata}, multi? -> for scores <- Utils.Nx.batch_to_list(scores) do diff --git a/lib/bumblebee/utils/tokenizers.ex b/lib/bumblebee/utils/tokenizers.ex index 80277a6a..bb78041f 100644 --- a/lib/bumblebee/utils/tokenizers.ex +++ b/lib/bumblebee/utils/tokenizers.ex @@ -22,15 +22,27 @@ defmodule Bumblebee.Utils.Tokenizers do input = List.wrap(input) {:ok, encodings} = - Tokenizer.encode(tokenizer, input, add_special_tokens: opts[:add_special_tokens]) + Tokenizer.encode_batch(tokenizer, input, add_special_tokens: opts[:add_special_tokens]) + + length = opts[:length] {pad_length, truncate_length} = - if length = opts[:length] do + if is_number(length) do {length, length} else - {encodings - |> Enum.map(&Encoding.n_tokens/1) - |> Enum.max(), nil} + max_length = + encodings + |> Enum.map(&Encoding.n_tokens/1) + |> Enum.max() + + case length do + nil -> + {max_length, nil} + + lengths when is_list(lengths) -> + bounding_length = find_bounding_length(max_length, lengths) + {bounding_length, bounding_length} + end end pad_id = Tokenizer.token_to_id(tokenizer, pad_token) @@ -62,6 +74,20 @@ defmodule Bumblebee.Utils.Tokenizers do |> maybe_put_offsets(encodings, opts[:return_offsets]) end + defp find_bounding_length(max_length, lengths) do + find_bounding_length(max_length, lengths, :infinity, 0) + end + + defp find_bounding_length(max_length, [length | rest], bound, max) when length >= max_length do + find_bounding_length(max_length, rest, min(bound, length), max(length, max)) + end + + defp find_bounding_length(max_length, [length | rest], bound, max) do + find_bounding_length(max_length, rest, bound, max(length, max)) + end + + defp find_bounding_length(_max_length, [], bound, max), do: min(bound, max) + defp maybe_put_attention_mask(encoded, encodings, return_attention_mask) do if return_attention_mask do attention_mask = @@ -125,13 +151,20 @@ defmodule Bumblebee.Utils.Tokenizers do |> Nx.reshape({length(list), :auto}) end - def decode(tokenizer, ids) do + def decode(tokenizer, [id | _] = ids) when is_number(id) do case Tokenizer.decode(tokenizer, ids) do {:ok, decoded} -> decoded {:error, term} -> raise "decoding failed with error: #{inspect(term)}" end end + def decode(tokenizer, batch_ids) do + case Tokenizer.decode_batch(tokenizer, batch_ids) do + {:ok, decoded} -> decoded + {:error, term} -> raise "decoding failed with error: #{inspect(term)}" + end + end + def id_to_token(tokenizer, id) do Tokenizer.id_to_token(tokenizer, id) end @@ -141,7 +174,7 @@ defmodule Bumblebee.Utils.Tokenizers do end def load!(path) do - case Tokenizers.Tokenizer.from_file(path) do + case Tokenizers.Tokenizer.from_file(path, padding: :none, truncation: :none) do {:ok, tokenizer} -> tokenizer {:error, error} -> raise "failed to read tokenizer from file, reason: #{error}" end diff --git a/mix.exs b/mix.exs index 40866b7d..7daf2abc 100644 --- a/mix.exs +++ b/mix.exs @@ -31,7 +31,9 @@ defmodule Bumblebee.MixProject do defp deps do [ {:axon, "~> 0.5.0", axon_opts()}, - {:tokenizers, "~> 0.3"}, + # {:tokenizers, "~> 0.3"}, + {:tokenizers, github: "elixir-nx/tokenizers", override: true}, + {:rustler, ">= 0.0.0", optional: true}, # {:nx, "~> 0.5.0"}, # {:exla, "~> 0.5.0", only: [:dev, :test]}, # {:torchx, "~> 0.5.0", only: [:dev, :test]}, diff --git a/mix.lock b/mix.lock index 1a44af27..b2952d53 100644 --- a/mix.lock +++ b/mix.lock @@ -12,14 +12,14 @@ "earmark_parser": {:hex, :earmark_parser, "1.4.33", "3c3fd9673bb5dcc9edc28dd90f50c87ce506d1f71b70e3de69aa8154bc695d44", [:mix], [], "hexpm", "2d526833729b59b9fdb85785078697c72ac5e5066350663e5be6a1182da61b8f"}, "elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"}, "ex_doc": {:hex, :ex_doc, "0.30.3", "bfca4d340e3b95f2eb26e72e4890da83e2b3a5c5b0e52607333bf5017284b063", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "fbc8702046c1d25edf79de376297e608ac78cdc3a29f075484773ad1718918b6"}, - "exla": {:git, "https://github.com/elixir-nx/nx.git", "2771b27d251093e517fdfd05bac2b2aa1bc3df14", [sparse: "exla"]}, + "exla": {:git, "https://github.com/elixir-nx/nx.git", "2c6a9d48890d70fb3937cd19b0cb3e2356008488", [sparse: "exla"]}, "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, "mime": {:hex, :mime, "2.0.3", "3676436d3d1f7b81b5a2d2bd8405f412c677558c81b1c92be58c00562bb59095", [:mix], [], "hexpm", "27a30bf0db44d25eecba73755acf4068cbfe26a4372f9eb3e4ea3a45956bff6b"}, "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, - "nx": {:git, "https://github.com/elixir-nx/nx.git", "2771b27d251093e517fdfd05bac2b2aa1bc3df14", [sparse: "nx"]}, + "nx": {:git, "https://github.com/elixir-nx/nx.git", "2c6a9d48890d70fb3937cd19b0cb3e2356008488", [sparse: "nx"]}, "nx_image": {:hex, :nx_image, "0.1.1", "69cf0d2fd873d12b028583aa49b5e0a25f6aca307afc337a5d871851a20fba1d", [:mix], [{:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "55c8206a822237f6027168f11214e3887263c5b8a1f8e0634eea82c96e5093e3"}, "nx_signal": {:hex, :nx_signal, "0.1.0", "403ac73140e2f368e827e0aca1a3035abaf6d890b00376742b359a6838e00d7f", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "1c68f2f0d186700819287f37ee6154a11e06bf5dbb30b73fcc92776293309a05"}, "plug": {:hex, :plug, "1.14.2", "cff7d4ec45b4ae176a227acd94a7ab536d9b37b942c8e8fa6dfc0fff98ff4d80", [:mix], [{:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_crypto, "~> 1.1.1 or ~> 1.2", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.3 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "842fc50187e13cf4ac3b253d47d9474ed6c296a8732752835ce4a86acdf68d13"}, @@ -27,11 +27,13 @@ "plug_crypto": {:hex, :plug_crypto, "1.2.5", "918772575e48e81e455818229bf719d4ab4181fcbf7f85b68a35620f78d89ced", [:mix], [], "hexpm", "26549a1d6345e2172eb1c233866756ae44a9609bd33ee6f99147ab3fd87fd842"}, "progress_bar": {:hex, :progress_bar, "3.0.0", "f54ff038c2ac540cfbb4c2bfe97c75e7116ead044f3c2b10c9f212452194b5cd", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}], "hexpm", "6981c2b25ab24aecc91a2dc46623658e1399c21a2ae24db986b90d678530f2b7"}, "ranch": {:hex, :ranch, "1.8.0", "8c7a100a139fd57f17327b6413e4167ac559fbc04ca7448e9be9057311597a1d", [:make, :rebar3], [], "hexpm", "49fbcfd3682fab1f5d109351b61257676da1a2fdbe295904176d5e521a2ddfe5"}, - "rustler_precompiled": {:hex, :rustler_precompiled, "0.6.1", "160b545bce8bf9a3f1b436b2c10f53574036a0db628e40f393328cbbe593602f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "0dd269fa261c4e3df290b12031c575fff07a542749f7b0e8b744d72d66c43600"}, + "rustler": {:hex, :rustler, "0.29.1", "880f20ae3027bd7945def6cea767f5257bc926f33ff50c0d5d5a5315883c084d", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "109497d701861bfcd26eb8f5801fe327a8eef304f56a5b63ef61151ff44ac9b6"}, + "rustler_precompiled": {:hex, :rustler_precompiled, "0.6.2", "d2218ba08a43fa331957f30481d00b666664d7e3861431b02bd3f4f30eec8e5b", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "b9048eaed8d7d14a53f758c91865cc616608a438d2595f621f6a4b32a5511709"}, "stb_image": {:hex, :stb_image, "0.6.2", "d680a418416b1d778231d1d16151be3474d187e8505e1bd524aa0d08d2de094f", [:make, :mix], [{:cc_precompiler, "~> 0.1.0", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7.0", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "231ad012f649dd2bd5ef99e9171e814f3235e8f7c45009355789ac4836044a39"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, - "tokenizers": {:hex, :tokenizers, "0.3.2", "78c6238690a0467c613c8ba3c59338235594a78f870e8f8151b9614516dee0fd", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "f6dd9a798e81cf2f3359e1731836ed0a351cae4da5d5d570a7ef3d0543e9cf85"}, - "torchx": {:git, "https://github.com/elixir-nx/nx.git", "2771b27d251093e517fdfd05bac2b2aa1bc3df14", [sparse: "torchx"]}, + "tokenizers": {:git, "https://github.com/elixir-nx/tokenizers.git", "26d864bdedc11ddbc8bae52eaad0858f8a90987f", []}, + "toml": {:hex, :toml, "0.7.0", "fbcd773caa937d0c7a02c301a1feea25612720ac3fa1ccb8bfd9d30d822911de", [:mix], [], "hexpm", "0690246a2478c1defd100b0c9b89b4ea280a22be9a7b313a8a058a2408a2fa70"}, + "torchx": {:git, "https://github.com/elixir-nx/nx.git", "2c6a9d48890d70fb3937cd19b0cb3e2356008488", [sparse: "torchx"]}, "unpickler": {:hex, :unpickler, "0.1.0", "c2262c0819e6985b761e7107546cef96a485f401816be5304a65fdd200d5bd6a", [:mix], [], "hexpm", "e2b3f61e62406187ac52afead8a63bfb4e49394028993f3c4c42712743cab79e"}, "unzip": {:hex, :unzip, "0.8.0", "ee21d87c21b01567317387dab4228ac570ca15b41cfc221a067354cbf8e68c4d", [:mix], [], "hexpm", "ffa67a483efcedcb5876971a50947222e104d5f8fea2c4a0441e6f7967854827"}, "xla": {:hex, :xla, "0.4.4", "c3a8ed1f579bda949df505e49ff65415c8281d991fbd6ae1d8f3c5d0fd155f54", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "484f3f9011db3c9f1ff1e98eecefd382f3882a07ada540fd58803db1d2dab671"}, diff --git a/test/bumblebee/text/bert_tokenizer_test.exs b/test/bumblebee/text/bert_tokenizer_test.exs index ab48efff..bf61e20b 100644 --- a/test/bumblebee/text/bert_tokenizer_test.exs +++ b/test/bumblebee/text/bert_tokenizer_test.exs @@ -70,5 +70,21 @@ defmodule Bumblebee.Text.BertTokenizerTest do assert_equal(inputs["start_offsets"], Nx.tensor([[0, 0, 5, 14, 19, 25, 0]])) assert_equal(inputs["end_offsets"], Nx.tensor([[0, 4, 13, 18, 25, 26, 0]])) end + + test "encoding with multiple lengths" do + assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-cased"}) + + inputs = + Bumblebee.apply_tokenizer(tokenizer, "This is short.", length: [8, 16]) + + assert {1, 8} = Nx.shape(inputs["input_ids"]) + + inputs = + Bumblebee.apply_tokenizer(tokenizer, "This is definitely much longer than the above.", + length: [8, 16] + ) + + assert {1, 16} = Nx.shape(inputs["input_ids"]) + end end end diff --git a/test/bumblebee/text/text_embedding_test.exs b/test/bumblebee/text/text_embedding_test.exs index c71c1768..7a9688be 100644 --- a/test/bumblebee/text/text_embedding_test.exs +++ b/test/bumblebee/text/text_embedding_test.exs @@ -47,5 +47,37 @@ defmodule Bumblebee.Text.TextEmbeddingTest do assert_equal(Nx.sum(Nx.pow(embedding, 2)), Nx.tensor(1.0)) end + + test "supports compilation for single or multiple sequence lengths" do + {:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-large"}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-large"}) + + serving_short = + Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer, + compile: [batch_size: 1, sequence_length: 8] + ) + + serving_long = + Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer, + compile: [batch_size: 1, sequence_length: 16] + ) + + serving_both = + Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer, + compile: [batch_size: 1, sequence_length: [8, 16]] + ) + + short_text = "short text" + long_text = "definitely much longer text that should exceed 16 tokens" + + assert %{embedding: embedding_short} = Nx.Serving.run(serving_short, short_text) + assert %{embedding: embedding_long} = Nx.Serving.run(serving_long, long_text) + + assert %{embedding: embedding_short2} = Nx.Serving.run(serving_both, short_text) + assert %{embedding: embedding_long2} = Nx.Serving.run(serving_both, long_text) + + assert_equal(embedding_short, embedding_short2) + assert_equal(embedding_long, embedding_long2) + end end end