Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Whisper timestamps and task/language configuration #238

Merged
merged 4 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/phoenix/speech_to_text.exs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ end
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "openai/whisper-tiny"})

serving =
Bumblebee.Audio.speech_to_text(model_info, featurizer, tokenizer, generation_config,
Bumblebee.Audio.speech_to_text_whisper(model_info, featurizer, tokenizer, generation_config,
josevalim marked this conversation as resolved.
Show resolved Hide resolved
compile: [batch_size: 10],
defn_options: [compiler: EXLA]
)
Expand Down
57 changes: 48 additions & 9 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,14 @@ defmodule Bumblebee do

See `Bumblebee.Text.GenerationConfig` for all the available options.

## Options

* `:spec_module` - the model specification module. By default it
is inferred from the configuration file, if that is not possible,
it must be specified explicitly. Some models have extra options
related to generations and those are loaded into a separate
struct, stored under the `:extra_config` attribute

## Examples

{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "gpt2"})
Expand All @@ -786,20 +794,51 @@ defmodule Bumblebee do
"""
@spec load_generation_config(repository()) ::
{:ok, Bumblebee.Text.GenerationConfig.t()} | {:error, String.t()}
def load_generation_config(repository) do
def load_generation_config(repository, opts \\ []) do
opts = Keyword.validate!(opts, [:spec_module])

repository = normalize_repository!(repository)

file_result =
with {:error, _} <- download(repository, @generation_filename) do
download(repository, @config_filename)
with {:ok, path} <- download(repository, @config_filename),
{:ok, spec_data} <- decode_config(path) do
spec_module = opts[:spec_module]

{inferred_module, inference_error} =
case infer_model_type(spec_data) do
{:ok, module, _architecture} -> {module, nil}
{:error, error} -> {nil, error}
end

spec_module = spec_module || inferred_module

unless spec_module do
raise "#{inference_error}, please specify the :spec_module option"
end

with {:ok, path} <- file_result,
{:ok, generation_data} <- decode_config(path) do
config = struct!(Bumblebee.Text.GenerationConfig)
config = HuggingFace.Transformers.Config.load(config, generation_data)
generation_data_result =
case download(repository, @generation_filename) do
{:ok, path} -> decode_config(path)
# Fallback to the spec data, since it used to include
# generation attributes
{:error, _} -> {:ok, spec_data}
end

with {:ok, generation_data} <- generation_data_result do
config = struct!(Bumblebee.Text.GenerationConfig)
config = HuggingFace.Transformers.Config.load(config, generation_data)

extra_config_module = Bumblebee.Text.Generation.extra_config_module(struct!(spec_module))

{:ok, config}
extra_config =
if extra_config_module do
extra_config = struct!(extra_config_module)
HuggingFace.Transformers.Config.load(extra_config, generation_data)
end

config = %{config | extra_config: extra_config}

{:ok, config}
end
end
end

Expand Down
107 changes: 96 additions & 11 deletions lib/bumblebee/audio.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ defmodule Bumblebee.Audio do
High-level tasks related to audio processing.
"""

# TODO: remove in v0.5
@deprecated "Use Bumblebee.Audio.speech_to_text_whisper/5 instead."
def speech_to_text(model_info, featurizer, tokenizer, generation_config, opts \\ []) do
speech_to_text_whisper(model_info, featurizer, tokenizer, generation_config, opts)
end

@typedoc """
A term representing audio.

Expand All @@ -14,15 +20,23 @@ defmodule Bumblebee.Audio do
requires `ffmpeg` installed)

"""
@type speech_to_text_input :: Nx.t() | {:file, String.t()}
@type speech_to_text_output :: %{results: list(speech_to_text_result())}
@type speech_to_text_result :: %{text: String.t()}
@type speech_to_text_whisper_input :: Nx.t() | {:file, String.t()}
@type speech_to_text_whisper_output :: %{results: list(speech_to_text_whisper_result())}
@type speech_to_text_whisper_result :: %{
text: String.t(),
chunks:
list(%{
text: String.t(),
start_timestamp_seconds: number() | nil,
end_timestamp_seconds: number() | nil
})
}

@doc """
Builds serving for speech-to-text generation.
Builds serving for speech-to-text generation with Whisper models.

The serving accepts `t:speech_to_text_input/0` and returns
`t:speech_to_text_output/0`. A list of inputs is also supported.
The serving accepts `t:speech_to_text_whisper_input/0` and returns
`t:speech_to_text_whisper_output/0`. A list of inputs is also supported.

## Options

Expand All @@ -39,6 +53,23 @@ defmodule Bumblebee.Audio do
in the total `:chunk_num_seconds`. Defaults to 1/6 of
`:chunk_num_seconds`

* `:language` - the language of the speech, when known upfront.
Should be given as ISO alpha-2 code as string. By default no
language is assumed and it is inferred from the input

* `:task` - either of:

* `:transcribe` (default) - generate audio transcription in
the same language as the speech

* `:translate` - generate translation of the given speech in
English

* `:timestamps` - when set, the model predicts timestamps and each
annotated segment becomes an output chunk. Currently the only
supported value is `:segments`, the length of each segment is up
to the model

* `:seed` - random seed to use when sampling. By default the current
timestamp is used

Expand Down Expand Up @@ -68,21 +99,75 @@ defmodule Bumblebee.Audio do
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "openai/whisper-tiny"})

serving =
Bumblebee.Audio.speech_to_text(whisper, featurizer, tokenizer, generation_config,
Bumblebee.Audio.speech_to_text_whisper(whisper, featurizer, tokenizer, generation_config,
defn_options: [compiler: EXLA]
)

Nx.Serving.run(serving, {:file, "/path/to/audio.wav"})
#=> %{results: [%{text: "There is a cat outside the window."}]}
#=> %{
#=> results: [
#=> %{
#=> chunks: [
#=> %{
#=> text: " There is a cat outside the window.",
#=> start_timestamp_seconds: nil,
#=> end_timestamp_seconds: nil
#=> }
#=> ],
#=> text: "There is a cat outside the window."
#=> }
#=> ]
#=> }

And with timestamps:

serving =
Bumblebee.Audio.speech_to_text_whisper(whisper, featurizer, tokenizer, generation_config,
defn_options: [compiler: EXLA],
chunk_num_seconds: 30,
timestamps: :segments
)

Nx.Serving.run(serving, {:file, "/path/to/colouredstars_08_mathers_128kb.mp3"})
#=> %{
#=> results: [
#=> %{
#=> chunks: [
#=> %{
#=> text: " Such an eight of colored stars, versions of fifty isiatic love poems by Edward Powis-Mathers.",
#=> start_timestamp_seconds: 0.0,
#=> end_timestamp_seconds: 7.0
#=> },
#=> %{
#=> text: " This the revocs recording is in the public domain. Doubt. From the Japanese of Hori-Kawa,",
#=> start_timestamp_seconds: 7.0,
#=> end_timestamp_seconds: 14.0
#=> },
#=> %{
#=> text: " will he be true to me that I do not know. But since the dawn, I have had as much disorder in my thoughts as in my black hair, and of doubt.",
#=> start_timestamp_seconds: 14.0,
#=> end_timestamp_seconds: 27.0
#=> }
#=> ],
#=> text: "Such an eight of colored stars, versions of fifty isiatic love poems by Edward Powis-Mathers. This the revocs recording is in the public domain. Doubt. From the Japanese of Hori-Kawa, will he be true to me that I do not know. But since the dawn, I have had as much disorder in my thoughts as in my black hair, and of doubt."
#=> }
#=> ]
#=> }

"""
@spec speech_to_text(
@spec speech_to_text_whisper(
Bumblebee.model_info(),
Bumblebee.Featurizer.t(),
Bumblebee.Tokenizer.t(),
Bumblebee.Text.GenerationConfig.t(),
keyword()
) :: Nx.Serving.t()
defdelegate speech_to_text(model_info, featurizer, tokenizer, generation_config, opts \\ []),
to: Bumblebee.Audio.SpeechToText
defdelegate speech_to_text_whisper(
model_info,
featurizer,
tokenizer,
generation_config,
opts \\ []
),
to: Bumblebee.Audio.SpeechToTextWhisper
end
Loading
Loading