Skip to content

Commit

Permalink
Support configurable attention head size for Llama
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Sep 17, 2024
1 parent 9388b28 commit ed2f177
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions lib/bumblebee/text/llama.ex
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ defmodule Bumblebee.Text.Llama do
default: 11008,
doc: "the dimensionality of intermediate layers"
],
attention_head_size: [
default: nil,
doc: """
the size of the key, value, and query projection per attention head.
Defaults to `div(hidden_size, num_attention_heads)
"""
],
num_blocks: [
default: 32,
doc: "the number of Transformer blocks in the model"
Expand Down Expand Up @@ -169,6 +176,7 @@ defmodule Bumblebee.Text.Llama do
def init_cache(spec, batch_size, max_length, _inputs) do
Layers.Decoder.init_cache(batch_size, max_length,
hidden_size: spec.hidden_size,
attention_head_size: spec.attention_head_size,
decoder_num_attention_heads: spec.num_attention_heads,
decoder_num_blocks: spec.num_blocks
)
Expand Down Expand Up @@ -321,6 +329,7 @@ defmodule Bumblebee.Text.Llama do
Layers.Transformer.blocks(hidden_state,
attention_mask: attention_mask,
attention_head_mask: attention_head_mask,
attention_head_size: spec.attention_head_size,
cache: cache,
num_blocks: spec.num_blocks,
num_attention_heads: spec.num_attention_heads,
Expand Down Expand Up @@ -431,6 +440,7 @@ defmodule Bumblebee.Text.Llama do
num_blocks: {"num_hidden_layers", number()},
num_attention_heads: {"num_attention_heads", number()},
num_key_value_heads: {"num_key_value_heads", number()},
attention_head_size: {"head_dim", number()},
intermediate_size: {"intermediate_size", number()},
activation: {"hidden_act", activation()},
rotary_embedding_base: {"rope_theta", number()},
Expand Down

0 comments on commit ed2f177

Please sign in to comment.