From 93b5580c8927f124aa95a081d6554ae3035b22b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 12 Mar 2024 14:52:25 +0100 Subject: [PATCH] Gemma fixes (#362) --- lib/bumblebee/layers.ex | 70 +++++++++++++++++++++++++++++++++---- lib/bumblebee/text/gemma.ex | 17 ++++++--- 2 files changed, 76 insertions(+), 11 deletions(-) diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index 6255eb4a..ffdf88e0 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -1109,6 +1109,31 @@ defmodule Bumblebee.Layers do @doc """ Adds an RMS Normalization layer to the network. + + ## Options + + * `:name` - layer name + + * `:initializer` - initializer for the standard deviation parameter. + Defaults to `:ones` + + * `:channel_index` - input feature index used for calculating + variance. Defaults to `-1` + + * `:epsilon` - numerical stability term + + * `:shift` - numeric shift in the scaling expression. Defaults to + `0.0` + + * `:upcast` - adds explicit type casting to make sure the norm + is computed in high numerical precision. Either of: + + * `:normalization` (default) - upcasts only the input normalization + part + + * `:all` - upcasts both input normalization and the scaling + expression + """ # TODO: Add to Axon def rms_norm(input, opts \\ []) do @@ -1118,15 +1143,29 @@ defmodule Bumblebee.Layers do shift: 0.0, channel_index: -1, epsilon: 1.0e-6, + upcast: :normalization, initializer: :ones ]) + impl = + case opts[:upcast] do + :normalization -> + &rms_norm_impl_upcast_normalization/3 + + :all -> + &rms_norm_impl_upcast_all/3 + + other -> + raise ArgumentError, + "expected :upcast to be either :all or :normalization, got: #{other}" + end + weight = Axon.param("weight", &Axon.Shape.norm_param(&1, opts[:channel_index]), initializer: opts[:initializer] ) - Axon.layer(&rms_norm_impl/3, [input, weight], + Axon.layer(impl, [input, weight], name: opts[:name], shift: opts[:shift], epsilon: opts[:epsilon], @@ -1134,19 +1173,36 @@ defmodule Bumblebee.Layers do ) end - defnp rms_norm_impl(input, weight, opts \\ []) do + defnp rms_norm_impl_upcast_normalization(input, weight, opts \\ []) do + opts = keyword!(opts, shift: 0.0, epsilon: 1.0e-6, channel_index: -1, mode: :train) + + normalized_input = + input + |> Nx.as_type(:f32) + |> rms_normalize(opts) + |> Nx.as_type(Nx.type(input)) + + normalized_input * (opts[:shift] + weight) + end + + defnp rms_norm_impl_upcast_all(input, weight, opts \\ []) do opts = keyword!(opts, shift: 0.0, epsilon: 1.0e-6, channel_index: -1, mode: :train) + input = Nx.as_type(input, :f32) + weight = Nx.as_type(weight, :f32) + + normalized_input = rms_normalize(input, opts) + + normalized_input * (opts[:shift] + weight) + end + + defnp rms_normalize(input, opts) do variance = input |> Nx.pow(2) |> Nx.mean(axes: [opts[:channel_index]], keep_axes: true) - x = - input - |> Nx.multiply(Nx.rsqrt(variance + opts[:epsilon])) - - x * (opts[:shift] + weight) + input * Nx.rsqrt(variance + opts[:epsilon]) end @doc """ diff --git a/lib/bumblebee/text/gemma.ex b/lib/bumblebee/text/gemma.ex index f0f5fea3..567f76bb 100644 --- a/lib/bumblebee/text/gemma.ex +++ b/lib/bumblebee/text/gemma.ex @@ -39,7 +39,7 @@ defmodule Bumblebee.Text.Gemma do doc: "the number of key value heads for each attention layer in the model" ], activation: [ - default: :gelu, + default: :gelu_approx_tanh, doc: "the activation function" ], rotary_embedding_base: [ @@ -289,7 +289,8 @@ defmodule Bumblebee.Text.Gemma do Layers.rms_norm(decoder_outputs.hidden_state, name: "output_norm", shift: 1.0, - epsilon: spec.layer_norm_epsilon + epsilon: spec.layer_norm_epsilon, + upcast: :all ) %{ @@ -309,7 +310,14 @@ defmodule Bumblebee.Text.Gemma do name: join(name, "token_embedding") ) end - |> Axon.nx(fn x -> Nx.multiply(x, Nx.sqrt(spec.hidden_size)) end) + |> Axon.nx(fn x -> + normalization_factor = + spec.hidden_size + |> Nx.tensor(type: Nx.type(x)) + |> Nx.sqrt() + + Nx.multiply(x, normalization_factor) + end) end defp decoder( @@ -332,7 +340,8 @@ defmodule Bumblebee.Text.Gemma do num_key_value_heads: spec.num_key_value_heads, hidden_size: spec.hidden_size, kernel_initializer: kernel_initializer(spec), - layer_norm: &Layers.rms_norm(&1, shift: 1.0, name: &2, epsilon: spec.layer_norm_epsilon), + layer_norm: + &Layers.rms_norm(&1, shift: 1.0, name: &2, epsilon: spec.layer_norm_epsilon, upcast: :all), ffn: &gated_ffn(&1, spec.intermediate_size, spec.hidden_size, name: &2,