Skip to content

Commit

Permalink
Up
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed May 31, 2024
1 parent c8f63da commit b508401
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
2 changes: 1 addition & 1 deletion lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ defmodule Bumblebee do
"""
@type model_info :: %{
model: Axon.t(),
params: map(),
params: %Axon.ModelState{},
spec: Bumblebee.ModelSpec.t()
}

Expand Down
21 changes: 11 additions & 10 deletions lib/bumblebee/conversion/pytorch_params.ex
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do
`Bumblebee.Conversion.PyTorchLoader.load!/1`
"""
@spec load_params!(Axon.t(), map(), Path.t() | list(Path.t()), keyword()) :: map()
@spec load_params!(Axon.t(), map(), Path.t() | list(Path.t()), keyword()) :: %Axon.ModelState{}
def load_params!(model, input_template, path, opts \\ []) do
opts =
opts
Expand All @@ -55,25 +55,27 @@ defmodule Bumblebee.Conversion.PyTorchParams do
end)
|> Enum.reduce(&Map.merge/2)

params_expr = Axon.trace_init(model, input_template)
model_state = Axon.trace_init(model, input_template)

params_expr = model_state.data
{params, diff} = init_params(model, params_expr, pytorch_state, opts[:params_mapping])
model_state = %{model_state | data: params}

params_complete? = diff.missing == [] and diff.mismatched == []

params =
model_state =
if params_complete? do
params
model_state
else
{init_fun, _} = Axon.build(model, compiler: Nx.Defn.Evaluator)
init_fun.(input_template, params)
init_fun.(input_template, model_state)
end

if Keyword.get(opts, :log_params_diff, not params_complete?) do
log_params_diff(diff)
end

params
model_state
end)
end

Expand All @@ -100,13 +102,13 @@ defmodule Bumblebee.Conversion.PyTorchParams do

{params, diff} =
layers
|> Enum.filter(fn {_layer, layer_name} -> params_expr.data[layer_name] end)
|> Enum.filter(fn {_layer, layer_name} -> params_expr[layer_name] end)
|> Enum.map_reduce(diff, fn {layer, layer_name}, diff ->
params_source = params_source(layer_name, prefixes, params_mapping)

{params, diff} =
Enum.reduce(layer.parameters, {[], diff}, fn param, {params, diff} ->
param_expr = params_expr.data[layer_name][param.name]
param_expr = params_expr[layer_name][param.name]

{sources, builder_fun} =
case params_source do
Expand Down Expand Up @@ -168,8 +170,7 @@ defmodule Bumblebee.Conversion.PyTorchParams do
{{layer_name, Map.new(params)}, diff}
end)

params_data = Map.new(params)
params = %{params_expr | data: params_data}
params = Map.new(params)

diff = %{
missing: Enum.reverse(diff.missing),
Expand Down
14 changes: 8 additions & 6 deletions test/bumblebee/text/roberta_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ defmodule Bumblebee.Text.RobertaTest do
assert %Bumblebee.Text.Roberta{architecture: :for_masked_language_modeling} = spec

# TODO: remove once we load tied embeddings
params = update_in(params, [Access.key!(:data)], fn data ->
put_in(data["language_modeling_head.output"], data["embedder.token_embedding"])
end)
params =
update_in(params, [Access.key!(:data)], fn data ->
put_in(data["language_modeling_head.output"], data["embedder.token_embedding"])
end)

inputs = %{
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
Expand Down Expand Up @@ -159,9 +160,10 @@ defmodule Bumblebee.Text.RobertaTest do
assert %Bumblebee.Text.Roberta{architecture: :for_causal_language_modeling} = spec

# TODO: remove once we load tied embeddings
params = update_in(params, [Access.key!(:data)], fn data ->
put_in(data["language_modeling_head.output"], data["embedder.token_embedding"])
end)
params =
update_in(params, [Access.key!(:data)], fn data ->
put_in(data["language_modeling_head.output"], data["embedder.token_embedding"])
end)

inputs = %{
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
Expand Down

0 comments on commit b508401

Please sign in to comment.