Skip to content

Commit

Permalink
Support chunking to enable long-form transcription
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Aug 30, 2023
1 parent 56ab13c commit b764859
Show file tree
Hide file tree
Showing 11 changed files with 240 additions and 23 deletions.
13 changes: 13 additions & 0 deletions lib/bumblebee/audio.ex
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,19 @@ defmodule Bumblebee.Audio do
## Options
* `:chunk_num_seconds` - enables long-form transcription by splitting
the input into chunks of the given length. Models generally have
a limit on the input length, so by chunking we can feed smaller
bits into the model, then merge the individual outputs into a
single result at the end. By default chunking is disabled
* `:context_num_seconds` - specifies the amount of overlap between
chunks on both sides of split points. The context is effectively
discarded when merging the chunks at the end, but it improves
the results at the chunk edges. Note that the context is included
in the total `:chunk_num_seconds`. Defaults to 1/6 of
`:chunk_num_seconds`
* `:seed` - random seed to use when sampling. By default the current
timestamp is used
Expand Down
153 changes: 145 additions & 8 deletions lib/bumblebee/audio/speech_to_text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,22 @@ defmodule Bumblebee.Audio.SpeechToText do
%Text.GenerationConfig{} = generation_config,
opts \\ []
) do
opts = Keyword.validate!(opts, [:seed, :compile, defn_options: [], preallocate_params: false])
opts =
Keyword.validate!(opts, [
:chunk_num_seconds,
:context_num_seconds,
:seed,
:compile,
defn_options: [],
preallocate_params: false
])

%{model: model, params: params, spec: spec} = model_info

Shared.validate_architecture!(spec, [:for_conditional_generation])

chunk_num_seconds = opts[:chunk_num_seconds]
context_num_seconds = opts[:context_num_seconds]
preallocate_params = opts[:preallocate_params]
defn_options = opts[:defn_options]

Expand Down Expand Up @@ -68,15 +78,142 @@ defmodule Bumblebee.Audio.SpeechToText do
{:error, "expected a 1-dimensional tensor or {:file, path}, got: #{inspect(other)}"}
end)

inputs = Bumblebee.apply_featurizer(featurizer, inputs, defn_options: defn_options)
{Nx.Batch.concatenate([inputs]), multi?}
all_chunks =
for input <- inputs do
if chunk_num_seconds do
chunk_input(input, sampling_rate, chunk_num_seconds, context_num_seconds)
else
[input]
end
end

all_num_chunks = Enum.map(all_chunks, &length/1)

all_chunks = List.flatten(all_chunks)
inputs = Bumblebee.apply_featurizer(featurizer, all_chunks, defn_options: defn_options)
{Nx.Batch.concatenate([inputs]), {multi?, all_num_chunks}}
end)
|> Nx.Serving.client_postprocessing(fn {results, _metadata}, {multi?, all_num_chunks} ->
all_special_tokens = Bumblebee.Tokenizer.all_special_tokens(tokenizer)

sequences =
results
|> Bumblebee.Utils.Nx.to_list()
|> Enum.map(fn sequence ->
sequence
|> Enum.filter(fn token_id ->
if token = Bumblebee.Tokenizer.id_to_token(tokenizer, token_id) do
token not in all_special_tokens
end
end)
|> Nx.tensor()
end)

{outputs, []} =
Enum.map_reduce(all_num_chunks, sequences, fn num_chunks, sequences ->
{sequences, rest} = Enum.split(sequences, num_chunks)
token_ids = merge_overlapping_sequences(sequences)
text = Bumblebee.Tokenizer.decode(tokenizer, token_ids)
output = %{results: [%{text: normalize_text(text)}]}
{output, rest}
end)

Shared.normalize_output(outputs, multi?)
end)
|> Nx.Serving.client_postprocessing(fn {token_ids, _metadata}, multi? ->
decoded = Bumblebee.Tokenizer.decode(tokenizer, token_ids)
end

defp chunk_input(input, sampling_rate, chunk_num_seconds, context_num_seconds) do
context_num_seconds = context_num_seconds || chunk_num_seconds / 6

chunk_length = floor(chunk_num_seconds * sampling_rate)
context_left = floor(context_num_seconds * sampling_rate)
context_right = context_left

input_length = Nx.axis_size(input, 0)
step = chunk_length - context_left - context_right

0..(input_length - 1)//step
|> Enum.reduce_while([], fn chunk_start_idx, chunks ->
chunk_end_idx = chunk_start_idx + chunk_length

# All right contexts must be full, otherwise it is the last item
last? =
if context_right > 0 do
chunk_end_idx > input_length
else
chunk_end_idx >= input_length
end

chunk = input[chunk_start_idx..(min(chunk_end_idx, input_length) - 1)]
chunks = [chunk | chunks]

{if(last?, do: :halt, else: :cont), chunks}
end)
|> Enum.reverse()
end

defp merge_overlapping_sequences(sequences) do
# We have a number of consecutive, overlapping sequences and we
# want to merge them into a single sequence. To merge a pair of
# consecutive sequences we slide the sequences and compare the
# overlap:
#
# abcd (left)
# cde (right)
# => compare c = d
#
# abcd (left)
# cde (right)
# => compare cd = cd
#
# We find the best alignment, then cut the overlap in half and
# concatenate the left an right part accordingly. In the example
# above, we would use the second alignment, taking `abc` from the
# left sequence and `de` from the right one.

{[left_sequence], right_sequences} = Enum.split(sequences, 1)

{acc, left_sequence} =
for right_sequence <- right_sequences, reduce: {[], left_sequence} do
{acc, left_sequence} ->
left_length = Nx.size(left_sequence)
right_length = Nx.size(right_sequence)

{_max_match_score, overlap_indices} =
for i <- 1..(left_length + right_length - 1),
reduce: {0.0, {left_length, left_length, 0, 0}} do
{max_match_score, overlap_indices} ->
left_start = max(0, left_length - i)
left_stop = min(left_length, left_length + right_length - i)
left_overlap = left_sequence[left_start..(left_stop - 1)]

right_start = max(0, i - left_length)
right_stop = min(right_length, i)
right_overlap = right_sequence[right_start..(right_stop - 1)]

num_matches = Nx.equal(left_overlap, right_overlap) |> Nx.sum() |> Nx.to_number()

# Epsilon to favor long perfect matches
eps = i / 10000.0
match_score = num_matches / i + eps

if num_matches > 1 and match_score > max_match_score do
overlap_indices = {left_start, left_stop, right_start, right_stop}
{match_score, overlap_indices}
else
{max_match_score, overlap_indices}
end
end

# Cut in the middle of the overlap
{left_start, left_stop, right_start, right_stop} = overlap_indices
left_mid = div(left_stop + left_start, 2)
right_mid = div(right_stop + right_start, 2)
{[left_sequence[0..(left_mid - 1)] | acc], right_sequence[right_mid..-1//1]}
end

decoded
|> Enum.map(&%{results: [%{text: normalize_text(&1)}]})
|> Shared.normalize_output(multi?)
Enum.reduce([left_sequence | acc], [], fn sequence, acc ->
Nx.to_flat_list(sequence) ++ acc
end)
end

Expand Down
39 changes: 33 additions & 6 deletions lib/bumblebee/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -340,16 +340,23 @@ defmodule Bumblebee.Shared do
def load_special_tokens(special_tokens, data) do
for {key, default_token} <- special_tokens, into: %{} do
token =
case data["#{key}_token"] do
nil -> default_token
%{"content" => token} when is_binary(token) -> token
token when is_binary(token) -> token
if token = data["#{key}_token"] do
load_token(token)
else
default_token
end

{key, token}
end
end

@doc """
Normalizes a persisted token into token string.
"""
@spec load_token(String.t() | map()) :: String.t()
def load_token(token) when is_binary(token), do: token
def load_token(%{"content" => token}) when is_binary(token), do: token

@doc """
Converts logits to scores as per the given scores function.
Expand Down Expand Up @@ -427,7 +434,8 @@ defmodule Bumblebee.Shared do
quote do
defstruct [
:tokenizer,
special_tokens: unquote(special_tokens)
special_tokens: unquote(special_tokens),
additional_special_tokens: []
]

@behaviour Bumblebee.Tokenizer
Expand Down Expand Up @@ -457,6 +465,11 @@ defmodule Bumblebee.Shared do
tokenizer.special_tokens
end

@impl true
def additional_special_tokens(tokenizer) do
tokenizer.additional_special_tokens
end

defimpl Bumblebee.HuggingFace.Transformers.Config do
def load(tokenizer, %{
"tokenizer_file" => path,
Expand All @@ -467,7 +480,21 @@ defmodule Bumblebee.Shared do
special_tokens =
Bumblebee.Shared.load_special_tokens(tokenizer.special_tokens, special_tokens_map)

%{tokenizer | tokenizer: native_tokenizer, special_tokens: special_tokens}
additional_special_tokens =
case special_tokens_map do
%{"additional_special_tokens" => tokens} ->
for token <- tokens, do: Bumblebee.Shared.load_token(token), into: MapSet.new()

_ ->
[]
end

%{
tokenizer
| tokenizer: native_tokenizer,
special_tokens: special_tokens,
additional_special_tokens: additional_special_tokens
}
end
end
end
Expand Down
16 changes: 16 additions & 0 deletions lib/bumblebee/tokenizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ defmodule Bumblebee.Tokenizer do
"""
@callback special_tokens(t()) :: %{special_token_type() => token()}

@doc """
Returns a list with extra special tokens, in addition to the named
`special_tokens/1`.
"""
@callback additional_special_tokens(t()) :: MapSet.t(token())

@doc """
Decodes a list of token ids into a sentence.
"""
Expand Down Expand Up @@ -111,4 +117,14 @@ defmodule Bumblebee.Tokenizer do
token_to_id(tokenizer, token)
end
end

@doc """
Returns all special tokens, including any extra tokens.
"""
@spec all_special_tokens(t()) :: list(token_id())
def all_special_tokens(%module{} = tokenizer) do
special_tokens = module.special_tokens(tokenizer)
additional_special_tokens = module.additional_special_tokens(tokenizer)
for {_type, token} <- special_tokens, do: token, into: additional_special_tokens
end
end
25 changes: 24 additions & 1 deletion test/bumblebee/audio/speech_to_text_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ defmodule Bumblebee.Audio.SpeechToTextTest do
@audio_dir Path.expand("../../fixtures/audio", __DIR__)

describe "integration" do
test "returns top scored labels" do
test "generates transcription" do
{:ok, model_info} = Bumblebee.load_model({:hf, "openai/whisper-tiny"})
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, "openai/whisper-tiny"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/whisper-tiny"})
Expand All @@ -29,5 +29,28 @@ defmodule Bumblebee.Audio.SpeechToTextTest do

assert %{results: [%{text: "Tower of strength."}]} = Nx.Serving.run(serving, audio)
end

test "long-form transcription with chunking" do
{:ok, model_info} = Bumblebee.load_model({:hf, "openai/whisper-tiny"})
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, "openai/whisper-tiny"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/whisper-tiny"})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "openai/whisper-tiny"})

serving =
Bumblebee.Audio.speech_to_text(model_info, featurizer, tokenizer, generation_config,
chunk_num_seconds: 30,
defn_options: [compiler: EXLA]
)

audio =
Path.join(@audio_dir, "librivox/46s_pcm_f32le_16000.bin")
|> File.read!()
|> Nx.from_binary(:f32)

transcription =
"An awakening from the book of Irish poetry part 1, read for LibriVox.org by Sonja. An awakening by Alice Pirlong. O spring will wake in the heart of me with the rapture of blown violets, when the green bud quickens on every tree to spring will wake in the heart of me, and queues of honey will reign on the lee, tangling the grasses in silver nets. Yes, spring will awaken the heart of me with the rapture of blown violets. End of an awakening, this recording is in the public domain."

assert %{results: [%{text: ^transcription}]} = Nx.Serving.run(serving, audio)
end
end
end
6 changes: 0 additions & 6 deletions test/fixtures/audio/common_voice/generate.sh

This file was deleted.

2 changes: 0 additions & 2 deletions test/fixtures/audio/common_voice/info.md
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
Source: https://huggingface.co/datasets/common_voice

Decoded binary formats generated using `generate.sh`.
8 changes: 8 additions & 0 deletions test/fixtures/audio/generate.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash

cd "$(dirname "$0")"

for source in $(ls **/*.{wav,mp3}); do
name="${source%.*}"
ffmpeg -i $source -ac 1 -ar 16000 -f f32le -hide_banner -loglevel quiet "${name}_pcm_f32le_16000.bin"
done
Binary file added test/fixtures/audio/librivox/46s.mp3
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions test/fixtures/audio/librivox/info.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Source: https://librivox.org/the-book-of-irish-poetry-by-various

0 comments on commit b764859

Please sign in to comment.