Skip to content

Commit

Permalink
Use model state (#375)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonatan Kłosko <[email protected]>
  • Loading branch information
seanmor5 and jonatanklosko authored Jul 30, 2024
1 parent a75f6b7 commit b2781bc
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 15 deletions.
2 changes: 1 addition & 1 deletion lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ defmodule Bumblebee do
"""
@type model_info :: %{
model: Axon.t(),
params: map(),
params: %Axon.ModelState{},
spec: Bumblebee.ModelSpec.t()
}

Expand Down
14 changes: 8 additions & 6 deletions lib/bumblebee/conversion/pytorch_params.ex
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do
`Bumblebee.Conversion.PyTorchLoader.load!/1`
"""
@spec load_params!(Axon.t(), map(), Path.t() | list(Path.t()), keyword()) :: map()
@spec load_params!(Axon.t(), map(), Path.t() | list(Path.t()), keyword()) :: %Axon.ModelState{}
def load_params!(model, input_template, path, opts \\ []) do
opts =
opts
Expand All @@ -55,25 +55,27 @@ defmodule Bumblebee.Conversion.PyTorchParams do
end)
|> Enum.reduce(&Map.merge/2)

params_expr = Axon.trace_init(model, input_template)
model_state = Axon.trace_init(model, input_template)

params_expr = model_state.data
{params, diff} = init_params(model, params_expr, pytorch_state, opts[:params_mapping])
model_state = %{model_state | data: params}

params_complete? = diff.missing == [] and diff.mismatched == []

params =
model_state =
if params_complete? do
params
model_state
else
{init_fun, _} = Axon.build(model, compiler: Nx.Defn.Evaluator)
init_fun.(input_template, params)
init_fun.(input_template, model_state)
end

if Keyword.get(opts, :log_params_diff, not params_complete?) do
log_params_diff(diff)
end

params
model_state
end)
end

Expand Down
2 changes: 1 addition & 1 deletion mix.lock
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
%{
"axon": {:git, "https://github.com/elixir-nx/axon.git", "7e0e5930ac4b8d2a89f48106b8121e103e597c89", []},
"axon": {:git, "https://github.com/elixir-nx/axon.git", "054eb4c1c224582528e8d1603ad08e7c4088f21c", []},
"bypass": {:hex, :bypass, "2.1.0", "909782781bf8e20ee86a9cabde36b259d44af8b9f38756173e8f5e2e1fabb9b1", [:mix], [{:plug, "~> 1.7", [hex: :plug, repo: "hexpm", optional: false]}, {:plug_cowboy, "~> 2.0", [hex: :plug_cowboy, repo: "hexpm", optional: false]}, {:ranch, "~> 1.3", [hex: :ranch, repo: "hexpm", optional: false]}], "hexpm", "d9b5df8fa5b7a6efa08384e9bbecfe4ce61c77d28a4282f79e02f1ef78d96b80"},
"castore": {:hex, :castore, "1.0.4", "ff4d0fb2e6411c0479b1d965a814ea6d00e51eb2f58697446e9c41a97d940b28", [:mix], [], "hexpm", "9418c1b8144e11656f0be99943db4caf04612e3eaecefb5dae9a2a87565584f8"},
"cc_precompiler": {:hex, :cc_precompiler, "0.1.8", "933a5f4da3b19ee56539a076076ce4d7716d64efc8db46fd066996a7e46e2bfd", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "176bdf4366956e456bf761b54ad70bc4103d0269ca9558fd7cee93d1b3f116db"},
Expand Down
6 changes: 3 additions & 3 deletions test/bumblebee/conversion/pytorch_params_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ defmodule Bumblebee.Conversion.PyTorchParamsTest do

log =
ExUnit.CaptureLog.capture_log(fn ->
params =
%Axon.ModelState{data: params} =
PyTorchParams.load_params!(model, input_template(), path,
params_mapping: params_mapping()
)
Expand Down Expand Up @@ -89,7 +89,7 @@ defmodule Bumblebee.Conversion.PyTorchParamsTest do

log =
ExUnit.CaptureLog.capture_log(fn ->
params =
%Axon.ModelState{data: params} =
PyTorchParams.load_params!(model, input_template(), path,
params_mapping: params_mapping()
)
Expand All @@ -107,7 +107,7 @@ defmodule Bumblebee.Conversion.PyTorchParamsTest do

log =
ExUnit.CaptureLog.capture_log(fn ->
params =
%Axon.ModelState{data: params} =
PyTorchParams.load_params!(model, input_template(), path,
params_mapping: params_mapping()
)
Expand Down
10 changes: 8 additions & 2 deletions test/bumblebee/text/roberta_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ defmodule Bumblebee.Text.RobertaTest do
assert %Bumblebee.Text.Roberta{architecture: :for_masked_language_modeling} = spec

# TODO: remove once we load tied embeddings
params = put_in(params["language_modeling_head.output"], params["embedder.token_embedding"])
params =
update_in(params, [Access.key!(:data)], fn data ->
put_in(data["language_modeling_head.output"], data["embedder.token_embedding"])
end)

inputs = %{
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
Expand Down Expand Up @@ -157,7 +160,10 @@ defmodule Bumblebee.Text.RobertaTest do
assert %Bumblebee.Text.Roberta{architecture: :for_causal_language_modeling} = spec

# TODO: remove once we load tied embeddings
params = put_in(params["language_modeling_head.output"], params["embedder.token_embedding"])
params =
update_in(params, [Access.key!(:data)], fn data ->
put_in(data["language_modeling_head.output"], data["embedder.token_embedding"])
end)

inputs = %{
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
Expand Down
4 changes: 2 additions & 2 deletions test/bumblebee_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@ defmodule BumblebeeTest do
end

test "passing :type casts params accordingly" do
assert {:ok, %{params: params}} =
assert {:ok, %{params: %Axon.ModelState{data: params}}} =
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"},
type: :bf16
)

assert Nx.type(params["decoder.blocks.0.ffn.output"]["kernel"]) == {:bf, 16}
assert Nx.type(params["decoder.blocks.0.ffn.output"]["bias"]) == {:bf, 16}

assert {:ok, %{params: params}} =
assert {:ok, %{params: %Axon.ModelState{data: params}}} =
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2Model"},
type: Axon.MixedPrecision.create_policy(params: :f16)
)
Expand Down

0 comments on commit b2781bc

Please sign in to comment.