Skip to content

Commit

Permalink
Implement attention sliding window for Mistral
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Feb 21, 2024
1 parent f739e0a commit 23700cf
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 17 deletions.
60 changes: 43 additions & 17 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ defmodule Bumblebee.Layers do
effectively makes each input token use information exclusively
from prior tokens. Defaults to `false`
* `:window_size` - when set, enables sliding window attention.
Should be a `{left, right}` tuple with window size on each side
* `:scale` - whether to scale attention weights by $\frac{1}{\sqrt{d}}$.
Defaults to `true`
Expand All @@ -228,7 +231,7 @@ defmodule Bumblebee.Layers do
"""
def attention(query, key, value, key_mask, head_mask, bias, offset, opts \\ []) do
opts = Keyword.validate!(opts, causal: false, scale: true, dropout_rate: 0.0)
opts = Keyword.validate!(opts, [:window_size, causal: false, scale: true, dropout_rate: 0.0])

weights =
Axon.layer(
Expand All @@ -242,6 +245,7 @@ defmodule Bumblebee.Layers do
Axon.optional(offset)
],
causal: opts[:causal],
window_size: opts[:window_size],
scale: opts[:scale]
)
|> Axon.dropout(rate: opts[:dropout_rate])
Expand All @@ -252,7 +256,7 @@ defmodule Bumblebee.Layers do
end

defnp attention_weights_impl(query, key, key_mask, head_mask, bias, offset, opts \\ []) do
opts = keyword!(opts, mode: :inference, scale: true, causal: false)
opts = keyword!(opts, [:window_size, mode: :inference, scale: true, causal: false])

query = Nx.transpose(query, axes: [0, 2, 1, 3])
key = Nx.transpose(key, axes: [0, 2, 1, 3])
Expand All @@ -273,23 +277,28 @@ defmodule Bumblebee.Layers do
key_mask -> key_mask |> Nx.new_axis(1) |> Nx.new_axis(1)
end

causal_mask =
if opts[:causal] do
query_sequence_length = Nx.axis_size(query, 2)
key_sequence_length = Nx.axis_size(key, 2)
offset = ensure_offset(offset)

Nx.greater_equal(
Nx.iota({query_sequence_length, 1}) + offset,
Nx.iota({1, key_sequence_length})
)
|> Nx.new_axis(0)
|> Nx.new_axis(0)
else
Nx.broadcast(1, {1, 1, 1, 1})
query_sequence_length = Nx.axis_size(query, 2)
key_sequence_length = Nx.axis_size(key, 2)
offset = ensure_offset(offset)

causal_and_window_mask =
case {opts[:causal], opts[:window_size]} do
{false, nil} ->
Nx.broadcast(1, {1, 1})

{false, {left_size, right_size}} ->
window_mask(query_sequence_length, key_sequence_length, offset, left_size, right_size)

{true, nil} ->
causal_mask(query_sequence_length, key_sequence_length, offset)

{true, {left_size, _right_size}} ->
window_mask(query_sequence_length, key_sequence_length, offset, left_size, 0)
end
|> Nx.new_axis(0)
|> Nx.new_axis(0)

mask = Nx.logical_and(key_mask, causal_mask)
mask = key_mask and causal_and_window_mask

bias =
case bias do
Expand Down Expand Up @@ -322,6 +331,23 @@ defmodule Bumblebee.Layers do
end
end

defnp causal_mask(query_sequence_length, key_sequence_length, offset) do
Nx.greater_equal(
Nx.iota({query_sequence_length, 1}) + offset,
Nx.iota({1, key_sequence_length})
)
end

defnp window_mask(query_sequence_length, key_sequence_length, offset, left_size, right_size) do
position_diff =
Nx.subtract(
Nx.iota({query_sequence_length, 1}) + offset,
Nx.iota({1, key_sequence_length})
)

left_size >= position_diff and position_diff >= -right_size
end

defnp attention_output_impl(weights, value, _opts \\ []) do
value = Nx.transpose(value, axes: [0, 2, 1, 3])
out = Nx.dot(weights, [3], [0, 1], value, [2], [0, 1])
Expand Down
17 changes: 17 additions & 0 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ defmodule Bumblebee.Layers.Transformer do
:output_use_bias,
:layer_norm,
:block_type,
:attention_window_size,
:scale_attention_weights,
:rotary_embedding
]
Expand Down Expand Up @@ -269,6 +270,9 @@ defmodule Bumblebee.Layers.Transformer do
receives the input hidden state, a map with block steps and a
name to prefix any additional layers.
* `:attention_window_size` - when set, enables sliding window attention.
Should be a `{left, right}` tuple with window size on each side
* `:scale_attention_weights` - whether to scale query in the traditional style of
multi-headed attention. Defaults to `true`
Expand Down Expand Up @@ -323,6 +327,7 @@ defmodule Bumblebee.Layers.Transformer do
output_use_bias: true,
block_type: :standard,
layer_norm: [],
attention_window_size: nil,
scale_attention_weights: true,
rotary_embedding: nil
])
Expand Down Expand Up @@ -351,6 +356,7 @@ defmodule Bumblebee.Layers.Transformer do
offset = opts[:offset]
layer_norm = opts[:layer_norm]
block_type = opts[:block_type]
attention_window_size = opts[:attention_window_size]
scale_attention_weights = opts[:scale_attention_weights]
rotary_embedding = opts[:rotary_embedding]

Expand Down Expand Up @@ -408,6 +414,7 @@ defmodule Bumblebee.Layers.Transformer do
key_use_bias: key_use_bias,
value_use_bias: value_use_bias,
output_use_bias: output_use_bias,
attention_window_size: attention_window_size,
scale_attention_weights: scale_attention_weights,
rotary_embedding: rotary_embedding,
name: join(name, "self_attention")
Expand Down Expand Up @@ -452,6 +459,7 @@ defmodule Bumblebee.Layers.Transformer do
key_use_bias: key_use_bias,
value_use_bias: value_use_bias,
output_use_bias: output_use_bias,
attention_window_size: attention_window_size,
scale_attention_weights: scale_attention_weights,
rotary_embedding: rotary_embedding,
name: join(name, "cross_attention")
Expand Down Expand Up @@ -679,6 +687,12 @@ defmodule Bumblebee.Layers.Transformer do
* `:output_use_bias` - whether to use bias in the output projection.
Defaults to `true`
* `:attention_window_size` - when set, enables sliding window attention.
Should be a `{left, right}` tuple with window size on each side
* `:scale_attention_weights` - whether to scale query in the traditional style of
multi-headed attention. Defaults to `true`
* `:rotary_embedding` - configuration of rotary embedding. If set,
will apply rotary position embedding with the given options. Valid
options are:
Expand Down Expand Up @@ -710,6 +724,7 @@ defmodule Bumblebee.Layers.Transformer do
attention_cache: Layers.none(),
offset: Layers.none(),
causal: false,
attention_window_size: nil,
scale_attention_weights: true,
kernel_initializer: :glorot_uniform,
dropout_rate: 0.0,
Expand All @@ -732,6 +747,7 @@ defmodule Bumblebee.Layers.Transformer do
hidden_size = opts[:hidden_size]
kernel_initializer = opts[:kernel_initializer]
causal = opts[:causal]
attention_window_size = opts[:attention_window_size]
scale_attention_weights = opts[:scale_attention_weights]
dropout_rate = opts[:dropout_rate]
rotary_embedding = opts[:rotary_embedding]
Expand Down Expand Up @@ -858,6 +874,7 @@ defmodule Bumblebee.Layers.Transformer do
offset,
scale: scale_attention_weights,
causal: causal,
window_size: attention_window_size,
dropout_rate: dropout_rate
)

Expand Down
7 changes: 7 additions & 0 deletions lib/bumblebee/text/mistral.ex
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ defmodule Bumblebee.Text.Mistral do
Attention
"""
],
attention_window_size: [
default: 4096,
doc: "window size for both sides of the sliding attention window"
],
activation: [
default: :silu,
doc: "the activation function"
Expand Down Expand Up @@ -329,6 +333,8 @@ defmodule Bumblebee.Text.Mistral do
),
block_type: :norm_first,
causal: true,
attention_window_size:
spec.attention_window_size && {spec.attention_window_size, spec.attention_window_size},
rotary_embedding: [
position_ids: position_ids,
max_positions: spec.max_positions,
Expand Down Expand Up @@ -387,6 +393,7 @@ defmodule Bumblebee.Text.Mistral do
num_blocks: {"num_hidden_layers", number()},
num_attention_heads: {"num_attention_heads", number()},
num_key_value_heads: {"num_key_value_heads", number()},
attention_window_size: {"sliding_window", number()},
intermediate_size: {"intermediate_size", number()},
activation: {"hidden_act", activation()},
rotary_embedding_base: {"rope_theta", number()},
Expand Down
32 changes: 32 additions & 0 deletions test/bumblebee/text/mistral_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,38 @@ defmodule Bumblebee.Text.MistralTest do
)
end

test ":base with attention sliding window" do
assert {:ok, spec} =
Bumblebee.load_spec({:hf, "hf-internal-testing/tiny-random-MistralModel"})

# TODO test once we know the expected behaviour
# spec = Bumblebee.configure(spec, attention_window_size: 2)
spec = Bumblebee.configure(spec, attention_window_size: 1)

assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MistralModel"},
spec: spec
)

assert %Bumblebee.Text.Mistral{architecture: :base} = spec

inputs = %{
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
}

outputs = Axon.predict(model, params, inputs)

assert Nx.shape(outputs.hidden_state) == {1, 10, 32}

assert_all_close(
outputs.hidden_state[[.., 1..3, 1..3]],
Nx.tensor([
[[0.9450, -1.3945, 0.7331], [-1.4422, -1.4622, -0.9143], [-1.5628, -1.0444, 0.9262]]
])
)
end

test ":for_sequence_classification" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
Expand Down

0 comments on commit 23700cf

Please sign in to comment.