Skip to content

Commit

Permalink
Support llama3 checkpoints with tied word embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Sep 20, 2024
1 parent ed2f177 commit b01e0da
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
13 changes: 10 additions & 3 deletions lib/bumblebee/text/llama.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ defmodule Bumblebee.Text.Llama do
default: 0.02,
doc:
"the standard deviation of the normal initializer used for initializing kernel parameters"
],
tie_word_embeddings: [
default: false,
doc: "whether to tie input and output embedding weights"
]
] ++
Shared.common_options([:num_labels, :id_to_label]) ++ Shared.token_options(pad_token_id: 0)
Expand Down Expand Up @@ -435,6 +439,7 @@ defmodule Bumblebee.Text.Llama do
opts =
convert!(data,
vocab_size: {"vocab_size", number()},
tie_word_embeddings: {"tie_word_embeddings", boolean()},
max_positions: {"max_position_embeddings", number()},
hidden_size: {"hidden_size", number()},
num_blocks: {"num_hidden_layers", number()},
Expand All @@ -447,15 +452,16 @@ defmodule Bumblebee.Text.Llama do
rotary_embedding_scaling_strategy:
{"rope_scaling", optional(scaling_strategy_converter)},
initializer_scale: {"initializer_range", number()},
layer_norm_epsilon: {"rms_norm_eps", number()}
layer_norm_epsilon: {"rms_norm_eps", number()},
tie_word_embeddings: {"tie_word_embeddings", boolean()}
) ++ Shared.common_options_from_transformers(data, spec)

@for.config(spec, opts)
end
end

defimpl Bumblebee.HuggingFace.Transformers.Model do
def params_mapping(_spec) do
def params_mapping(spec) do
%{
"embedder.token_embedding" => "model.embed_tokens",
"decoder.blocks.{n}.self_attention.query" => "model.layers.{n}.self_attn.q_proj",
Expand All @@ -470,7 +476,8 @@ defmodule Bumblebee.Text.Llama do
"decoder.blocks.{n}.ffn.output" => "model.layers.{n}.mlp.down_proj",
"decoder.blocks.{n}.output_norm" => "model.layers.{n}.post_attention_layernorm",
"output_norm" => "model.norm",
"language_modeling_head.output" => "lm_head",
"language_modeling_head.output" =>
if(spec.tie_word_embeddings, do: "model.embed_tokens", else: "lm_head"),
"sequence_classification_head.output" => "score"
}
end
Expand Down
14 changes: 6 additions & 8 deletions lib/bumblebee/text/t5.ex
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,6 @@ defmodule Bumblebee.Text.T5 do
tokens that can be represented in model input and output
"""
],
tie_word_embeddings: [
default: true,
doc: """
whether or not to tie encoder and decoder token embedding
"""
],
hidden_size: [
default: 512,
doc: "the dimensionality of hidden layers"
Expand Down Expand Up @@ -74,6 +68,10 @@ defmodule Bumblebee.Text.T5 do
layer_norm_epsilon: [
default: 1.0e-6,
doc: "the epsilon used by the layer normalization layers"
],
tie_word_embeddings: [
default: true,
doc: "whether or not to tie encoder and decoder token embedding"
]
] ++
Shared.common_options([:num_labels, :id_to_label]) ++
Expand Down Expand Up @@ -538,7 +536,6 @@ defmodule Bumblebee.Text.T5 do
opts =
convert!(data,
vocab_size: {"vocab_size", number()},
tie_word_embeddings: {"tie_word_embeddings", boolean()},
hidden_size: {"d_model", number()},
attention_head_size: {"d_kv", number()},
encoder_num_blocks: {"num_layers", number()},
Expand All @@ -551,7 +548,8 @@ defmodule Bumblebee.Text.T5 do
activation: {"feed_forward_proj", t5_activation()},
ffn_gated_activation: {"feed_forward_proj", ffn_gated_activation()},
dropout_rate: {"dropout", number()},
initializer_scale: {"initializer_factor", number()}
initializer_scale: {"initializer_factor", number()},
tie_word_embeddings: {"tie_word_embeddings", boolean()}
) ++ Shared.common_options_from_transformers(data, spec)

@for.config(spec, opts)
Expand Down

0 comments on commit b01e0da

Please sign in to comment.