Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add token summary to text generation output #336

Merged
merged 1 commit into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lib/bumblebee/audio/speech_to_text_whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
generate_fun = fn params, {inputs, seed} ->
inputs = Bumblebee.Featurizer.process_batch(featurizer, inputs)
inputs = Map.put(inputs, "seed", seed)
generate_fun.(params, inputs)
%{token_ids: token_ids} = generate_fun.(params, inputs)
token_ids
end

Nx.Serving.new(
Expand Down
13 changes: 12 additions & 1 deletion lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,12 @@ defmodule Bumblebee.Text do
@type generation_input ::
String.t() | %{:text => String.t(), optional(:seed) => integer()}
@type generation_output :: %{results: list(generation_result())}
@type generation_result :: %{text: String.t()}
@type generation_result :: %{text: String.t(), token_summary: token_summary()}
@type token_summary :: %{
input: pos_integer(),
outout: pos_integer(),
padding: non_neg_integer()
}

@doc """
Builds serving for prompt-driven text generation.
Expand Down Expand Up @@ -172,6 +177,12 @@ defmodule Bumblebee.Text do
serving. To process a batch, call the serving with each input
separately. Defaults to `false`

* `:stream_done` - when `:stream` is enabled, this enables a final
event, after all chunks have been emitted. The event has the
shape `{:done, result}`, where `result` includes the same fields
as `t:generation_result/0`, except for `:text`, which has been
already streamed. Defaults to `false`

## Examples

{:ok, model_info} = Bumblebee.load_model({:hf, "gpt2"})
Expand Down
164 changes: 108 additions & 56 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ defmodule Bumblebee.Text.Generation do
Bumblebee.ModelSpec.t(),
Bumblebee.Text.GenerationConfig.t(),
keyword()
) :: (params :: map(), inputs :: map() -> Nx.t())
) ::
(params :: map(), inputs :: map() -> %{token_ids: Nx.Tensor.t(), length: Nx.Tensor.t()})
def build_generate(model, spec, config, opts \\ []) do
opts = Keyword.validate!(opts, logits_processors: [])

Expand Down Expand Up @@ -358,7 +359,7 @@ defmodule Bumblebee.Text.Generation do

strategy = opts[:strategy]

sequences =
{sequences, finished_length} =
case strategy.type do
:greedy_search ->
greedy(
Expand Down Expand Up @@ -399,8 +400,11 @@ defmodule Bumblebee.Text.Generation do
)
end

# Output only the newly generated tokens
sequences[[.., length..-1//1]]
%{
# Output only the newly generated tokens
token_ids: sequences[[.., length..-1//1]],
length: finished_length - length
}
end

deftransformp pop_seed(inputs), do: Map.pop!(inputs, "seed")
Expand All @@ -422,17 +426,17 @@ defmodule Bumblebee.Text.Generation do
pad_token_id = opts[:pad_token_id]
eos_token_id = opts[:eos_token_id]

{sequences, length = input_length, finished?} =
{sequences, length = input_length, finished_length} =
init_sequences(decoder_input_ids, max_length, pad_token_id)

# The loop works with inputs of length 1, so if the initial input
# is longer, we make the initial pass outside
{sequences, length, finished?, inputs} =
{sequences, length, finished_length, inputs} =
if length > 1 do
greedy_step(
sequences,
length,
finished?,
finished_length,
inputs,
input_length,
predict_fun,
Expand All @@ -443,17 +447,17 @@ defmodule Bumblebee.Text.Generation do
eos_token_id: eos_token_id
)
else
{sequences, length, finished?, inputs}
{sequences, length, finished_length, inputs}
end

{sequences, _length, _finished?, _inputs, _params} =
while {sequences, length, finished?, inputs, params},
continue?(finished?, length, max_length) do
{sequences, length, finished?, inputs} =
{sequences, _length, finished_length, _inputs, _params} =
while {sequences, length, finished_length, inputs, params},
continue?(finished_length) do
{sequences, length, finished_length, inputs} =
greedy_step(
sequences,
length,
finished?,
finished_length,
inputs,
input_length,
predict_fun,
Expand All @@ -464,10 +468,10 @@ defmodule Bumblebee.Text.Generation do
eos_token_id: eos_token_id
)

{sequences, length, finished?, inputs, params}
{sequences, length, finished_length, inputs, params}
end

sequences
{sequences, finished_length}
end

defnp init_sequences(decoder_input_ids, max_length, pad_token_id) do
Expand All @@ -476,19 +480,21 @@ defmodule Bumblebee.Text.Generation do
sequences = Nx.broadcast(pad_token_id, {batch_size, max_length})
sequences = Nx.put_slice(sequences, [0, 0], decoder_input_ids)

finished? = Nx.broadcast(Nx.tensor(0, type: :u8), {batch_size})
# For each sequence, we keep track of its final length, where 0
# means that it has not been finished yet
finished_length = Nx.broadcast(0, {batch_size})

{sequences, length, finished?}
{sequences, length, finished_length}
end

defnp continue?(finished?, length, max_length) do
not Nx.all(finished?) and length < max_length
defnp continue?(finished_length) do
Nx.any(finished_length == 0)
end

defnp greedy_step(
sequences,
length,
finished?,
finished_length,
inputs,
input_length,
predict_fun,
Expand All @@ -506,29 +512,59 @@ defmodule Bumblebee.Text.Generation do
logits = batch_process_logits(logits_processor_fun, logits, sequences, length, input_length)
token_id = Nx.argmax(logits, axis: -1)

{sequences, length, finished?} =
update_sequences(sequences, length, finished?, token_id, pad_token_id, eos_token_id)
{sequences, length, finished_length} =
update_sequences(
sequences,
input_length,
length,
finished_length,
token_id,
pad_token_id,
eos_token_id
)

inputs = update_inputs_fun.(inputs, outputs.cache, Nx.new_axis(token_id, -1))

{sequences, length, finished?, inputs}
{sequences, length, finished_length, inputs}
end

defnp update_sequences(sequences, length, finished?, token_id, pad_token_id, eos_token_id) do
token_id = Nx.select(finished?, pad_token_id, token_id)
defnp update_sequences(
sequences,
input_length,
length,
finished_length,
token_id,
pad_token_id,
eos_token_id
) do
token_id = Nx.select(finished_length > 0, pad_token_id, token_id)

finished? =
token_ids = Nx.new_axis(token_id, -1)
sequences = Nx.put_slice(sequences, [0, length], token_ids)
length = length + 1

{batch_size, max_length} = Nx.shape(sequences)

finished_length =
case eos_token_id do
nil -> finished?
eos_token_id -> finished? or token_id == eos_token_id
end
nil ->
finished_length

{token_id, finished?} = hook({token_id, finished?}, :token)
eos_token_id ->
Nx.select(
finished_length == 0 and (token_id == eos_token_id or length == max_length),
length,
finished_length
)
end

token_ids = Nx.new_axis(token_id, -1)
sequences = Nx.put_slice(sequences, [0, length], token_ids)
finished? = finished_length > 0
output_length = Nx.broadcast(length - input_length, {batch_size})
data = %{token_id: token_id, finished?: finished?, length: output_length}
token = create_token()
{token, _} = hook_token(token, data, :token)

{sequences, length + 1, finished?}
attach_token(token, {sequences, length, finished_length})
end

defnp batch_process_logits(logits_processor_fun, logits, sequences, length, input_length) do
Expand Down Expand Up @@ -560,7 +596,7 @@ defmodule Bumblebee.Text.Generation do
top_k = opts[:top_k]
penalty_alpha = opts[:penalty_alpha]

{sequences, length = input_length, finished?} =
{sequences, length = input_length, finished_length} =
init_sequences(decoder_input_ids, max_length, pad_token_id)

# Step (1)
Expand Down Expand Up @@ -593,10 +629,10 @@ defmodule Bumblebee.Text.Generation do
# pick the best one using the contrastive rank. From the same model
# pass we also get the next top-k continuation tokens

{sequences, _length, _finished?, _inputs, _params, _joint_hidden_state, _top_k_values} =
while {sequences, length, finished?, inputs, params, joint_hidden_state,
{sequences, _length, finished_length, _inputs, _params, _joint_hidden_state, _top_k_values} =
while {sequences, length, finished_length, inputs, params, joint_hidden_state,
{top_k_scores, top_k_token_ids}},
continue?(finished?, length, max_length) do
continue?(finished_length) do
outputs = predict_fun.(params, inputs)

hidden_state = decoder_hidden_state(outputs)
Expand All @@ -618,8 +654,16 @@ defmodule Bumblebee.Text.Generation do

token_id = top_k_token_ids |> Nx.flatten() |> Utils.Nx.chunked_take(top_k, selected_idx)

{sequences, length, finished?} =
update_sequences(sequences, length, finished?, token_id, pad_token_id, eos_token_id)
{sequences, length, finished_length} =
update_sequences(
sequences,
input_length,
length,
finished_length,
token_id,
pad_token_id,
eos_token_id
)

logits = outputs.logits[[.., -1]]
logits = Utils.Nx.chunked_take(logits, top_k, selected_idx)
Expand All @@ -634,11 +678,11 @@ defmodule Bumblebee.Text.Generation do
cache = reflect_cache(outputs.cache, top_k, selected_idx, traverse_cache_fun)
inputs = update_inputs_fun.(inputs, cache, Nx.reshape(top_k_token_ids, {:auto, 1}))

{sequences, length, finished?, inputs, params, joint_hidden_state,
{sequences, length, finished_length, inputs, params, joint_hidden_state,
{top_k_scores, top_k_token_ids}}
end

sequences
{sequences, finished_length}
end

deftransformp decoder_hidden_state(outputs) do
Expand Down Expand Up @@ -723,19 +767,19 @@ defmodule Bumblebee.Text.Generation do
pad_token_id = opts[:pad_token_id]
eos_token_id = opts[:eos_token_id]

{sequences, length = input_length, finished?} =
{sequences, length = input_length, finished_length} =
init_sequences(decoder_input_ids, max_length, pad_token_id)

prng_key = seed |> Nx.vectorize(:batch) |> Nx.Random.key()

# The loop works with inputs of length 1, so if the initial input
# is longer, we make the initial pass outside
{sequences, length, finished?, inputs, prng_key} =
{sequences, length, finished_length, inputs, prng_key} =
if length > 1 do
sampling_step(
sequences,
length,
finished?,
finished_length,
inputs,
input_length,
predict_fun,
Expand All @@ -747,17 +791,17 @@ defmodule Bumblebee.Text.Generation do
eos_token_id: eos_token_id
)
else
{sequences, length, finished?, inputs, prng_key}
{sequences, length, finished_length, inputs, prng_key}
end

{sequences, _length, _finished?, _inputs, _params, _key} =
while {sequences, length, finished?, inputs, params, prng_key},
continue?(finished?, length, max_length) do
{sequences, length, finished?, inputs, prng_key} =
{sequences, _length, finished_length, _inputs, _params, _key} =
while {sequences, length, finished_length, inputs, params, prng_key},
continue?(finished_length) do
{sequences, length, finished_length, inputs, prng_key} =
sampling_step(
sequences,
length,
finished?,
finished_length,
inputs,
input_length,
predict_fun,
Expand All @@ -769,16 +813,16 @@ defmodule Bumblebee.Text.Generation do
eos_token_id: eos_token_id
)

{sequences, length, finished?, inputs, params, prng_key}
{sequences, length, finished_length, inputs, params, prng_key}
end

sequences
{sequences, finished_length}
end

defnp sampling_step(
sequences,
length,
finished?,
finished_length,
inputs,
input_length,
predict_fun,
Expand All @@ -801,12 +845,20 @@ defmodule Bumblebee.Text.Generation do
scores = Axon.Activations.softmax(logits)
token_id = batched_choice(key, scores)

{sequences, length, finished?} =
update_sequences(sequences, length, finished?, token_id, pad_token_id, eos_token_id)
{sequences, length, finished_length} =
update_sequences(
sequences,
input_length,
length,
finished_length,
token_id,
pad_token_id,
eos_token_id
)

inputs = update_inputs_fun.(inputs, outputs.cache, Nx.new_axis(token_id, -1))

{sequences, length, finished?, inputs, prng_key}
{sequences, length, finished_length, inputs, prng_key}
end

deftransformp batched_choice(key, scores) do
Expand Down
Loading
Loading