Skip to content

Commit

Permalink
Add multilingual translation serving (#395)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Aug 22, 2024
1 parent 17e4397 commit 9421eca
Show file tree
Hide file tree
Showing 10 changed files with 398 additions and 21 deletions.
6 changes: 3 additions & 3 deletions lib/bumblebee/audio/speech_to_text_whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do

{stream, {}}
end)
|> maybe_stream(opts[:stream], spec, featurizer, tokenizer, options)
|> add_postprocessing(opts[:stream], spec, featurizer, tokenizer, options)
end

defp validate_input(%{audio: audio} = input, sampling_rate, chunk_num_seconds) do
Expand Down Expand Up @@ -351,7 +351,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
end
end

defp maybe_stream(serving, false, spec, featurizer, tokenizer, options) do
defp add_postprocessing(serving, false, spec, featurizer, tokenizer, options) do
Nx.Serving.client_postprocessing(serving, fn {outputs, _metadata}, {} ->
outputs = Nx.to_list(outputs)
state = decode_chunk_outputs_init(spec, featurizer, tokenizer)
Expand All @@ -362,7 +362,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
end)
end

defp maybe_stream(serving, true, spec, featurizer, tokenizer, options) do
defp add_postprocessing(serving, true, spec, featurizer, tokenizer, options) do
serving
|> Nx.Serving.streaming()
|> Nx.Serving.client_postprocessing(fn stream, {} ->
Expand Down
68 changes: 66 additions & 2 deletions lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ defmodule Bumblebee.Text do
#=> %{
#=> results: [
#=> %{
#=> text: "Elixir is a functional programming language that is designed to be used in a variety of applications. It"
#=> text: " programming language that is designed to be used in a variety of applications. It",
#=> token_summary: %{input: 5, output: 15, padding: 0}
#=> }
#=> ]
#=> }
Expand Down Expand Up @@ -224,6 +225,69 @@ defmodule Bumblebee.Text do
defdelegate generation(model_info, tokenizer, generation_config, opts \\ []),
to: Bumblebee.Text.TextGeneration

@type translation_input ::
%{
:text => String.t(),
:source_language_token => String.t(),
:target_language_token => String.t(),
optional(:seed) => integer() | nil
}
@type translation_output :: generation_output()

@doc """
Builds serving for text translation.
The serving accepts `t:translation_input/0` and returns `t:translation_output/0`.
A list of inputs is also supported.
This serving is an extension of `generation/4` that handles per-input
language configuration.
Note that this serving is designed for multilingual models that
require source/target language to be specified. Some text models are
trained for specific language pairs, others expect a command such as
"translate English to Spanish", in such cases you most likely want
to use `generation/4`.
## Options
See `generation/4` for available options.
## Examples
repository_id = "facebook/nllb-200-distilled-600M"
{:ok, model_info} = Bumblebee.load_model({:hf, repository_id})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, repository_id})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, repository_id})
serving = Bumblebee.Text.translation(model_info, tokenizer, generation_config)
text = "The bank of the river is beautiful in spring"
Nx.Serving.run(serving, %{
text: text,
source_language_token: "eng_Latn",
target_language_token: "pol_Latn"
})
#=> %{
#=> results: [
#=> %{
#=> text: "W wiosnę brzeg rzeki jest piękny",
#=> token_summary: %{input: 11, output: 13, padding: 0}
#=> }
#=> ]
#=> }
"""
@spec translation(
Bumblebee.model_info(),
Bumblebee.Tokenizer.t(),
Bumblebee.Text.GenerationConfig.t(),
keyword()
) :: Nx.Serving.t()
defdelegate translation(model_info, tokenizer, generation_config, opts \\ []),
to: Bumblebee.Text.Translation

@type text_classification_input :: String.t()
@type text_classification_output :: %{predictions: list(text_classification_prediction())}
@type text_classification_prediction :: %{score: number(), label: String.t()}
Expand Down Expand Up @@ -316,7 +380,7 @@ defmodule Bumblebee.Text do
it is not already a pooled embedding. Supported values:
* `:mean_pooling` - performs a mean across all tokens
* `cls_token_pooling` - takes the embedding for the special CLS token.
Note that we currently assume that the CLS token is the first token
in the sequence
Expand Down
10 changes: 5 additions & 5 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,13 @@ defmodule Bumblebee.Text.Generation do
encoder_outputs = encoder_predict_fun.(params, inputs)

batch_size = Nx.axis_size(encoder_input(inputs), 0)
decoder_input_ids = Nx.broadcast(decoder_start_token_id, {batch_size, 1})

inputs = Map.put(inputs, "encoder_hidden_state", encoder_outputs.hidden_state)

inputs =
Map.merge(inputs, %{
"encoder_hidden_state" => encoder_outputs.hidden_state,
"decoder_input_ids" => decoder_input_ids
})
Map.put_new_lazy(inputs, "decoder_input_ids", fn ->
Nx.broadcast(decoder_start_token_id, {batch_size, 1})
end)

max_length = max_length_fun.(1)
inputs = prepare_decoder_inputs(inputs, "decoder_", spec, model, max_length)
Expand Down
65 changes: 61 additions & 4 deletions lib/bumblebee/text/pre_trained_tokenizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do
whether to return the sequence length. The length is the effective number of tokens,
so it is calculated after truncation, but does not include padding
"""
],
template_options: [
default: [],
doc: """
options configuring the tokenization template, specific to the given tokenizer type.
Recognised options are:
* `:language_token` - for tokenizers: `:nllb`
"""
]
]

Expand Down Expand Up @@ -187,7 +197,8 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do
pad: "<pad>",
cls: "<s>",
mask: "<mask>"
}
},
default_template_options: [language_token: "eng_Latn"]
},
roberta: %{
special_tokens: %{
Expand Down Expand Up @@ -246,8 +257,15 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do
# special tokens added by a template post-processor. By setting
# truncation upfront, the tokenizer will apply it before the
# post-processor accounting for the extra special tokens
if Keyword.has_key?(opts, :length) or Keyword.has_key?(opts, :truncation_direction) do
update_truncation(tokenizer)
tokenizer =
if Keyword.has_key?(opts, :length) or Keyword.has_key?(opts, :truncation_direction) do
update_truncation(tokenizer)
else
tokenizer
end

if Keyword.has_key?(opts, :template_options) do
set_template(tokenizer)
else
tokenizer
end
Expand All @@ -269,6 +287,42 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do
)
end

defp set_template(%{type: :nllb} = tokenizer) do
language_token = Keyword.fetch!(tokenizer.template_options, :language_token)
eos_token = tokenizer.special_tokens.eos

set_template_postprocessor(
tokenizer,
"#{language_token} $A #{eos_token}",
"#{language_token} $A $B #{eos_token}",
[language_token, eos_token]
)
end

defp set_template(%{type: type} = tokenizer) do
if tokenizer.template_options != [] do
raise ArgumentError,
"#{inspect(type)} tokenizer expects no :template_options," <>
" got: #{inspect(tokenizer.template_options)}"
end

tokenizer
end

defp set_template_postprocessor(tokenizer, single, pair, special_tokens) do
post_processor =
Tokenizers.PostProcessor.template(
single: single,
pair: pair,
special_tokens:
for token <- special_tokens do
{token, Tokenizer.token_to_id(tokenizer.native_tokenizer, token)}
end
)

update_in(tokenizer.native_tokenizer, &Tokenizer.set_post_processor(&1, post_processor))
end

@impl true
def apply(tokenizer, input) do
input = List.wrap(input)
Expand Down Expand Up @@ -480,7 +534,7 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do
" but got: #{inspect(tokenizer.type)}"
end

%{special_tokens: special_tokens} = tokenizer_types[tokenizer.type]
tokenizer_type = %{special_tokens: special_tokens} = tokenizer_types[tokenizer.type]

special_tokens = load_special_tokens(special_tokens, special_tokens_map)

Expand All @@ -493,12 +547,15 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do
[]
end

template_options = tokenizer_type[:default_template_options] || []

%{
tokenizer
| native_tokenizer: native_tokenizer,
special_tokens: special_tokens,
additional_special_tokens: additional_special_tokens
}
|> @for.config(template_options: template_options)
end

defp load_special_tokens(special_tokens, data) do
Expand Down
9 changes: 6 additions & 3 deletions lib/bumblebee/text/text_generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ defmodule Bumblebee.Text.TextGeneration do

{batch, {multi?, input_length, input_padded_length}}
end)
|> maybe_stream(opts[:stream], opts[:stream_done], tokenizer)
|> add_postprocessing(opts[:stream], opts[:stream_done], tokenizer)
end

defp validate_input(text) when is_binary(text), do: validate_input(%{text: text})
Expand All @@ -117,7 +117,10 @@ defmodule Bumblebee.Text.TextGeneration do
{:error, "expected either a string or a map, got: #{inspect(input)}"}
end

defp maybe_stream(serving, false, _stream_done, tokenizer) do
@doc false
def add_postprocessing(serving, stream, stream_done, tokenizer)

def add_postprocessing(serving, false, _stream_done, tokenizer) do
Nx.Serving.client_postprocessing(
serving,
fn {%{token_ids: token_ids, length: length}, _metadata},
Expand All @@ -138,7 +141,7 @@ defmodule Bumblebee.Text.TextGeneration do
)
end

defp maybe_stream(serving, true, stream_done, tokenizer) do
def add_postprocessing(serving, true, stream_done, tokenizer) do
serving
|> Nx.Serving.streaming(hooks: [:token])
|> Nx.Serving.client_postprocessing(fn stream,
Expand Down
Loading

0 comments on commit 9421eca

Please sign in to comment.