Skip to content

Commit

Permalink
Reorder functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Aug 3, 2023
1 parent 0154888 commit 68d5f48
Showing 1 changed file with 170 additions and 168 deletions.
338 changes: 170 additions & 168 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -50,174 +50,6 @@ defmodule Bumblebee.Text.Generation do
module.traverse_cache(spec, cache, fun)
end

@doc false
def generation(model_info, tokenizer, %Text.GenerationConfig{} = generation_config, opts \\ []) do
opts = Keyword.validate!(opts, [:seed, :compile, defn_options: [], stream: false])

%{model: model, params: params, spec: spec} = model_info

Shared.validate_architecture!(spec, [
:for_conditional_generation,
:for_causal_language_modeling
])

defn_options = opts[:defn_options]

compile =
if compile = opts[:compile] do
compile
|> Keyword.validate!([:batch_size, :sequence_length])
|> Shared.require_options!([:batch_size, :sequence_length])
end

batch_size = compile[:batch_size]
sequence_length = compile[:sequence_length]

generate_fun = build_generate(model, spec, generation_config, Keyword.take(opts, [:seed]))

batch_keys = Shared.sequence_batch_keys(sequence_length)

Nx.Serving.new(
fn batch_key, defn_options ->
generate_fun =
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key

inputs = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :u32),
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32)
}

[params, inputs]
end)

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
generate_fun.(params, inputs)
end
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn input ->
if opts[:stream] do
Shared.validate_input_for_stream!(input)
end

{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1)

inputs =
Bumblebee.apply_tokenizer(tokenizer, texts,
length: sequence_length,
pad_direction: :left,
return_token_type_ids: false
)

batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)

{batch, multi?}
end)
|> Nx.Serving.client_postprocessing(fn {token_ids, _metadata}, multi? ->
decoded = Bumblebee.Tokenizer.decode(tokenizer, token_ids)

decoded
|> Enum.map(&%{results: [%{text: &1}]})
|> Shared.normalize_output(multi?)
end)
|> maybe_stream(opts[:stream], tokenizer)
end

defp maybe_stream(serving, false, _tokenizer), do: serving

defp maybe_stream(serving, true, tokenizer) do
serving
|> Nx.Serving.streaming(hooks: [:token])
|> Nx.Serving.client_postprocessing(fn stream, false = _multi? ->
Stream.transform(stream, %{tokens: [], consumed_size: 0, finished?: false}, fn
_event, %{finished?: true} = state ->
{:halt, state}

{:token, {token_id, finished?}}, state ->
token_id = Nx.to_number(token_id[0])
finished? = Nx.to_number(finished?[0]) == 1

state = %{state | tokens: state.tokens ++ [token_id], finished?: finished?}

chunk = pending_chunk(tokenizer, state)

cond do
# When the sequence is finished early or we reach a newline,
# we flush the cache
finished? or String.ends_with?(chunk, "\n") ->
{[chunk], %{state | tokens: [], consumed_size: 0}}

# CJK characters are tokenized atomically, so we can emit
# the chunk
chunk != "" and cjk_codepoint?(last_codepoint(chunk)) ->
state = update_in(state.consumed_size, &(&1 + byte_size(chunk)))
{[chunk], state}

# Emit chunk until the space. We need to keep tokens,
# because certain tokenizers do not encode whitespace in
# tokens and they add a space based on previous tokens
space_idx = find_last_occurrence(chunk, " ") ->
if space_idx > 0 do
chunk = binary_slice(chunk, 0, space_idx)
state = update_in(state.consumed_size, &(&1 + space_idx))
{[chunk], state}
else
{[], state}
end

true ->
{[], state}
end

{:done, _, _}, state ->
chunk = pending_chunk(tokenizer, state)

if chunk == "" do
{:halt, state}
else
{[chunk], %{state | tokens: [], consumed_size: 0}}
end
end)
end)
end

defp pending_chunk(tokenizer, state) do
text = Bumblebee.Tokenizer.decode(tokenizer, state.tokens)
binary_slice(text, state.consumed_size..-1//1)
end

defp find_last_occurrence(string, pattern) do
case :binary.matches(string, pattern) do
[] -> nil
matches -> matches |> List.last() |> elem(0)
end
end

defp last_codepoint(<<codepoint::utf8>>), do: codepoint
defp last_codepoint(<<_::utf8, rest::binary>>), do: last_codepoint(rest)

defp cjk_codepoint?(codepoint) do
# The specific ranges originated in [1] and are generally mirrored
# in other tokenizers using WordPiece. Also see [2].
#
# [1]: https://github.com/google-research/bert/blob/eedf5716ce1268e56f0a50264a88cafad334ac61/tokenization.py#L264-L284
# [2]: https://github.com/google-research/bert/blob/eedf5716ce1268e56f0a50264a88cafad334ac61/multilingual.md#tokenization

codepoint in 0x4E00..0x9FFF or
codepoint in 0x3400..0x4DBF or
codepoint in 0x20000..0x2A6DF or
codepoint in 0x2A700..0x2B73F or
codepoint in 0x2B740..0x2B81F or
codepoint in 0x2B820..0x2CEAF or
codepoint in 0xF900..0xFAFF or
codepoint in 0x2F800..0x2FA1F
end

@doc """
Builds a numerical definition that generates sequences of tokens using
the given language model.
Expand Down Expand Up @@ -961,4 +793,174 @@ defmodule Bumblebee.Text.Generation do
|> Nx.squeeze()
|> Nx.devectorize()
end

# Serving

@doc false
def generation(model_info, tokenizer, %Text.GenerationConfig{} = generation_config, opts \\ []) do
opts = Keyword.validate!(opts, [:seed, :compile, defn_options: [], stream: false])

%{model: model, params: params, spec: spec} = model_info

Shared.validate_architecture!(spec, [
:for_conditional_generation,
:for_causal_language_modeling
])

defn_options = opts[:defn_options]

compile =
if compile = opts[:compile] do
compile
|> Keyword.validate!([:batch_size, :sequence_length])
|> Shared.require_options!([:batch_size, :sequence_length])
end

batch_size = compile[:batch_size]
sequence_length = compile[:sequence_length]

generate_fun = build_generate(model, spec, generation_config, Keyword.take(opts, [:seed]))

batch_keys = Shared.sequence_batch_keys(sequence_length)

Nx.Serving.new(
fn batch_key, defn_options ->
generate_fun =
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key

inputs = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :u32),
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32)
}

[params, inputs]
end)

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
generate_fun.(params, inputs)
end
end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size, batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn input ->
if opts[:stream] do
Shared.validate_input_for_stream!(input)
end

{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1)

inputs =
Bumblebee.apply_tokenizer(tokenizer, texts,
length: sequence_length,
pad_direction: :left,
return_token_type_ids: false
)

batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)

{batch, multi?}
end)
|> Nx.Serving.client_postprocessing(fn {token_ids, _metadata}, multi? ->
decoded = Bumblebee.Tokenizer.decode(tokenizer, token_ids)

decoded
|> Enum.map(&%{results: [%{text: &1}]})
|> Shared.normalize_output(multi?)
end)
|> maybe_stream(opts[:stream], tokenizer)
end

defp maybe_stream(serving, false, _tokenizer), do: serving

defp maybe_stream(serving, true, tokenizer) do
serving
|> Nx.Serving.streaming(hooks: [:token])
|> Nx.Serving.client_postprocessing(fn stream, false = _multi? ->
Stream.transform(stream, %{tokens: [], consumed_size: 0, finished?: false}, fn
_event, %{finished?: true} = state ->
{:halt, state}

{:token, {token_id, finished?}}, state ->
token_id = Nx.to_number(token_id[0])
finished? = Nx.to_number(finished?[0]) == 1

state = %{state | tokens: state.tokens ++ [token_id], finished?: finished?}

chunk = pending_chunk(tokenizer, state)

cond do
# When the sequence is finished early or we reach a newline,
# we flush the cache
finished? or String.ends_with?(chunk, "\n") ->
{[chunk], %{state | tokens: [], consumed_size: 0}}

# CJK characters are tokenized atomically, so we can emit
# the chunk
chunk != "" and cjk_codepoint?(last_codepoint(chunk)) ->
state = update_in(state.consumed_size, &(&1 + byte_size(chunk)))
{[chunk], state}

# Emit chunk until the space. We need to keep tokens,
# because certain tokenizers do not encode whitespace in
# tokens and they add a space based on previous tokens
space_idx = find_last_occurrence(chunk, " ") ->
if space_idx > 0 do
chunk = binary_slice(chunk, 0, space_idx)
state = update_in(state.consumed_size, &(&1 + space_idx))
{[chunk], state}
else
{[], state}
end

true ->
{[], state}
end

{:done, _, _}, state ->
chunk = pending_chunk(tokenizer, state)

if chunk == "" do
{:halt, state}
else
{[chunk], %{state | tokens: [], consumed_size: 0}}
end
end)
end)
end

defp pending_chunk(tokenizer, state) do
text = Bumblebee.Tokenizer.decode(tokenizer, state.tokens)
binary_slice(text, state.consumed_size..-1//1)
end

defp find_last_occurrence(string, pattern) do
case :binary.matches(string, pattern) do
[] -> nil
matches -> matches |> List.last() |> elem(0)
end
end

defp last_codepoint(<<codepoint::utf8>>), do: codepoint
defp last_codepoint(<<_::utf8, rest::binary>>), do: last_codepoint(rest)

defp cjk_codepoint?(codepoint) do
# The specific ranges originated in [1] and are generally mirrored
# in other tokenizers using WordPiece. Also see [2].
#
# [1]: https://github.com/google-research/bert/blob/eedf5716ce1268e56f0a50264a88cafad334ac61/tokenization.py#L264-L284
# [2]: https://github.com/google-research/bert/blob/eedf5716ce1268e56f0a50264a88cafad334ac61/multilingual.md#tokenization

codepoint in 0x4E00..0x9FFF or
codepoint in 0x3400..0x4DBF or
codepoint in 0x20000..0x2A6DF or
codepoint in 0x2A700..0x2B73F or
codepoint in 0x2B740..0x2B81F or
codepoint in 0x2B820..0x2CEAF or
codepoint in 0xF900..0xFAFF or
codepoint in 0x2F800..0x2FA1F
end
end

0 comments on commit 68d5f48

Please sign in to comment.