Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement attention sliding window for Mistral #341

Merged
merged 3 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 43 additions & 17 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ defmodule Bumblebee.Layers do
effectively makes each input token use information exclusively
from prior tokens. Defaults to `false`

* `:window_size` - when set, enables sliding window attention.
Should be a `{left, right}` tuple with window size on each side

* `:scale` - whether to scale attention weights by $\frac{1}{\sqrt{d}}$.
Defaults to `true`

Expand All @@ -228,7 +231,7 @@ defmodule Bumblebee.Layers do

"""
def attention(query, key, value, key_mask, head_mask, bias, offset, opts \\ []) do
opts = Keyword.validate!(opts, causal: false, scale: true, dropout_rate: 0.0)
opts = Keyword.validate!(opts, [:window_size, causal: false, scale: true, dropout_rate: 0.0])

weights =
Axon.layer(
Expand All @@ -242,6 +245,7 @@ defmodule Bumblebee.Layers do
Axon.optional(offset)
],
causal: opts[:causal],
window_size: opts[:window_size],
scale: opts[:scale]
)
|> Axon.dropout(rate: opts[:dropout_rate])
Expand All @@ -252,7 +256,7 @@ defmodule Bumblebee.Layers do
end

defnp attention_weights_impl(query, key, key_mask, head_mask, bias, offset, opts \\ []) do
opts = keyword!(opts, mode: :inference, scale: true, causal: false)
opts = keyword!(opts, [:window_size, mode: :inference, scale: true, causal: false])

query = Nx.transpose(query, axes: [0, 2, 1, 3])
key = Nx.transpose(key, axes: [0, 2, 1, 3])
Expand All @@ -273,23 +277,28 @@ defmodule Bumblebee.Layers do
key_mask -> key_mask |> Nx.new_axis(1) |> Nx.new_axis(1)
end

causal_mask =
if opts[:causal] do
query_sequence_length = Nx.axis_size(query, 2)
key_sequence_length = Nx.axis_size(key, 2)
offset = ensure_offset(offset)

Nx.greater_equal(
Nx.iota({query_sequence_length, 1}) + offset,
Nx.iota({1, key_sequence_length})
)
|> Nx.new_axis(0)
|> Nx.new_axis(0)
else
Nx.broadcast(1, {1, 1, 1, 1})
query_sequence_length = Nx.axis_size(query, 2)
key_sequence_length = Nx.axis_size(key, 2)
offset = ensure_offset(offset)

causal_and_window_mask =
case {opts[:causal], opts[:window_size]} do
{false, nil} ->
Nx.broadcast(1, {1, 1})

{false, {left_size, right_size}} ->
window_mask(query_sequence_length, key_sequence_length, offset, left_size, right_size)

{true, nil} ->
causal_mask(query_sequence_length, key_sequence_length, offset)

{true, {left_size, _right_size}} ->
window_mask(query_sequence_length, key_sequence_length, offset, left_size, 0)
end
|> Nx.new_axis(0)
|> Nx.new_axis(0)

mask = Nx.logical_and(key_mask, causal_mask)
mask = key_mask and causal_and_window_mask

bias =
case bias do
Expand Down Expand Up @@ -322,6 +331,23 @@ defmodule Bumblebee.Layers do
end
end

defnp causal_mask(query_sequence_length, key_sequence_length, offset) do
Nx.greater_equal(
Nx.iota({query_sequence_length, 1}) + offset,
Nx.iota({1, key_sequence_length})
)
end

defnp window_mask(query_sequence_length, key_sequence_length, offset, left_size, right_size) do
position_diff =
Nx.subtract(
Nx.iota({query_sequence_length, 1}) + offset,
Nx.iota({1, key_sequence_length})
)

left_size >= position_diff and position_diff >= -right_size
end

defnp attention_output_impl(weights, value, _opts \\ []) do
value = Nx.transpose(value, axes: [0, 2, 1, 3])
out = Nx.dot(weights, [3], [0, 1], value, [2], [0, 1])
Expand Down
17 changes: 17 additions & 0 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ defmodule Bumblebee.Layers.Transformer do
:output_use_bias,
:layer_norm,
:block_type,
:attention_window_size,
:scale_attention_weights,
:rotary_embedding
]
Expand Down Expand Up @@ -269,6 +270,9 @@ defmodule Bumblebee.Layers.Transformer do
receives the input hidden state, a map with block steps and a
name to prefix any additional layers.

* `:attention_window_size` - when set, enables sliding window attention.
Should be a `{left, right}` tuple with window size on each side

* `:scale_attention_weights` - whether to scale query in the traditional style of
multi-headed attention. Defaults to `true`

Expand Down Expand Up @@ -323,6 +327,7 @@ defmodule Bumblebee.Layers.Transformer do
output_use_bias: true,
block_type: :standard,
layer_norm: [],
attention_window_size: nil,
scale_attention_weights: true,
rotary_embedding: nil
])
Expand Down Expand Up @@ -351,6 +356,7 @@ defmodule Bumblebee.Layers.Transformer do
offset = opts[:offset]
layer_norm = opts[:layer_norm]
block_type = opts[:block_type]
attention_window_size = opts[:attention_window_size]
scale_attention_weights = opts[:scale_attention_weights]
rotary_embedding = opts[:rotary_embedding]

Expand Down Expand Up @@ -408,6 +414,7 @@ defmodule Bumblebee.Layers.Transformer do
key_use_bias: key_use_bias,
value_use_bias: value_use_bias,
output_use_bias: output_use_bias,
attention_window_size: attention_window_size,
scale_attention_weights: scale_attention_weights,
rotary_embedding: rotary_embedding,
name: join(name, "self_attention")
Expand Down Expand Up @@ -452,6 +459,7 @@ defmodule Bumblebee.Layers.Transformer do
key_use_bias: key_use_bias,
value_use_bias: value_use_bias,
output_use_bias: output_use_bias,
attention_window_size: attention_window_size,
scale_attention_weights: scale_attention_weights,
rotary_embedding: rotary_embedding,
name: join(name, "cross_attention")
Expand Down Expand Up @@ -679,6 +687,12 @@ defmodule Bumblebee.Layers.Transformer do
* `:output_use_bias` - whether to use bias in the output projection.
Defaults to `true`

* `:attention_window_size` - when set, enables sliding window attention.
Should be a `{left, right}` tuple with window size on each side

* `:scale_attention_weights` - whether to scale query in the traditional style of
multi-headed attention. Defaults to `true`

* `:rotary_embedding` - configuration of rotary embedding. If set,
will apply rotary position embedding with the given options. Valid
options are:
Expand Down Expand Up @@ -710,6 +724,7 @@ defmodule Bumblebee.Layers.Transformer do
attention_cache: Layers.none(),
offset: Layers.none(),
causal: false,
attention_window_size: nil,
scale_attention_weights: true,
kernel_initializer: :glorot_uniform,
dropout_rate: 0.0,
Expand All @@ -732,6 +747,7 @@ defmodule Bumblebee.Layers.Transformer do
hidden_size = opts[:hidden_size]
kernel_initializer = opts[:kernel_initializer]
causal = opts[:causal]
attention_window_size = opts[:attention_window_size]
scale_attention_weights = opts[:scale_attention_weights]
dropout_rate = opts[:dropout_rate]
rotary_embedding = opts[:rotary_embedding]
Expand Down Expand Up @@ -858,6 +874,7 @@ defmodule Bumblebee.Layers.Transformer do
offset,
scale: scale_attention_weights,
causal: causal,
window_size: attention_window_size,
dropout_rate: dropout_rate
)

Expand Down
7 changes: 7 additions & 0 deletions lib/bumblebee/text/mistral.ex
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ defmodule Bumblebee.Text.Mistral do
Attention
"""
],
attention_window_size: [
default: 4096,
doc: "window size for both sides of the sliding attention window"
],
activation: [
default: :silu,
doc: "the activation function"
Expand Down Expand Up @@ -329,6 +333,8 @@ defmodule Bumblebee.Text.Mistral do
),
block_type: :norm_first,
causal: true,
attention_window_size:
spec.attention_window_size && {spec.attention_window_size, spec.attention_window_size},
rotary_embedding: [
position_ids: position_ids,
max_positions: spec.max_positions,
Expand Down Expand Up @@ -387,6 +393,7 @@ defmodule Bumblebee.Text.Mistral do
num_blocks: {"num_hidden_layers", number()},
num_attention_heads: {"num_attention_heads", number()},
num_key_value_heads: {"num_key_value_heads", number()},
attention_window_size: {"sliding_window", number()},
intermediate_size: {"intermediate_size", number()},
activation: {"hidden_act", activation()},
rotary_embedding_base: {"rope_theta", number()},
Expand Down
30 changes: 30 additions & 0 deletions test/bumblebee/text/mistral_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,36 @@ defmodule Bumblebee.Text.MistralTest do
)
end

test ":base with attention sliding window" do
assert {:ok, spec} =
Bumblebee.load_spec({:hf, "hf-internal-testing/tiny-random-MistralModel"})

spec = Bumblebee.configure(spec, attention_window_size: 2)

assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MistralModel"},
spec: spec
)

assert %Bumblebee.Text.Mistral{architecture: :base} = spec

inputs = %{
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
}

outputs = Axon.predict(model, params, inputs)

assert Nx.shape(outputs.hidden_state) == {1, 10, 32}

assert_all_close(
outputs.hidden_state[[.., 1..3, 1..3]],
Nx.tensor([
[[0.9450, -1.3945, 0.7331], [-2.1118, -1.3091, -0.7834], [-1.3033, -1.3374, 0.8919]]
])
)
end

test ":for_sequence_classification" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
Expand Down
Loading