diff --git a/lib/bumblebee/layers.ex b/lib/bumblebee/layers.ex index e6a61794..60dd0730 100644 --- a/lib/bumblebee/layers.ex +++ b/lib/bumblebee/layers.ex @@ -1254,12 +1254,11 @@ defmodule Bumblebee.Layers do positions_cos_sin(position, inv_frequency) %{ - type: type, + type: :longrope, short_factor: short_factor, long_factor: long_factor, original_max_positions: original_max_positions - } - when type in [:su, :yarn] -> + } -> factor = if sequence_length > original_max_positions do Nx.tensor(long_factor, type: :f32) @@ -1270,18 +1269,12 @@ defmodule Bumblebee.Layers do scale = max_positions / original_max_positions cos_sin_factor = - cond do - scale <= 1.0 -> - 1.0 - - type == :su -> - Nx.divide(Nx.log(scale), Nx.log(original_max_positions)) - |> Nx.add(1) - |> Nx.sqrt() - - type == :yarn -> - Nx.multiply(0.1, Nx.log(scale)) - |> Nx.add(1.0) + if scale <= 1.0 do + 1.0 + else + Nx.divide(Nx.log(scale), Nx.log(original_max_positions)) + |> Nx.add(1) + |> Nx.sqrt() end inv_frequency = inv_frequency(base, range) |> Nx.divide(factor) diff --git a/lib/bumblebee/text/phi3.ex b/lib/bumblebee/text/phi3.ex index 025ffd51..04593a57 100644 --- a/lib/bumblebee/text/phi3.ex +++ b/lib/bumblebee/text/phi3.ex @@ -60,9 +60,7 @@ defmodule Bumblebee.Text.Phi3 do doc: """ scaling configuration for rotary embedding. Currently the supported values are: - * `%{type: :su, short_factor: list(number()), long_factor: list(number()), original_max_positions: pos_integer()}` - - * `%{type: :yarn, short_factor: list(number()), long_factor: list(number()), original_max_positions: pos_integer()}` + * `%{type: :longrope, short_factor: list(number()), long_factor: list(number()), original_max_positions: pos_integer()}` """ ], @@ -428,11 +426,12 @@ defmodule Bumblebee.Text.Phi3 do case value do %{"type" => type, "long_factor" => long_factor, "short_factor" => short_factor} - when type in ["su", "yarn"] and is_list(long_factor) and is_list(short_factor) and + when type in ["longrope", "su", "yarn"] and + is_list(long_factor) and is_list(short_factor) and is_number(original_max_positions) -> {:ok, %{ - type: String.to_atom(type), + type: :longrope, long_factor: long_factor, short_factor: short_factor, original_max_positions: original_max_positions diff --git a/test/bumblebee/text/phi3_test.exs b/test/bumblebee/text/phi3_test.exs index 6ac45a9e..ce07171f 100644 --- a/test/bumblebee/text/phi3_test.exs +++ b/test/bumblebee/text/phi3_test.exs @@ -28,11 +28,11 @@ defmodule Bumblebee.Text.Phi3Test do ) end - test ":base rotary embedding scaling strategy :su" do + test ":base rotary embedding scaling strategy :longrope" do assert {:ok, %{model: model, params: params, spec: spec}} = Bumblebee.load_model( {:hf, - "bumblebee-testing/tiny-random-Phi3Model-rope_scaling-su-original_max_position_embeddings-256"} + "bumblebee-testing/tiny-random-Phi3Model-rope_scaling-longrope-original_max_position_embeddings-256"} ) assert %Bumblebee.Text.Phi3{architecture: :base} = spec @@ -54,32 +54,6 @@ defmodule Bumblebee.Text.Phi3Test do ) end - test ":base rotary embedding scaling strategy :yarn" do - assert {:ok, %{model: model, params: params, spec: spec}} = - Bumblebee.load_model( - {:hf, - "bumblebee-testing/tiny-random-Phi3Model-rope_scaling-yarn-original_max_position_embeddings-256"} - ) - - assert %Bumblebee.Text.Phi3{architecture: :base} = spec - - inputs = %{ - "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), - "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) - } - - outputs = Axon.predict(model, params, inputs) - - assert Nx.shape(outputs.hidden_state) == {1, 10, 32} - - assert_all_close( - outputs.hidden_state[[.., 1..3, 1..3]], - Nx.tensor([ - [[-1.4530, 0.5995, 0.1574], [-0.2663, 1.9339, 0.5336], [1.1052, -0.1642, 0.5989]] - ]) - ) - end - test ":for_sequence_classification" do assert {:ok, %{model: model, params: params, spec: spec}} = Bumblebee.load_model(