Skip to content

Commit

Permalink
Add token summary to text generation output
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Feb 14, 2024
1 parent c92b9a6 commit d6af0f4
Show file tree
Hide file tree
Showing 11 changed files with 280 additions and 115 deletions.
3 changes: 2 additions & 1 deletion lib/bumblebee/audio/speech_to_text_whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
generate_fun = fn params, {inputs, seed} ->
inputs = Bumblebee.Featurizer.process_batch(featurizer, inputs)
inputs = Map.put(inputs, "seed", seed)
generate_fun.(params, inputs)
%{token_ids: token_ids} = generate_fun.(params, inputs)
token_ids
end

Nx.Serving.new(
Expand Down
13 changes: 12 additions & 1 deletion lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,12 @@ defmodule Bumblebee.Text do
@type generation_input ::
String.t() | %{:text => String.t(), optional(:seed) => integer()}
@type generation_output :: %{results: list(generation_result())}
@type generation_result :: %{text: String.t()}
@type generation_result :: %{text: String.t(), token_summary: token_summary()}
@type token_summary :: %{
input: pos_integer(),
outout: pos_integer(),
padding: non_neg_integer()
}

@doc """
Builds serving for prompt-driven text generation.
Expand Down Expand Up @@ -172,6 +177,12 @@ defmodule Bumblebee.Text do
serving. To process a batch, call the serving with each input
separately. Defaults to `false`
* `:stream_done` - when `:stream` is enabled, this enables a final
event, after all chunks have been emitted. The event has the
shape `{:done, result}`, where `result` includes the same fields
as `t:generation_result/0`, except for `:text`, which has been
already streamed. Defaults to `false`
## Examples
{:ok, model_info} = Bumblebee.load_model({:hf, "gpt2"})
Expand Down
164 changes: 108 additions & 56 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ defmodule Bumblebee.Text.Generation do
Bumblebee.ModelSpec.t(),
Bumblebee.Text.GenerationConfig.t(),
keyword()
) :: (params :: map(), inputs :: map() -> Nx.t())
) ::
(params :: map(), inputs :: map() -> %{token_ids: Nx.Tensor.t(), length: Nx.Tensor.t()})
def build_generate(model, spec, config, opts \\ []) do
opts = Keyword.validate!(opts, logits_processors: [])

Expand Down Expand Up @@ -358,7 +359,7 @@ defmodule Bumblebee.Text.Generation do

strategy = opts[:strategy]

sequences =
{sequences, finished_length} =
case strategy.type do
:greedy_search ->
greedy(
Expand Down Expand Up @@ -399,8 +400,11 @@ defmodule Bumblebee.Text.Generation do
)
end

# Output only the newly generated tokens
sequences[[.., length..-1//1]]
%{
# Output only the newly generated tokens
token_ids: sequences[[.., length..-1//1]],
length: finished_length - length
}
end

deftransformp pop_seed(inputs), do: Map.pop!(inputs, "seed")
Expand All @@ -422,17 +426,17 @@ defmodule Bumblebee.Text.Generation do
pad_token_id = opts[:pad_token_id]
eos_token_id = opts[:eos_token_id]

{sequences, length = input_length, finished?} =
{sequences, length = input_length, finished_length} =
init_sequences(decoder_input_ids, max_length, pad_token_id)

# The loop works with inputs of length 1, so if the initial input
# is longer, we make the initial pass outside
{sequences, length, finished?, inputs} =
{sequences, length, finished_length, inputs} =
if length > 1 do
greedy_step(
sequences,
length,
finished?,
finished_length,
inputs,
input_length,
predict_fun,
Expand All @@ -443,17 +447,17 @@ defmodule Bumblebee.Text.Generation do
eos_token_id: eos_token_id
)
else
{sequences, length, finished?, inputs}
{sequences, length, finished_length, inputs}
end

{sequences, _length, _finished?, _inputs, _params} =
while {sequences, length, finished?, inputs, params},
continue?(finished?, length, max_length) do
{sequences, length, finished?, inputs} =
{sequences, _length, finished_length, _inputs, _params} =
while {sequences, length, finished_length, inputs, params},
continue?(finished_length) do
{sequences, length, finished_length, inputs} =
greedy_step(
sequences,
length,
finished?,
finished_length,
inputs,
input_length,
predict_fun,
Expand All @@ -464,10 +468,10 @@ defmodule Bumblebee.Text.Generation do
eos_token_id: eos_token_id
)

{sequences, length, finished?, inputs, params}
{sequences, length, finished_length, inputs, params}
end

sequences
{sequences, finished_length}
end

defnp init_sequences(decoder_input_ids, max_length, pad_token_id) do
Expand All @@ -476,19 +480,21 @@ defmodule Bumblebee.Text.Generation do
sequences = Nx.broadcast(pad_token_id, {batch_size, max_length})
sequences = Nx.put_slice(sequences, [0, 0], decoder_input_ids)

finished? = Nx.broadcast(Nx.tensor(0, type: :u8), {batch_size})
# For each sequence, we keep track of its final length, where 0
# means that it has not been finished yet
finished_length = Nx.broadcast(0, {batch_size})

{sequences, length, finished?}
{sequences, length, finished_length}
end

defnp continue?(finished?, length, max_length) do
not Nx.all(finished?) and length < max_length
defnp continue?(finished_length) do
Nx.any(finished_length == 0)
end

defnp greedy_step(
sequences,
length,
finished?,
finished_length,
inputs,
input_length,
predict_fun,
Expand All @@ -506,29 +512,59 @@ defmodule Bumblebee.Text.Generation do
logits = batch_process_logits(logits_processor_fun, logits, sequences, length, input_length)
token_id = Nx.argmax(logits, axis: -1)

{sequences, length, finished?} =
update_sequences(sequences, length, finished?, token_id, pad_token_id, eos_token_id)
{sequences, length, finished_length} =
update_sequences(
sequences,
input_length,
length,
finished_length,
token_id,
pad_token_id,
eos_token_id
)

inputs = update_inputs_fun.(inputs, outputs.cache, Nx.new_axis(token_id, -1))

{sequences, length, finished?, inputs}
{sequences, length, finished_length, inputs}
end

defnp update_sequences(sequences, length, finished?, token_id, pad_token_id, eos_token_id) do
token_id = Nx.select(finished?, pad_token_id, token_id)
defnp update_sequences(
sequences,
input_length,
length,
finished_length,
token_id,
pad_token_id,
eos_token_id
) do
token_id = Nx.select(finished_length > 0, pad_token_id, token_id)

finished? =
token_ids = Nx.new_axis(token_id, -1)
sequences = Nx.put_slice(sequences, [0, length], token_ids)
length = length + 1

{batch_size, max_length} = Nx.shape(sequences)

finished_length =
case eos_token_id do
nil -> finished?
eos_token_id -> finished? or token_id == eos_token_id
end
nil ->
finished_length

{token_id, finished?} = hook({token_id, finished?}, :token)
eos_token_id ->
Nx.select(
finished_length == 0 and (token_id == eos_token_id or length == max_length),
length,
finished_length
)
end

token_ids = Nx.new_axis(token_id, -1)
sequences = Nx.put_slice(sequences, [0, length], token_ids)
finished? = finished_length > 0
output_length = Nx.broadcast(length - input_length, {batch_size})
data = %{token_id: token_id, finished?: finished?, length: output_length}
token = create_token()
{token, _} = hook_token(token, data, :token)

{sequences, length + 1, finished?}
attach_token(token, {sequences, length, finished_length})
end

defnp batch_process_logits(logits_processor_fun, logits, sequences, length, input_length) do
Expand Down Expand Up @@ -560,7 +596,7 @@ defmodule Bumblebee.Text.Generation do
top_k = opts[:top_k]
penalty_alpha = opts[:penalty_alpha]

{sequences, length = input_length, finished?} =
{sequences, length = input_length, finished_length} =
init_sequences(decoder_input_ids, max_length, pad_token_id)

# Step (1)
Expand Down Expand Up @@ -593,10 +629,10 @@ defmodule Bumblebee.Text.Generation do
# pick the best one using the contrastive rank. From the same model
# pass we also get the next top-k continuation tokens

{sequences, _length, _finished?, _inputs, _params, _joint_hidden_state, _top_k_values} =
while {sequences, length, finished?, inputs, params, joint_hidden_state,
{sequences, _length, finished_length, _inputs, _params, _joint_hidden_state, _top_k_values} =
while {sequences, length, finished_length, inputs, params, joint_hidden_state,
{top_k_scores, top_k_token_ids}},
continue?(finished?, length, max_length) do
continue?(finished_length) do
outputs = predict_fun.(params, inputs)

hidden_state = decoder_hidden_state(outputs)
Expand All @@ -618,8 +654,16 @@ defmodule Bumblebee.Text.Generation do

token_id = top_k_token_ids |> Nx.flatten() |> Utils.Nx.chunked_take(top_k, selected_idx)

{sequences, length, finished?} =
update_sequences(sequences, length, finished?, token_id, pad_token_id, eos_token_id)
{sequences, length, finished_length} =
update_sequences(
sequences,
input_length,
length,
finished_length,
token_id,
pad_token_id,
eos_token_id
)

logits = outputs.logits[[.., -1]]
logits = Utils.Nx.chunked_take(logits, top_k, selected_idx)
Expand All @@ -634,11 +678,11 @@ defmodule Bumblebee.Text.Generation do
cache = reflect_cache(outputs.cache, top_k, selected_idx, traverse_cache_fun)
inputs = update_inputs_fun.(inputs, cache, Nx.reshape(top_k_token_ids, {:auto, 1}))

{sequences, length, finished?, inputs, params, joint_hidden_state,
{sequences, length, finished_length, inputs, params, joint_hidden_state,
{top_k_scores, top_k_token_ids}}
end

sequences
{sequences, finished_length}
end

deftransformp decoder_hidden_state(outputs) do
Expand Down Expand Up @@ -723,19 +767,19 @@ defmodule Bumblebee.Text.Generation do
pad_token_id = opts[:pad_token_id]
eos_token_id = opts[:eos_token_id]

{sequences, length = input_length, finished?} =
{sequences, length = input_length, finished_length} =
init_sequences(decoder_input_ids, max_length, pad_token_id)

prng_key = seed |> Nx.vectorize(:batch) |> Nx.Random.key()

# The loop works with inputs of length 1, so if the initial input
# is longer, we make the initial pass outside
{sequences, length, finished?, inputs, prng_key} =
{sequences, length, finished_length, inputs, prng_key} =
if length > 1 do
sampling_step(
sequences,
length,
finished?,
finished_length,
inputs,
input_length,
predict_fun,
Expand All @@ -747,17 +791,17 @@ defmodule Bumblebee.Text.Generation do
eos_token_id: eos_token_id
)
else
{sequences, length, finished?, inputs, prng_key}
{sequences, length, finished_length, inputs, prng_key}
end

{sequences, _length, _finished?, _inputs, _params, _key} =
while {sequences, length, finished?, inputs, params, prng_key},
continue?(finished?, length, max_length) do
{sequences, length, finished?, inputs, prng_key} =
{sequences, _length, finished_length, _inputs, _params, _key} =
while {sequences, length, finished_length, inputs, params, prng_key},
continue?(finished_length) do
{sequences, length, finished_length, inputs, prng_key} =
sampling_step(
sequences,
length,
finished?,
finished_length,
inputs,
input_length,
predict_fun,
Expand All @@ -769,16 +813,16 @@ defmodule Bumblebee.Text.Generation do
eos_token_id: eos_token_id
)

{sequences, length, finished?, inputs, params, prng_key}
{sequences, length, finished_length, inputs, params, prng_key}
end

sequences
{sequences, finished_length}
end

defnp sampling_step(
sequences,
length,
finished?,
finished_length,
inputs,
input_length,
predict_fun,
Expand All @@ -801,12 +845,20 @@ defmodule Bumblebee.Text.Generation do
scores = Axon.Activations.softmax(logits)
token_id = batched_choice(key, scores)

{sequences, length, finished?} =
update_sequences(sequences, length, finished?, token_id, pad_token_id, eos_token_id)
{sequences, length, finished_length} =
update_sequences(
sequences,
input_length,
length,
finished_length,
token_id,
pad_token_id,
eos_token_id
)

inputs = update_inputs_fun.(inputs, outputs.cache, Nx.new_axis(token_id, -1))

{sequences, length, finished?, inputs, prng_key}
{sequences, length, finished_length, inputs, prng_key}
end

deftransformp batched_choice(key, scores) do
Expand Down
Loading

0 comments on commit d6af0f4

Please sign in to comment.