Skip to content

Commit

Permalink
Update Phi-3 RoPE scaling strategies
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Sep 9, 2024
1 parent 45a265b commit 9aaeb13
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 48 deletions.
23 changes: 8 additions & 15 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1254,12 +1254,11 @@ defmodule Bumblebee.Layers do
positions_cos_sin(position, inv_frequency)

%{
type: type,
type: :longrope,
short_factor: short_factor,
long_factor: long_factor,
original_max_positions: original_max_positions
}
when type in [:su, :yarn] ->
} ->
factor =
if sequence_length > original_max_positions do
Nx.tensor(long_factor, type: :f32)
Expand All @@ -1270,18 +1269,12 @@ defmodule Bumblebee.Layers do
scale = max_positions / original_max_positions

cos_sin_factor =
cond do
scale <= 1.0 ->
1.0

type == :su ->
Nx.divide(Nx.log(scale), Nx.log(original_max_positions))
|> Nx.add(1)
|> Nx.sqrt()

type == :yarn ->
Nx.multiply(0.1, Nx.log(scale))
|> Nx.add(1.0)
if scale <= 1.0 do
1.0
else
Nx.divide(Nx.log(scale), Nx.log(original_max_positions))
|> Nx.add(1)
|> Nx.sqrt()
end

inv_frequency = inv_frequency(base, range) |> Nx.divide(factor)
Expand Down
9 changes: 4 additions & 5 deletions lib/bumblebee/text/phi3.ex
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ defmodule Bumblebee.Text.Phi3 do
doc: """
scaling configuration for rotary embedding. Currently the supported values are:
* `%{type: :su, short_factor: list(number()), long_factor: list(number()), original_max_positions: pos_integer()}`
* `%{type: :yarn, short_factor: list(number()), long_factor: list(number()), original_max_positions: pos_integer()}`
* `%{type: :longrope, short_factor: list(number()), long_factor: list(number()), original_max_positions: pos_integer()}`
"""
],
Expand Down Expand Up @@ -428,11 +426,12 @@ defmodule Bumblebee.Text.Phi3 do

case value do
%{"type" => type, "long_factor" => long_factor, "short_factor" => short_factor}
when type in ["su", "yarn"] and is_list(long_factor) and is_list(short_factor) and
when type in ["longrope", "su", "yarn"] and
is_list(long_factor) and is_list(short_factor) and
is_number(original_max_positions) ->
{:ok,
%{
type: String.to_atom(type),
type: :longrope,
long_factor: long_factor,
short_factor: short_factor,
original_max_positions: original_max_positions
Expand Down
30 changes: 2 additions & 28 deletions test/bumblebee/text/phi3_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ defmodule Bumblebee.Text.Phi3Test do
)
end

test ":base rotary embedding scaling strategy :su" do
test ":base rotary embedding scaling strategy :longrope" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
{:hf,
"bumblebee-testing/tiny-random-Phi3Model-rope_scaling-su-original_max_position_embeddings-256"}
"bumblebee-testing/tiny-random-Phi3Model-rope_scaling-longrope-original_max_position_embeddings-256"}
)

assert %Bumblebee.Text.Phi3{architecture: :base} = spec
Expand All @@ -54,32 +54,6 @@ defmodule Bumblebee.Text.Phi3Test do
)
end

test ":base rotary embedding scaling strategy :yarn" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
{:hf,
"bumblebee-testing/tiny-random-Phi3Model-rope_scaling-yarn-original_max_position_embeddings-256"}
)

assert %Bumblebee.Text.Phi3{architecture: :base} = spec

inputs = %{
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
}

outputs = Axon.predict(model, params, inputs)

assert Nx.shape(outputs.hidden_state) == {1, 10, 32}

assert_all_close(
outputs.hidden_state[[.., 1..3, 1..3]],
Nx.tensor([
[[-1.4530, 0.5995, 0.1574], [-0.2663, 1.9339, 0.5336], [1.1052, -0.1642, 0.5989]]
])
)
end

test ":for_sequence_classification" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
Expand Down

0 comments on commit 9aaeb13

Please sign in to comment.