Skip to content

Commit

Permalink
Add gemma
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Mar 2, 2024
1 parent a09b230 commit f1ab76f
Show file tree
Hide file tree
Showing 5 changed files with 545 additions and 3 deletions.
4 changes: 4 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ defmodule Bumblebee do
"DistilBertForQuestionAnswering" => {Bumblebee.Text.Distilbert, :for_question_answering},
"DistilBertForTokenClassification" => {Bumblebee.Text.Distilbert, :for_token_classification},
"DistilBertForMultipleChoice" => {Bumblebee.Text.Distilbert, :for_multiple_choice},
"GemmaModel" => {Bumblebee.Text.Gemma, :base},
"GemmaForCausalLM" => {Bumblebee.Text.Gemma, :for_causal_language_modeling},
"GemmaForSequenceClassification" => {Bumblebee.Text.Gemma, :for_sequence_classification},
"GPT2ForSequenceClassification" => {Bumblebee.Text.Gpt2, :for_sequence_classification},
"GPT2ForTokenClassification" => {Bumblebee.Text.Gpt2, :for_token_classification},
"GPT2LMHeadModel" => {Bumblebee.Text.Gpt2, :for_causal_language_modeling},
Expand Down Expand Up @@ -231,6 +234,7 @@ defmodule Bumblebee do
"distilbert" => :distilbert,
"camembert" => :camembert,
"clip" => :clip,
"gemma" => :gemma,
"gpt_neox" => :gpt_neo_x,
"gpt2" => :gpt2,
"gpt_bigcode" => :gpt2,
Expand Down
13 changes: 10 additions & 3 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,13 @@ defmodule Bumblebee.Layers do
# TODO: Add to Axon
def rms_norm(input, opts \\ []) do
opts =
Keyword.validate!(opts, [:name, channel_index: -1, epsilon: 1.0e-6, initializer: :ones])
Keyword.validate!(opts, [
:name,
shift: 0.0,
channel_index: -1,
epsilon: 1.0e-6,
initializer: :ones
])

weight =
Axon.param("weight", &Axon.Shape.norm_param(&1, opts[:channel_index]),
Expand All @@ -1100,13 +1106,14 @@ defmodule Bumblebee.Layers do

Axon.layer(&rms_norm_impl/3, [input, weight],
name: opts[:name],
shift: opts[:shift],
epsilon: opts[:epsilon],
op_name: :rms_norm
)
end

defnp rms_norm_impl(input, weight, opts \\ []) do
opts = keyword!(opts, epsilon: 1.0e-6, channel_index: -1, mode: :train)
opts = keyword!(opts, shift: 0.0, epsilon: 1.0e-6, channel_index: -1, mode: :train)

variance =
input
Expand All @@ -1117,7 +1124,7 @@ defmodule Bumblebee.Layers do
input
|> Nx.multiply(Nx.rsqrt(variance + opts[:epsilon]))

x * weight
x * (opts[:shift] + weight)
end

@doc """
Expand Down
Loading

0 comments on commit f1ab76f

Please sign in to comment.