Skip to content

Commit

Permalink
Migrate optional outputs to use global layer options (#360)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Mar 6, 2024
1 parent c50e4f3 commit b9b73f8
Show file tree
Hide file tree
Showing 35 changed files with 221 additions and 281 deletions.
14 changes: 5 additions & 9 deletions lib/bumblebee/audio/whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,7 @@ defmodule Bumblebee.Audio.Whisper do
doc:
"the standard deviation of the normal initializer used for initializing kernel parameters"
]
] ++
Shared.common_options([
:output_hidden_states,
:output_attentions
])
]

@moduledoc """
Whisper model family.
Expand Down Expand Up @@ -161,6 +157,10 @@ defmodule Bumblebee.Audio.Whisper do
pass. The cache should be treated as opaque and initialized with
`Bumblebee.Text.Generation.init_cache/4`.
## Global layer options
#{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])}
## Configuration
#{Shared.options_doc(options)}
Expand Down Expand Up @@ -436,8 +436,6 @@ defmodule Bumblebee.Audio.Whisper do
activation: spec.activation
],
block_type: :norm_first,
output_hidden_states: spec.output_hidden_states,
output_attentions: spec.output_attentions,
name: join(name, "blocks")
)

Expand Down Expand Up @@ -485,8 +483,6 @@ defmodule Bumblebee.Audio.Whisper do
activation: spec.activation
],
block_type: :norm_first,
output_hidden_states: spec.output_hidden_states,
output_attentions: spec.output_attentions,
name: join(name, "blocks")
)

Expand Down
44 changes: 33 additions & 11 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -895,15 +895,22 @@ defmodule Bumblebee.Layers do
end

@doc """
Returns a container layer if `condition` is truthy, otherwise returns
a none layer.
Adds a layer that passes the input through only if the given global
layer option is set.
"""
def maybe_container(container, condition) do
if condition do
Axon.container(container)
else
none()
end
def global_opt_in(%Axon{} = input, global_option_name) do
Axon.layer(
fn input, opts ->
if opts[global_option_name] do
input
else
%Axon.None{}
end
end,
[input],
op_name: :global_opt_in,
global_options: [global_option_name]
)
end

@doc """
Expand Down Expand Up @@ -933,17 +940,32 @@ defmodule Bumblebee.Layers do
All values are wrapped with `Axon.optional/2`, so if any of them is
missing, it gets returned as `%Axon.None{}`.
Also, guards known optional outputs behind a global layer option
using `global_opt_in/2`.
"""
@spec output(map()) :: Axon.t()
def output(outputs) do
outputs
|> Map.new(fn
{key, %Axon{} = val} -> {key, Axon.optional(val)}
{key, val} -> {key, val}
|> Map.new(fn {key, %Axon{} = val} ->
{key, val |> maybe_opt_in_output(key) |> Axon.optional()}
end)
|> Axon.container()
end

@opt_in_outputs %{
:hidden_states => :output_hidden_states,
:attentions => :output_attentions
}

defp maybe_opt_in_output(%Axon{} = input, key) do
if option_name = @opt_in_outputs[key] do
global_opt_in(input, option_name)
else
input
end
end

@doc """
Computes a 1-full mask matching the first two dimensions of `input`
(batch size and sequence length).
Expand Down
20 changes: 4 additions & 16 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,6 @@ defmodule Bumblebee.Layers.Transformer do
is configured, this option controls whether the bias from the
first block is used for all other blocks. Defaults to `false`
* `:output_hidden_states` - when `true`, the output includes a
tuple with intermediate hidden states from each transformer
block. Defaults to `false`
* `:output_attentions` - when `true`, the output includes a tuple
with attention weights from each transformer block. Defaults
to `false`
* `:name` - the prefix for layer names
For all other options (including required options) see `block/2`.
Expand Down Expand Up @@ -75,16 +67,12 @@ defmodule Bumblebee.Layers.Transformer do
cross_hidden_state: nil,
cross_attention_mask: Layers.none(),
cross_attention_head_mask: Layers.none(),
cache: Layers.none(),
output_hidden_states: false,
output_attentions: false
cache: Layers.none()
]
)

name = opts[:name]
num_blocks = opts[:num_blocks]
output_hidden_states = opts[:output_hidden_states]
output_attentions = opts[:output_attentions]

attention_mask = opts[:attention_mask]
attention_head_mask = opts[:attention_head_mask]
Expand All @@ -100,9 +88,9 @@ defmodule Bumblebee.Layers.Transformer do

state = %{
hidden_state: hidden_state,
hidden_states: Layers.maybe_container({hidden_state}, output_hidden_states),
attentions: Layers.maybe_container({}, output_attentions),
cross_attentions: Layers.maybe_container({}, output_attentions),
hidden_states: Axon.container({hidden_state}),
attentions: Axon.container({}),
cross_attentions: Axon.container({}),
cache: cache,
attention_relative_bias: Layers.none()
}
Expand Down
18 changes: 5 additions & 13 deletions lib/bumblebee/multimodal/blip.ex
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@ defmodule Bumblebee.Multimodal.Blip do
default: 2.6592,
doc: "the initial value for the scaling layer used to scale similarity logits"
]
] ++
Shared.common_options([
:output_hidden_states,
:output_attentions
])
]

@moduledoc """
The BLIP model for text-image similarity.
Expand Down Expand Up @@ -72,6 +68,10 @@ defmodule Bumblebee.Multimodal.Blip do
pass. The cache should be treated as opaque and initialized with
`Bumblebee.Text.Generation.init_cache/4`.
## Global layer options
#{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])}
## Configuration
#{Shared.options_doc(options)}
Expand Down Expand Up @@ -128,10 +128,6 @@ defmodule Bumblebee.Multimodal.Blip do

vision_model =
vision_spec
|> Bumblebee.configure(
output_hidden_states: spec.output_hidden_states,
output_attentions: spec.output_hidden_states
)
|> Bumblebee.build_model()
|> Bumblebee.Utils.Axon.prefix_names("vision_model.")
|> Bumblebee.Utils.Axon.plug_inputs(%{
Expand All @@ -155,10 +151,6 @@ defmodule Bumblebee.Multimodal.Blip do

text_decoder =
text_spec
|> Bumblebee.configure(
output_hidden_states: spec.output_hidden_states,
output_attentions: spec.output_hidden_states
)
|> Bumblebee.build_model()
|> Bumblebee.Utils.Axon.prefix_names("text_decoder.")
|> Bumblebee.Utils.Axon.plug_inputs(%{
Expand Down
18 changes: 5 additions & 13 deletions lib/bumblebee/multimodal/clip.ex
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@ defmodule Bumblebee.Multimodal.Clip do
default: 2.6592,
doc: "the initial value for the scaling layer used to scale similarity logits"
]
] ++
Shared.common_options([
:output_hidden_states,
:output_attentions
])
]

@moduledoc """
The CLIP model for text-image similarity.
Expand Down Expand Up @@ -54,6 +50,10 @@ defmodule Bumblebee.Multimodal.Clip do
Featurized image pixel values.
## Global layer options
#{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])}
## Configuration
#{Shared.options_doc(options)}
Expand Down Expand Up @@ -108,10 +108,6 @@ defmodule Bumblebee.Multimodal.Clip do

text_model =
text_spec
|> Bumblebee.configure(
output_hidden_states: spec.output_hidden_states,
output_attentions: spec.output_hidden_states
)
|> Bumblebee.build_model()
|> Bumblebee.Utils.Axon.prefix_names("text_model.")
|> Bumblebee.Utils.Axon.plug_inputs(%{
Expand All @@ -122,10 +118,6 @@ defmodule Bumblebee.Multimodal.Clip do

vision_model =
vision_spec
|> Bumblebee.configure(
output_hidden_states: spec.output_hidden_states,
output_attentions: spec.output_hidden_states
)
|> Bumblebee.build_model()
|> Bumblebee.Utils.Axon.prefix_names("vision_model.")
|> Bumblebee.Utils.Axon.plug_inputs(%{
Expand Down
14 changes: 5 additions & 9 deletions lib/bumblebee/multimodal/layout_lm.ex
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,7 @@ defmodule Bumblebee.Multimodal.LayoutLm do
default: 1.0e-12,
doc: "the epsilon used by the layer normalization layers"
]
] ++
Shared.common_options([
:output_hidden_states,
:output_attentions,
:num_labels,
:id_to_label
])
] ++ Shared.common_options([:num_labels, :id_to_label])

@moduledoc """
LayoutLM Model family.
Expand Down Expand Up @@ -140,6 +134,10 @@ defmodule Bumblebee.Multimodal.LayoutLm do
`{x0, y0, x1, y1}` where `{x0, y0}` is the upper left corner and
`{x1, y1}` is the lower right corner.
## Global layer options
#{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])}
## Configuration
#{Shared.options_doc(options)}
Expand Down Expand Up @@ -426,8 +424,6 @@ defmodule Bumblebee.Multimodal.LayoutLm do
intermediate_size: spec.intermediate_size,
activation: spec.activation
],
output_hidden_states: spec.output_hidden_states,
output_attentions: spec.output_attentions,
name: join(name, "blocks")
)
end
Expand Down
16 changes: 16 additions & 0 deletions lib/bumblebee/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,22 @@ defmodule Bumblebee.Shared do
Enum.join(items, "\n\n")
end

@doc """
Generates documentation string for the given global layer options.
"""
@spec global_layer_options_doc(list(atom())) :: String.t()
def global_layer_options_doc(names) do
docs = [
output_hidden_states: "when `true`, the model output includes all hidden states",
output_attentions: "when `true`, the model output includes all attention weights"
]

Enum.map_join(names, "\n\n", fn name ->
doc = Keyword.fetch!(docs, name)
" * `#{inspect(name)}` - #{doc}"
end)
end

@doc """
Returns option defaults form the options specification.
Expand Down
16 changes: 7 additions & 9 deletions lib/bumblebee/text/albert.ex
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,7 @@ defmodule Bumblebee.Text.Albert do
doc:
"the standard deviation of the normal initializer used for initializing kernel parameters"
]
] ++
Shared.common_options([
:output_hidden_states,
:output_attentions,
:num_labels,
:id_to_label
])
] ++ Shared.common_options([:num_labels, :id_to_label])

@moduledoc """
ALBERT model family.
Expand Down Expand Up @@ -148,6 +142,10 @@ defmodule Bumblebee.Text.Albert do
The `:for_multiple_choice` model accepts groups of sequences, so the
expected sequence shape is `{batch_size, num_choices, sequence_length}`.
## Global layer options
#{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])}
## Configuration
#{Shared.options_doc(options)}
Expand Down Expand Up @@ -389,8 +387,8 @@ defmodule Bumblebee.Text.Albert do
name: join(name, "embedding_projection")
)

hidden_states = Layers.maybe_container({hidden_state}, spec.output_hidden_states)
attentions = Layers.maybe_container({}, spec.output_attentions)
hidden_states = Axon.container({hidden_state})
attentions = Axon.container({})

for block_idx <- 0..(spec.num_blocks - 1),
inner_idx <- 0..(spec.block_depth - 1),
Expand Down
15 changes: 5 additions & 10 deletions lib/bumblebee/text/bart.ex
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,7 @@ defmodule Bumblebee.Text.Bart do
"the standard deviation of the normal initializer used for initializing kernel parameters"
]
] ++
Shared.common_options([
:output_hidden_states,
:output_attentions,
:num_labels,
:id_to_label
]) ++
Shared.common_options([:num_labels, :id_to_label]) ++
Shared.token_options(
eos_token_id: 2,
decoder_start_token_id: 2
Expand Down Expand Up @@ -197,6 +192,10 @@ defmodule Bumblebee.Text.Bart do
`"position_ids"`, `"attention_head_mask"`, `"input_embeddings"`, `"encoder_hidden_state"`,
`"encoder_attention_mask"`, `"cross_attention_head_mask"`, `"cache"`.
## Global layer options
#{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])}
## Configuration
#{Shared.options_doc(options)}
Expand Down Expand Up @@ -563,8 +562,6 @@ defmodule Bumblebee.Text.Bart do
intermediate_size: spec.encoder_intermediate_size,
activation: spec.activation
],
output_hidden_states: spec.output_hidden_states,
output_attentions: spec.output_attentions,
name: join(name, "blocks")
)
end
Expand Down Expand Up @@ -603,8 +600,6 @@ defmodule Bumblebee.Text.Bart do
intermediate_size: spec.decoder_intermediate_size,
activation: spec.activation
],
output_hidden_states: spec.output_hidden_states,
output_attentions: spec.output_attentions,
name: join(name, "blocks")
)
end
Expand Down
Loading

0 comments on commit b9b73f8

Please sign in to comment.