diff --git a/lib/bumblebee/text/llama.ex b/lib/bumblebee/text/llama.ex index 05995957..3fed0524 100644 --- a/lib/bumblebee/text/llama.ex +++ b/lib/bumblebee/text/llama.ex @@ -75,6 +75,10 @@ defmodule Bumblebee.Text.Llama do default: 0.02, doc: "the standard deviation of the normal initializer used for initializing kernel parameters" + ], + tie_word_embeddings: [ + default: false, + doc: "whether to tie input and output embedding weights" ] ] ++ Shared.common_options([:num_labels, :id_to_label]) ++ Shared.token_options(pad_token_id: 0) @@ -435,6 +439,7 @@ defmodule Bumblebee.Text.Llama do opts = convert!(data, vocab_size: {"vocab_size", number()}, + tie_word_embeddings: {"tie_word_embeddings", boolean()}, max_positions: {"max_position_embeddings", number()}, hidden_size: {"hidden_size", number()}, num_blocks: {"num_hidden_layers", number()}, @@ -447,7 +452,8 @@ defmodule Bumblebee.Text.Llama do rotary_embedding_scaling_strategy: {"rope_scaling", optional(scaling_strategy_converter)}, initializer_scale: {"initializer_range", number()}, - layer_norm_epsilon: {"rms_norm_eps", number()} + layer_norm_epsilon: {"rms_norm_eps", number()}, + tie_word_embeddings: {"tie_word_embeddings", boolean()} ) ++ Shared.common_options_from_transformers(data, spec) @for.config(spec, opts) @@ -455,7 +461,7 @@ defmodule Bumblebee.Text.Llama do end defimpl Bumblebee.HuggingFace.Transformers.Model do - def params_mapping(_spec) do + def params_mapping(spec) do %{ "embedder.token_embedding" => "model.embed_tokens", "decoder.blocks.{n}.self_attention.query" => "model.layers.{n}.self_attn.q_proj", @@ -470,7 +476,8 @@ defmodule Bumblebee.Text.Llama do "decoder.blocks.{n}.ffn.output" => "model.layers.{n}.mlp.down_proj", "decoder.blocks.{n}.output_norm" => "model.layers.{n}.post_attention_layernorm", "output_norm" => "model.norm", - "language_modeling_head.output" => "lm_head", + "language_modeling_head.output" => + if(spec.tie_word_embeddings, do: "model.embed_tokens", else: "lm_head"), "sequence_classification_head.output" => "score" } end diff --git a/lib/bumblebee/text/t5.ex b/lib/bumblebee/text/t5.ex index fb865c2c..3f9cabb5 100644 --- a/lib/bumblebee/text/t5.ex +++ b/lib/bumblebee/text/t5.ex @@ -10,12 +10,6 @@ defmodule Bumblebee.Text.T5 do tokens that can be represented in model input and output """ ], - tie_word_embeddings: [ - default: true, - doc: """ - whether or not to tie encoder and decoder token embedding - """ - ], hidden_size: [ default: 512, doc: "the dimensionality of hidden layers" @@ -74,6 +68,10 @@ defmodule Bumblebee.Text.T5 do layer_norm_epsilon: [ default: 1.0e-6, doc: "the epsilon used by the layer normalization layers" + ], + tie_word_embeddings: [ + default: true, + doc: "whether or not to tie encoder and decoder token embedding" ] ] ++ Shared.common_options([:num_labels, :id_to_label]) ++ @@ -538,7 +536,6 @@ defmodule Bumblebee.Text.T5 do opts = convert!(data, vocab_size: {"vocab_size", number()}, - tie_word_embeddings: {"tie_word_embeddings", boolean()}, hidden_size: {"d_model", number()}, attention_head_size: {"d_kv", number()}, encoder_num_blocks: {"num_layers", number()}, @@ -551,7 +548,8 @@ defmodule Bumblebee.Text.T5 do activation: {"feed_forward_proj", t5_activation()}, ffn_gated_activation: {"feed_forward_proj", ffn_gated_activation()}, dropout_rate: {"dropout", number()}, - initializer_scale: {"initializer_factor", number()} + initializer_scale: {"initializer_factor", number()}, + tie_word_embeddings: {"tie_word_embeddings", boolean()} ) ++ Shared.common_options_from_transformers(data, spec) @for.config(spec, opts)