diff --git a/lib/bumblebee/audio.ex b/lib/bumblebee/audio.ex index 61a30273..e37e8eee 100644 --- a/lib/bumblebee/audio.ex +++ b/lib/bumblebee/audio.ex @@ -26,6 +26,19 @@ defmodule Bumblebee.Audio do ## Options + * `:chunk_num_seconds` - enables long-form transcription by splitting + the input into chunks of the given length. Models generally have + a limit on the input length, so by chunking we can feed smaller + bits into the model, then merge the individual outputs into a + single result at the end. By default chunking is disabled + + * `:context_num_seconds` - specifies the amount of overlap between + chunks on both sides of split points. The context is effectively + discarded when merging the chunks at the end, but it improves + the results at the chunk edges. Note that the context is included + in the total `:chunk_num_seconds`. Defaults to 1/6 of + `:chunk_num_seconds` + * `:seed` - random seed to use when sampling. By default the current timestamp is used diff --git a/lib/bumblebee/audio/speech_to_text.ex b/lib/bumblebee/audio/speech_to_text.ex index 4392e1d8..f6915ffa 100644 --- a/lib/bumblebee/audio/speech_to_text.ex +++ b/lib/bumblebee/audio/speech_to_text.ex @@ -11,12 +11,22 @@ defmodule Bumblebee.Audio.SpeechToText do %Text.GenerationConfig{} = generation_config, opts \\ [] ) do - opts = Keyword.validate!(opts, [:seed, :compile, defn_options: [], preallocate_params: false]) + opts = + Keyword.validate!(opts, [ + :chunk_num_seconds, + :context_num_seconds, + :seed, + :compile, + defn_options: [], + preallocate_params: false + ]) %{model: model, params: params, spec: spec} = model_info Shared.validate_architecture!(spec, [:for_conditional_generation]) + chunk_num_seconds = opts[:chunk_num_seconds] + context_num_seconds = opts[:context_num_seconds] preallocate_params = opts[:preallocate_params] defn_options = opts[:defn_options] @@ -68,15 +78,142 @@ defmodule Bumblebee.Audio.SpeechToText do {:error, "expected a 1-dimensional tensor or {:file, path}, got: #{inspect(other)}"} end) - inputs = Bumblebee.apply_featurizer(featurizer, inputs, defn_options: defn_options) - {Nx.Batch.concatenate([inputs]), multi?} + all_chunks = + for input <- inputs do + if chunk_num_seconds do + chunk_input(input, sampling_rate, chunk_num_seconds, context_num_seconds) + else + [input] + end + end + + all_num_chunks = Enum.map(all_chunks, &length/1) + + all_chunks = List.flatten(all_chunks) + inputs = Bumblebee.apply_featurizer(featurizer, all_chunks, defn_options: defn_options) + {Nx.Batch.concatenate([inputs]), {multi?, all_num_chunks}} + end) + |> Nx.Serving.client_postprocessing(fn {results, _metadata}, {multi?, all_num_chunks} -> + all_special_tokens = Bumblebee.Tokenizer.all_special_tokens(tokenizer) + + sequences = + results + |> Bumblebee.Utils.Nx.to_list() + |> Enum.map(fn sequence -> + sequence + |> Enum.filter(fn token_id -> + if token = Bumblebee.Tokenizer.id_to_token(tokenizer, token_id) do + token not in all_special_tokens + end + end) + |> Nx.tensor() + end) + + {outputs, []} = + Enum.map_reduce(all_num_chunks, sequences, fn num_chunks, sequences -> + {sequences, rest} = Enum.split(sequences, num_chunks) + token_ids = merge_overlapping_sequences(sequences) + text = Bumblebee.Tokenizer.decode(tokenizer, token_ids) + output = %{results: [%{text: normalize_text(text)}]} + {output, rest} + end) + + Shared.normalize_output(outputs, multi?) end) - |> Nx.Serving.client_postprocessing(fn {token_ids, _metadata}, multi? -> - decoded = Bumblebee.Tokenizer.decode(tokenizer, token_ids) + end + + defp chunk_input(input, sampling_rate, chunk_num_seconds, context_num_seconds) do + context_num_seconds = context_num_seconds || chunk_num_seconds / 6 + + chunk_length = floor(chunk_num_seconds * sampling_rate) + context_left = floor(context_num_seconds * sampling_rate) + context_right = context_left + + input_length = Nx.axis_size(input, 0) + step = chunk_length - context_left - context_right + + 0..(input_length - 1)//step + |> Enum.reduce_while([], fn chunk_start_idx, chunks -> + chunk_end_idx = chunk_start_idx + chunk_length + + # All right contexts must be full, otherwise it is the last item + last? = + if context_right > 0 do + chunk_end_idx > input_length + else + chunk_end_idx >= input_length + end + + chunk = input[chunk_start_idx..(min(chunk_end_idx, input_length) - 1)] + chunks = [chunk | chunks] + + {if(last?, do: :halt, else: :cont), chunks} + end) + |> Enum.reverse() + end + + defp merge_overlapping_sequences(sequences) do + # We have a number of consecutive, overlapping sequences and we + # want to merge them into a single sequence. To merge a pair of + # consecutive sequences we slide the sequences and compare the + # overlap: + # + # abcd (left) + # cde (right) + # => compare c = d + # + # abcd (left) + # cde (right) + # => compare cd = cd + # + # We find the best alignment, then cut the overlap in half and + # concatenate the left an right part accordingly. In the example + # above, we would use the second alignment, taking `abc` from the + # left sequence and `de` from the right one. + + {[left_sequence], right_sequences} = Enum.split(sequences, 1) + + {acc, left_sequence} = + for right_sequence <- right_sequences, reduce: {[], left_sequence} do + {acc, left_sequence} -> + left_length = Nx.size(left_sequence) + right_length = Nx.size(right_sequence) + + {_max_match_score, overlap_indices} = + for i <- 1..(left_length + right_length - 1), + reduce: {0.0, {left_length, left_length, 0, 0}} do + {max_match_score, overlap_indices} -> + left_start = max(0, left_length - i) + left_stop = min(left_length, left_length + right_length - i) + left_overlap = left_sequence[left_start..(left_stop - 1)] + + right_start = max(0, i - left_length) + right_stop = min(right_length, i) + right_overlap = right_sequence[right_start..(right_stop - 1)] + + num_matches = Nx.equal(left_overlap, right_overlap) |> Nx.sum() |> Nx.to_number() + + # Epsilon to favor long perfect matches + eps = i / 10000.0 + match_score = num_matches / i + eps + + if num_matches > 1 and match_score > max_match_score do + overlap_indices = {left_start, left_stop, right_start, right_stop} + {match_score, overlap_indices} + else + {max_match_score, overlap_indices} + end + end + + # Cut in the middle of the overlap + {left_start, left_stop, right_start, right_stop} = overlap_indices + left_mid = div(left_stop + left_start, 2) + right_mid = div(right_stop + right_start, 2) + {[left_sequence[0..(left_mid - 1)] | acc], right_sequence[right_mid..-1//1]} + end - decoded - |> Enum.map(&%{results: [%{text: normalize_text(&1)}]}) - |> Shared.normalize_output(multi?) + Enum.reduce([left_sequence | acc], [], fn sequence, acc -> + Nx.to_flat_list(sequence) ++ acc end) end diff --git a/lib/bumblebee/audio/whisper_featurizer.ex b/lib/bumblebee/audio/whisper_featurizer.ex index 2885b8cd..c92ac21d 100644 --- a/lib/bumblebee/audio/whisper_featurizer.ex +++ b/lib/bumblebee/audio/whisper_featurizer.ex @@ -56,7 +56,7 @@ defmodule Bumblebee.Audio.WhisperFeaturizer do def apply(featurizer, raw_samples, defn_options) do max_length = featurizer.num_seconds * featurizer.sampling_rate - transformed_samples = + samples = for sample <- List.wrap(raw_samples) do unless Nx.rank(sample) == 1 do raise ArgumentError, @@ -64,17 +64,20 @@ defmodule Bumblebee.Audio.WhisperFeaturizer do end pad_size = max_length - Nx.axis_size(sample, 0) - sample = Nx.pad(sample, featurizer.padding_value, [{0, pad_size, 0}]) - - Nx.Defn.jit(&extract_fbank_features/2, defn_options).(sample, - fft_length: featurizer.fft_length, - sampling_rate: featurizer.sampling_rate, - mel_bins: featurizer.feature_size, - hop_length: featurizer.hop_length - ) + Nx.pad(sample, featurizer.padding_value, [{0, pad_size, 0}]) end - samples = Nx.stack(transformed_samples) + samples = samples |> Nx.stack() |> Nx.vectorize(:batch) + + samples = + Nx.Defn.jit(&extract_fbank_features/2, defn_options).(samples, + fft_length: featurizer.fft_length, + sampling_rate: featurizer.sampling_rate, + mel_bins: featurizer.feature_size, + hop_length: featurizer.hop_length + ) + + samples = Nx.devectorize(samples) %{"input_features" => samples} end @@ -92,6 +95,8 @@ defmodule Bumblebee.Audio.WhisperFeaturizer do window_padding: :reflect ) + stft = stft[0..-2//1] + # Magic numbers taken from the reference implementation. This yields # max_mel ~ 3016 frequency_spacing = 200.0 / 3 diff --git a/lib/bumblebee/shared.ex b/lib/bumblebee/shared.ex index c4f652dd..dbd72869 100644 --- a/lib/bumblebee/shared.ex +++ b/lib/bumblebee/shared.ex @@ -340,16 +340,23 @@ defmodule Bumblebee.Shared do def load_special_tokens(special_tokens, data) do for {key, default_token} <- special_tokens, into: %{} do token = - case data["#{key}_token"] do - nil -> default_token - %{"content" => token} when is_binary(token) -> token - token when is_binary(token) -> token + if token = data["#{key}_token"] do + load_token(token) + else + default_token end {key, token} end end + @doc """ + Normalizes a persisted token into token string. + """ + @spec load_token(String.t() | map()) :: String.t() + def load_token(token) when is_binary(token), do: token + def load_token(%{"content" => token}) when is_binary(token), do: token + @doc """ Converts logits to scores as per the given scores function. @@ -427,7 +434,8 @@ defmodule Bumblebee.Shared do quote do defstruct [ :tokenizer, - special_tokens: unquote(special_tokens) + special_tokens: unquote(special_tokens), + additional_special_tokens: [] ] @behaviour Bumblebee.Tokenizer @@ -457,6 +465,11 @@ defmodule Bumblebee.Shared do tokenizer.special_tokens end + @impl true + def additional_special_tokens(tokenizer) do + tokenizer.additional_special_tokens + end + defimpl Bumblebee.HuggingFace.Transformers.Config do def load(tokenizer, %{ "tokenizer_file" => path, @@ -467,7 +480,21 @@ defmodule Bumblebee.Shared do special_tokens = Bumblebee.Shared.load_special_tokens(tokenizer.special_tokens, special_tokens_map) - %{tokenizer | tokenizer: native_tokenizer, special_tokens: special_tokens} + additional_special_tokens = + case special_tokens_map do + %{"additional_special_tokens" => tokens} -> + for token <- tokens, do: Bumblebee.Shared.load_token(token), into: MapSet.new() + + _ -> + [] + end + + %{ + tokenizer + | tokenizer: native_tokenizer, + special_tokens: special_tokens, + additional_special_tokens: additional_special_tokens + } end end end diff --git a/lib/bumblebee/tokenizer.ex b/lib/bumblebee/tokenizer.ex index df6ab8ec..21318540 100644 --- a/lib/bumblebee/tokenizer.ex +++ b/lib/bumblebee/tokenizer.ex @@ -64,6 +64,12 @@ defmodule Bumblebee.Tokenizer do """ @callback special_tokens(t()) :: %{special_token_type() => token()} + @doc """ + Returns a list with extra special tokens, in addition to the named + `special_tokens/1`. + """ + @callback additional_special_tokens(t()) :: MapSet.t(token()) + @doc """ Decodes a list of token ids into a sentence. """ @@ -111,4 +117,14 @@ defmodule Bumblebee.Tokenizer do token_to_id(tokenizer, token) end end + + @doc """ + Returns all special tokens, including any extra tokens. + """ + @spec all_special_tokens(t()) :: list(token_id()) + def all_special_tokens(%module{} = tokenizer) do + special_tokens = module.special_tokens(tokenizer) + additional_special_tokens = module.additional_special_tokens(tokenizer) + for {_type, token} <- special_tokens, do: token, into: additional_special_tokens + end end diff --git a/mix.exs b/mix.exs index 65fe73e1..5cdf4a6b 100644 --- a/mix.exs +++ b/mix.exs @@ -32,12 +32,12 @@ defmodule Bumblebee.MixProject do [ {:axon, "~> 0.6.0", axon_opts()}, {:tokenizers, "~> 0.4"}, - {:nx, "~> 0.6.0"}, - {:exla, "~> 0.6.0", only: [:dev, :test]}, - {:torchx, "~> 0.6.0", only: [:dev, :test]}, - # {:nx, github: "elixir-nx/nx", sparse: "nx", override: true}, - # {:exla, github: "elixir-nx/nx", sparse: "exla", override: true, only: [:dev, :test]}, - # {:torchx, github: "elixir-nx/nx", sparse: "torchx", override: true, only: [:dev, :test]}, + # {:nx, "~> 0.6.0"}, + # {:exla, "~> 0.6.0", only: [:dev, :test]}, + # {:torchx, "~> 0.6.0", only: [:dev, :test]}, + {:nx, github: "elixir-nx/nx", sparse: "nx", override: true}, + {:exla, github: "elixir-nx/nx", sparse: "exla", override: true, only: [:dev, :test]}, + {:torchx, github: "elixir-nx/nx", sparse: "torchx", override: true, only: [:dev, :test]}, {:nx_image, "~> 0.1.0"}, {:unpickler, "~> 0.1.0"}, {:safetensors, "~> 0.1.1"}, @@ -48,7 +48,7 @@ defmodule Bumblebee.MixProject do {:stb_image, "~> 0.6.0", only: :test}, {:bypass, "~> 2.1", only: :test}, {:ex_doc, "~> 0.28", only: :dev, runtime: false}, - {:nx_signal, "~> 0.1.0"} + {:nx_signal, "~> 0.2.0"} ] end diff --git a/mix.lock b/mix.lock index 5428f949..75bf0b1b 100644 --- a/mix.lock +++ b/mix.lock @@ -8,20 +8,21 @@ "cowboy_telemetry": {:hex, :cowboy_telemetry, "0.4.0", "f239f68b588efa7707abce16a84d0d2acf3a0f50571f8bb7f56a15865aae820c", [:rebar3], [{:cowboy, "~> 2.7", [hex: :cowboy, repo: "hexpm", optional: false]}, {:telemetry, "~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "7d98bac1ee4565d31b62d59f8823dfd8356a169e7fcbb83831b8a5397404c9de"}, "cowlib": {:hex, :cowlib, "2.11.0", "0b9ff9c346629256c42ebe1eeb769a83c6cb771a6ee5960bd110ab0b9b872063", [:make, :rebar3], [], "hexpm", "2b3e9da0b21c4565751a6d4901c20d1b4cc25cbb7fd50d91d2ab6dd287bc86a9"}, "decimal": {:hex, :decimal, "2.1.1", "5611dca5d4b2c3dd497dec8f68751f1f1a54755e8ed2a966c2633cf885973ad6", [:mix], [], "hexpm", "53cfe5f497ed0e7771ae1a475575603d77425099ba5faef9394932b35020ffcc"}, - "dll_loader_helper": {:hex, :dll_loader_helper, "1.0.0", "8b960745743845ab1e3559eb3ebace7cc9c621fb61965603fa2c3499ec1c22d2", [:make, :mix], [{:castore, ">= 0.0.0", [hex: :castore, repo: "hexpm", optional: false]}, {:cc_precompiler, "~> 0.1", [hex: :cc_precompiler, repo: "hexpm", optional: false]}], "hexpm", "3f1e257072f57ce502c00b250f01210217d690203dacfa7c0311ceeec91d897c"}, + "dll_loader_helper": {:hex, :dll_loader_helper, "1.1.0", "e7d015e980942a0d67e306827ec907e7e853a21186bd92bb968d986698591a0f", [:mix], [{:dll_loader_helper_beam, "~> 1.1", [hex: :dll_loader_helper_beam, repo: "hexpm", optional: false]}], "hexpm", "2b6c11ee7bb48f6a132ce8f872202f9e828c019988da1e2d40ad41496195df0c"}, + "dll_loader_helper_beam": {:hex, :dll_loader_helper_beam, "1.1.0", "d51232663985dbc998c59b5d080feecd5398d5b75a9f0293a9855db774c2684d", [:rebar3], [], "hexpm", "aa85d0d0e9398916a80b2fd751885877934ae3ea008288f99ff829c0b8ef1f55"}, "earmark_parser": {:hex, :earmark_parser, "1.4.33", "3c3fd9673bb5dcc9edc28dd90f50c87ce506d1f71b70e3de69aa8154bc695d44", [:mix], [], "hexpm", "2d526833729b59b9fdb85785078697c72ac5e5066350663e5be6a1182da61b8f"}, "elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"}, "ex_doc": {:hex, :ex_doc, "0.30.5", "aa6da96a5c23389d7dc7c381eba862710e108cee9cfdc629b7ec021313900e9e", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "88a1e115dcb91cefeef7e22df4a6ebbe4634fbf98b38adcbc25c9607d6d9d8e6"}, - "exla": {:hex, :exla, "0.6.0", "af63e45ce41ad25630967923147d14292a0cc48e507b8a3cf3bf3d5483099a28", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.6.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.5.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "5f6a4a105ea9ab207b9aa4de5a294730e2bfe9639f4b8d37a7c00da131090d7a"}, + "exla": {:git, "https://github.com/elixir-nx/nx.git", "0865ecd376dba75ddc0d604fa0bcfa8e74a0ff28", [sparse: "exla"]}, "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, "mime": {:hex, :mime, "2.0.3", "3676436d3d1f7b81b5a2d2bd8405f412c677558c81b1c92be58c00562bb59095", [:mix], [], "hexpm", "27a30bf0db44d25eecba73755acf4068cbfe26a4372f9eb3e4ea3a45956bff6b"}, "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, - "nx": {:hex, :nx, "0.6.0", "37c86eae824125a7e298dd1ee896953d9d671ce3630dcff74c77db17d734a85f", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "e1ad3cc70a5828a1aedb156b71e90863d9623a2dc9b35a5588f8627a07ee6cb4"}, + "nx": {:git, "https://github.com/elixir-nx/nx.git", "0865ecd376dba75ddc0d604fa0bcfa8e74a0ff28", [sparse: "nx"]}, "nx_image": {:hex, :nx_image, "0.1.1", "69cf0d2fd873d12b028583aa49b5e0a25f6aca307afc337a5d871851a20fba1d", [:mix], [{:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "55c8206a822237f6027168f11214e3887263c5b8a1f8e0634eea82c96e5093e3"}, - "nx_signal": {:hex, :nx_signal, "0.1.0", "403ac73140e2f368e827e0aca1a3035abaf6d890b00376742b359a6838e00d7f", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "1c68f2f0d186700819287f37ee6154a11e06bf5dbb30b73fcc92776293309a05"}, + "nx_signal": {:hex, :nx_signal, "0.2.0", "e1ca0318877b17c81ce8906329f5125f1e2361e4c4235a5baac8a95ee88ea98e", [:mix], [{:nx, "~> 0.6", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "7247e5e18a177a59c4cb5355952900c62fdeadeb2bad02a9a34237b68744e2bb"}, "plug": {:hex, :plug, "1.14.2", "cff7d4ec45b4ae176a227acd94a7ab536d9b37b942c8e8fa6dfc0fff98ff4d80", [:mix], [{:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_crypto, "~> 1.1.1 or ~> 1.2", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.3 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "842fc50187e13cf4ac3b253d47d9474ed6c296a8732752835ce4a86acdf68d13"}, "plug_cowboy": {:hex, :plug_cowboy, "2.6.1", "9a3bbfceeb65eff5f39dab529e5cd79137ac36e913c02067dba3963a26efe9b2", [:mix], [{:cowboy, "~> 2.7", [hex: :cowboy, repo: "hexpm", optional: false]}, {:cowboy_telemetry, "~> 0.3", [hex: :cowboy_telemetry, repo: "hexpm", optional: false]}, {:plug, "~> 1.14", [hex: :plug, repo: "hexpm", optional: false]}], "hexpm", "de36e1a21f451a18b790f37765db198075c25875c64834bcc82d90b309eb6613"}, "plug_crypto": {:hex, :plug_crypto, "1.2.5", "918772575e48e81e455818229bf719d4ab4181fcbf7f85b68a35620f78d89ced", [:mix], [], "hexpm", "26549a1d6345e2172eb1c233866756ae44a9609bd33ee6f99147ab3fd87fd842"}, @@ -33,7 +34,7 @@ "stb_image": {:hex, :stb_image, "0.6.2", "d680a418416b1d778231d1d16151be3474d187e8505e1bd524aa0d08d2de094f", [:make, :mix], [{:cc_precompiler, "~> 0.1.0", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7.0", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "231ad012f649dd2bd5ef99e9171e814f3235e8f7c45009355789ac4836044a39"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, "tokenizers": {:hex, :tokenizers, "0.4.0", "140283ca74a971391ddbd83cd8cbdb9bd03736f37a1b6989b82d245a95e1eb97", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "ef1a9824f5a893cd3b831c0e5b3d72caa250d2ec462035cc6afef6933b13a82e"}, - "torchx": {:hex, :torchx, "0.6.0", "e4a5f545e245c15aceeafcf9f22ac2ae0a87720c4a6b2f132e9909635f434e93", [:make, :mix], [{:dll_loader_helper, "~> 0.1 or ~> 1.0", [hex: :dll_loader_helper, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.6.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "35365dc51ee28dc86ca87c150dd3869bc83b207b2574bb2310c1be39e3867550"}, + "torchx": {:git, "https://github.com/elixir-nx/nx.git", "0865ecd376dba75ddc0d604fa0bcfa8e74a0ff28", [sparse: "torchx"]}, "unpickler": {:hex, :unpickler, "0.1.0", "c2262c0819e6985b761e7107546cef96a485f401816be5304a65fdd200d5bd6a", [:mix], [], "hexpm", "e2b3f61e62406187ac52afead8a63bfb4e49394028993f3c4c42712743cab79e"}, "unzip": {:hex, :unzip, "0.8.0", "ee21d87c21b01567317387dab4228ac570ca15b41cfc221a067354cbf8e68c4d", [:mix], [], "hexpm", "ffa67a483efcedcb5876971a50947222e104d5f8fea2c4a0441e6f7967854827"}, "xla": {:hex, :xla, "0.5.0", "fb8a02c02e5a4f4531fbf18a90c325e471037f983f0115d23f510e7dd9a6aa65", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "571ac797a4244b8ba8552ed0295a54397bd896708be51e4da6cbb784f6678061"}, diff --git a/test/bumblebee/audio/speech_to_text_test.exs b/test/bumblebee/audio/speech_to_text_test.exs index c31e8e40..66dcf923 100644 --- a/test/bumblebee/audio/speech_to_text_test.exs +++ b/test/bumblebee/audio/speech_to_text_test.exs @@ -8,7 +8,7 @@ defmodule Bumblebee.Audio.SpeechToTextTest do @audio_dir Path.expand("../../fixtures/audio", __DIR__) describe "integration" do - test "returns top scored labels" do + test "generates transcription" do {:ok, model_info} = Bumblebee.load_model({:hf, "openai/whisper-tiny"}) {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "openai/whisper-tiny"}) {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/whisper-tiny"}) @@ -29,5 +29,28 @@ defmodule Bumblebee.Audio.SpeechToTextTest do assert %{results: [%{text: "Tower of strength."}]} = Nx.Serving.run(serving, audio) end + + test "long-form transcription with chunking" do + {:ok, model_info} = Bumblebee.load_model({:hf, "openai/whisper-tiny"}) + {:ok, featurizer} = Bumblebee.load_featurizer({:hf, "openai/whisper-tiny"}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/whisper-tiny"}) + {:ok, generation_config} = Bumblebee.load_generation_config({:hf, "openai/whisper-tiny"}) + + serving = + Bumblebee.Audio.speech_to_text(model_info, featurizer, tokenizer, generation_config, + chunk_num_seconds: 30, + defn_options: [compiler: EXLA] + ) + + audio = + Path.join(@audio_dir, "librivox/46s_pcm_f32le_16000.bin") + |> File.read!() + |> Nx.from_binary(:f32) + + transcription = + "An awakening from the book of Irish poetry part 1, read for LibriVox.org by Sonja. An awakening by Alice Pirlong. O spring will wake in the heart of me with the rapture of blown violets, when the green bud quickens on every tree to spring will wake in the heart of me, and queues of honey will reign on the lee, tangling the grasses in silver nets. Yes, spring will awaken the heart of me with the rapture of blown violets. End of an awakening, this recording is in the public domain." + + assert %{results: [%{text: ^transcription}]} = Nx.Serving.run(serving, audio) + end end end diff --git a/test/fixtures/audio/common_voice/generate.sh b/test/fixtures/audio/common_voice/generate.sh deleted file mode 100755 index b6ffcd7c..00000000 --- a/test/fixtures/audio/common_voice/generate.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash - -for source in $(ls *.wav); do - id="${source%.wav}" - ffmpeg -i "${id}.wav" -ac 1 -ar 16000 -f f32le -hide_banner -loglevel quiet "${id}_pcm_f32le_16000.bin" -done diff --git a/test/fixtures/audio/common_voice/info.md b/test/fixtures/audio/common_voice/info.md index c09d9c60..b18b013b 100644 --- a/test/fixtures/audio/common_voice/info.md +++ b/test/fixtures/audio/common_voice/info.md @@ -1,3 +1 @@ Source: https://huggingface.co/datasets/common_voice - -Decoded binary formats generated using `generate.sh`. diff --git a/test/fixtures/audio/generate.sh b/test/fixtures/audio/generate.sh new file mode 100755 index 00000000..de66ac12 --- /dev/null +++ b/test/fixtures/audio/generate.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +cd "$(dirname "$0")" + +for source in $(ls **/*.{wav,mp3}); do + name="${source%.*}" + ffmpeg -i $source -ac 1 -ar 16000 -f f32le -hide_banner -loglevel quiet "${name}_pcm_f32le_16000.bin" +done diff --git a/test/fixtures/audio/librivox/46s.mp3 b/test/fixtures/audio/librivox/46s.mp3 new file mode 100644 index 00000000..0b875cf1 Binary files /dev/null and b/test/fixtures/audio/librivox/46s.mp3 differ diff --git a/test/fixtures/audio/librivox/46s_pcm_f32le_16000.bin b/test/fixtures/audio/librivox/46s_pcm_f32le_16000.bin new file mode 100644 index 00000000..f75aeab4 Binary files /dev/null and b/test/fixtures/audio/librivox/46s_pcm_f32le_16000.bin differ diff --git a/test/fixtures/audio/librivox/info.md b/test/fixtures/audio/librivox/info.md new file mode 100644 index 00000000..3083bdf6 --- /dev/null +++ b/test/fixtures/audio/librivox/info.md @@ -0,0 +1 @@ +Source: https://librivox.org/the-book-of-irish-poetry-by-various