Skip to content

Commit

Permalink
Add support for targeting multiple sequence lengths in text servings (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Jul 31, 2023
1 parent c65b143 commit b9bc8f0
Show file tree
Hide file tree
Showing 16 changed files with 249 additions and 47 deletions.
7 changes: 4 additions & 3 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -650,9 +650,10 @@ defmodule Bumblebee do
* `:return_offsets` - whether to return token offsets for encoded
sequence. Defaults to `false`
* `:length` - applies fixed length padding or truncation to the given
input if set
* `:length` - applies fixed length padding or truncation to the
given input if set. Can be either a specific number or a list
of numbers. When a list is given, the smallest number that
exceeds all input lengths is used as the padding length
## Examples
Expand Down
33 changes: 33 additions & 0 deletions lib/bumblebee/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,39 @@ defmodule Bumblebee.Shared do
end
end

@doc """
Returns batch keys for the given sequence length specified in text
serving compile options.
"""
@spec sequence_batch_keys(nil | non_neg_integer() | list(non_neg_integer())) :: list()
def sequence_batch_keys(sequence_length)

def sequence_batch_keys(nil), do: [:default]

def sequence_batch_keys(length) when is_number(length) do
[{:sequence_length, length}]
end

def sequence_batch_keys(lengths) when is_list(lengths) do
Enum.map(lengths, &{:sequence_length, &1})
end

@doc """
Determines batch key compatible with `sequence_batch_keys/1` based
on tokenized inputs.
"""
@spec sequence_batch_key_for_inputs(
inputs :: any(),
nil | non_neg_integer() | list(non_neg_integer())
) :: term()
def sequence_batch_key_for_inputs(inputs, sequence_length) do
if sequence_length do
{:sequence_length, Nx.axis_size(inputs["input_ids"], 1)}
else
:default
end
end

@doc """
Generates tokenizer implementation.
"""
Expand Down
39 changes: 32 additions & 7 deletions lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ defmodule Bumblebee.Text do
are optionally padded to always match this batch size
* `:sequence_length` - the maximum input sequence length. Input
sequences are always padded/truncated to match that length
sequences are always padded/truncated to match that length.
A list can be given, in which case the serving compiles
a separate computation for each length and then inputs are
matched to the smallest bounding length
It is advised to set this option in production and also configure
a defn compiler using `:defn_options` to maximally reduce inference
Expand Down Expand Up @@ -138,7 +141,10 @@ defmodule Bumblebee.Text do
are optionally padded to always match this batch size
* `:sequence_length` - the maximum input sequence length. Input
sequences are always padded/truncated to match that length
sequences are always padded/truncated to match that length.
A list can be given, in which case the serving compiles
a separate computation for each length and then inputs are
matched to the smallest bounding length
It is advised to set this option in production and also configure
a defn compiler using `:defn_options` to maximally reduce inference
Expand Down Expand Up @@ -203,6 +209,9 @@ defmodule Bumblebee.Text do
* `:sequence_length` - the maximum input sequence length. Input
sequences are always padded/truncated to match that length.
A list can be given, in which case the serving compiles
a separate computation for each length and then inputs are
matched to the smallest bounding length.
Note that in this case, the whole conversation history is the
input, so this value should be relatively large to allow long
history (though the supported upper limit depends on the model)
Expand Down Expand Up @@ -267,7 +276,10 @@ defmodule Bumblebee.Text do
are optionally padded to always match this batch size
* `:sequence_length` - the maximum input sequence length. Input
sequences are always padded/truncated to match that length
sequences are always padded/truncated to match that length.
A list can be given, in which case the serving compiles
a separate computation for each length and then inputs are
matched to the smallest bounding length
It is advised to set this option in production and also configure
a defn compiler using `:defn_options` to maximally reduce inference
Expand Down Expand Up @@ -336,7 +348,10 @@ defmodule Bumblebee.Text do
are optionally padded to always match this batch size
* `:sequence_length` - the maximum input sequence length. Input
sequences are always padded/truncated to match that length
sequences are always padded/truncated to match that length.
A list can be given, in which case the serving compiles
a separate computation for each length and then inputs are
matched to the smallest bounding length
It is advised to set this option in production and also configure
a defn compiler using `:defn_options` to maximally reduce inference
Expand Down Expand Up @@ -398,7 +413,10 @@ defmodule Bumblebee.Text do
are optionally padded to always match this batch size
* `:sequence_length` - the maximum input sequence length. Input
sequences are always padded/truncated to match that length
sequences are always padded/truncated to match that length.
A list can be given, in which case the serving compiles
a separate computation for each length and then inputs are
matched to the smallest bounding length
It is advised to set this option in production and also configure
a defn compiler using `:defn_options` to maximally reduce inference
Expand Down Expand Up @@ -441,6 +459,7 @@ defmodule Bumblebee.Text do
end: number(),
score: number()
}

@doc """
Builds serving for the question answering task.
Expand All @@ -463,7 +482,10 @@ defmodule Bumblebee.Text do
prompt and label
* `:sequence_length` - the maximum input sequence length. Input
sequences are always padded/truncated to match that length
sequences are always padded/truncated to match that length.
A list can be given, in which case the serving compiles
a separate computation for each length and then inputs are
matched to the smallest bounding length
It is advised to set this option in production and also configure
a defn compiler using `:defn_options` to maximally reduce inference
Expand Down Expand Up @@ -528,7 +550,10 @@ defmodule Bumblebee.Text do
prompt and label
* `:sequence_length` - the maximum input sequence length. Input
sequences are always padded/truncated to match that length
sequences are always padded/truncated to match that length.
A list can be given, in which case the serving compiles
a separate computation for each length and then inputs are
matched to the smallest bounding length
It is advised to set this option in production and also configure
a defn compiler using `:defn_options` to maximally reduce inference
Expand Down
13 changes: 10 additions & 3 deletions lib/bumblebee/text/conversation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,14 @@ defmodule Bumblebee.Text.Conversation do
generate_fun =
Text.Generation.build_generate(model, spec, generation_config, Keyword.take(opts, [:seed]))

batch_keys = Shared.sequence_batch_keys(sequence_length)

Nx.Serving.new(
fn defn_options ->
fn batch_key, defn_options ->
generate_fun =
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key

inputs = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :u32),
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32)
Expand All @@ -74,7 +78,7 @@ defmodule Bumblebee.Text.Conversation do
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size)
|> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn input ->
{histories, multi?} = Shared.validate_serving_input!(input, &validate_input/1)

Expand All @@ -91,7 +95,10 @@ defmodule Bumblebee.Text.Conversation do
return_token_type_ids: false
)

{Nx.Batch.concatenate([inputs]), {histories, multi?}}
batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)

{batch, {histories, multi?}}
end)
|> Nx.Serving.client_postprocessing(fn {token_ids, _metadata}, {histories, multi?} ->
decoded = Bumblebee.Tokenizer.decode(tokenizer, token_ids)
Expand Down
13 changes: 10 additions & 3 deletions lib/bumblebee/text/fill_mask.ex
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,14 @@ defmodule Bumblebee.Text.FillMask do
|> Nx.squeeze(axes: [1])
end

batch_keys = Shared.sequence_batch_keys(sequence_length)

Nx.Serving.new(
fn defn_options ->
fn batch_key, defn_options ->
scores_fun =
Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key

inputs = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :u32),
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32)
Expand All @@ -66,7 +70,7 @@ defmodule Bumblebee.Text.FillMask do
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size)
|> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn input ->
{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1)

Expand All @@ -78,7 +82,10 @@ defmodule Bumblebee.Text.FillMask do
return_token_type_ids: false
)

{Nx.Batch.concatenate([inputs]), multi?}
batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)

{batch, multi?}
end)
|> Nx.Serving.client_postprocessing(fn {scores, _metadata}, multi? ->
for scores <- Bumblebee.Utils.Nx.batch_to_list(scores) do
Expand Down
13 changes: 10 additions & 3 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,14 @@ defmodule Bumblebee.Text.Generation do

generate_fun = build_generate(model, spec, generation_config, Keyword.take(opts, [:seed]))

batch_keys = Shared.sequence_batch_keys(sequence_length)

Nx.Serving.new(
fn defn_options ->
fn batch_key, defn_options ->
generate_fun =
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key

inputs = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :u32),
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32)
Expand All @@ -94,7 +98,7 @@ defmodule Bumblebee.Text.Generation do
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size)
|> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn input ->
{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1)

Expand All @@ -105,7 +109,10 @@ defmodule Bumblebee.Text.Generation do
return_token_type_ids: false
)

{Nx.Batch.concatenate([inputs]), multi?}
batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)

{batch, multi?}
end)
|> Nx.Serving.client_postprocessing(fn {token_ids, _metadata}, multi? ->
decoded = Bumblebee.Tokenizer.decode(tokenizer, token_ids)
Expand Down
14 changes: 11 additions & 3 deletions lib/bumblebee/text/question_answering.ex
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,14 @@ defmodule Bumblebee.Text.QuestionAnswering do
%{start_scores: start_scores, end_scores: end_scores}
end

batch_keys = Shared.sequence_batch_keys(sequence_length)

Nx.Serving.new(
fn defn_options ->
fn batch_key, defn_options ->
predict_fun =
Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key

inputs = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :u32),
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32),
Expand All @@ -52,7 +56,7 @@ defmodule Bumblebee.Text.QuestionAnswering do
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size)
|> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn raw_input ->
{raw_inputs, multi?} =
Shared.validate_serving_input!(raw_input, fn
Expand All @@ -73,7 +77,11 @@ defmodule Bumblebee.Text.QuestionAnswering do
)

inputs = Map.take(all_inputs, ["input_ids", "attention_mask", "token_type_ids"])
{Nx.Batch.concatenate([inputs]), {all_inputs, raw_inputs, multi?}}

batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)

{batch, {all_inputs, raw_inputs, multi?}}
end)
|> Nx.Serving.client_postprocessing(fn {outputs, _metadata}, {inputs, raw_inputs, multi?} ->
Enum.zip_with(
Expand Down
13 changes: 10 additions & 3 deletions lib/bumblebee/text/text_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,14 @@ defmodule Bumblebee.Text.TextClassification do
Shared.logits_to_scores(outputs.logits, scores_function)
end

batch_keys = Shared.sequence_batch_keys(sequence_length)

Nx.Serving.new(
fn defn_options ->
fn batch_key, defn_options ->
scores_fun =
Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key

inputs = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :u32),
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32)
Expand All @@ -50,7 +54,7 @@ defmodule Bumblebee.Text.TextClassification do
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size)
|> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn input ->
{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1)

Expand All @@ -60,7 +64,10 @@ defmodule Bumblebee.Text.TextClassification do
return_token_type_ids: false
)

{Nx.Batch.concatenate([inputs]), multi?}
batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)

{batch, multi?}
end)
|> Nx.Serving.client_postprocessing(fn {scores, _metadata}, multi? ->
for scores <- Bumblebee.Utils.Nx.batch_to_list(scores) do
Expand Down
13 changes: 10 additions & 3 deletions lib/bumblebee/text/text_embedding.ex
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,14 @@ defmodule Bumblebee.Text.TextEmbedding do
output
end

batch_keys = Shared.sequence_batch_keys(sequence_length)

Nx.Serving.new(
fn defn_options ->
fn batch_key, defn_options ->
embedding_fun =
Shared.compile_or_jit(embedding_fun, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key

inputs = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :u32),
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32)
Expand All @@ -95,7 +99,7 @@ defmodule Bumblebee.Text.TextEmbedding do
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size)
|> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn input ->
{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1)

Expand All @@ -105,7 +109,10 @@ defmodule Bumblebee.Text.TextEmbedding do
return_token_type_ids: false
)

{Nx.Batch.concatenate([inputs]), multi?}
batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)

{batch, multi?}
end)
|> Nx.Serving.client_postprocessing(fn {embeddings, _metadata}, multi? ->
for embedding <- Bumblebee.Utils.Nx.batch_to_list(embeddings) do
Expand Down
Loading

0 comments on commit b9bc8f0

Please sign in to comment.