Skip to content

Commit

Permalink
Supports loading .safetensors params (#231)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonatan Kłosko <[email protected]>
  • Loading branch information
grzuy and jonatanklosko committed Aug 4, 2023
1 parent 333ba09 commit 3f691f2
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 3 deletions.
17 changes: 16 additions & 1 deletion lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,10 @@ defmodule Bumblebee do
model,
input_template,
paths,
[params_mapping: params_mapping] ++ Keyword.take(opts, [:backend, :log_params_diff])
[
params_mapping: params_mapping,
loader_fun: filename |> Path.extname() |> params_file_loader_fun()
] ++ Keyword.take(opts, [:backend, :log_params_diff])
)

{:ok, params}
Expand Down Expand Up @@ -525,6 +528,18 @@ defmodule Bumblebee do
end
end

defp params_file_loader_fun(".safetensors") do
fn path ->
path
|> File.read!()
|> Safetensors.load!()
end
end

defp params_file_loader_fun(_) do
&Bumblebee.Conversion.PyTorch.Loader.load!/1
end

@doc """
Featurizes `input` with the given featurizer.
Expand Down
15 changes: 13 additions & 2 deletions lib/bumblebee/conversion/pytorch.ex
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,28 @@ defmodule Bumblebee.Conversion.PyTorch do
between the Axon model and the PyTorch state. For more details see
`Bumblebee.HuggingFace.Transformers.Model.params_mapping/1`
* `:loader_fun` - a 1-arity function that takes a path argument
and loads the params file. Defaults to
`Bumblebee.Conversion.PyTorch.Loader.load!/1`
"""
@spec load_params!(Axon.t(), map(), Path.t() | list(Path.t()), keyword()) :: map()
def load_params!(model, input_template, path, opts \\ []) do
opts = Keyword.validate!(opts, [:log_params_diff, :backend, params_mapping: %{}])
opts =
opts
|> Keyword.validate!([
:log_params_diff,
:backend,
params_mapping: %{},
loader_fun: &Bumblebee.Conversion.PyTorch.Loader.load!/1
])

with_default_backend(opts[:backend], fn ->
pytorch_state =
path
|> List.wrap()
|> Enum.map(fn path ->
pytorch_state = Bumblebee.Conversion.PyTorch.Loader.load!(path)
pytorch_state = opts[:loader_fun].(path)

unless state_dict?(pytorch_state) do
raise "expected a serialized model state dictionary at #{path}, but got: #{inspect(pytorch_state)}"
Expand Down
1 change: 1 addition & 0 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ defmodule Bumblebee.MixProject do
{:torchx, github: "elixir-nx/nx", sparse: "torchx", override: true, only: [:dev, :test]},
{:nx_image, "~> 0.1.0"},
{:unpickler, "~> 0.1.0"},
{:safetensors, "~> 0.1.1"},
{:castore, "~> 0.1 or ~> 1.0"},
{:jason, "~> 1.4.0"},
{:unzip, "0.8.0"},
Expand Down
1 change: 1 addition & 0 deletions mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"ranch": {:hex, :ranch, "1.8.0", "8c7a100a139fd57f17327b6413e4167ac559fbc04ca7448e9be9057311597a1d", [:make, :rebar3], [], "hexpm", "49fbcfd3682fab1f5d109351b61257676da1a2fdbe295904176d5e521a2ddfe5"},
"rustler": {:hex, :rustler, "0.29.1", "880f20ae3027bd7945def6cea767f5257bc926f33ff50c0d5d5a5315883c084d", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "109497d701861bfcd26eb8f5801fe327a8eef304f56a5b63ef61151ff44ac9b6"},
"rustler_precompiled": {:hex, :rustler_precompiled, "0.6.2", "d2218ba08a43fa331957f30481d00b666664d7e3861431b02bd3f4f30eec8e5b", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "b9048eaed8d7d14a53f758c91865cc616608a438d2595f621f6a4b32a5511709"},
"safetensors": {:hex, :safetensors, "0.1.0", "3f0d9af32c5c2d43ad8bd483e0db2182ce539d4a16119a35d5047e5dee5a1f2f", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5.3", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "2bc6122b8bbeec4efbea14f2e3fe00ebeaf83773f0323af99df5d685e8e6824f"},
"stb_image": {:hex, :stb_image, "0.6.2", "d680a418416b1d778231d1d16151be3474d187e8505e1bd524aa0d08d2de094f", [:make, :mix], [{:cc_precompiler, "~> 0.1.0", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7.0", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "231ad012f649dd2bd5ef99e9171e814f3235e8f7c45009355789ac4836044a39"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
"tokenizers": {:git, "https://github.com/elixir-nx/tokenizers.git", "90dd590d5a64863e61666c3c5ebaec2d3e51841c", []},
Expand Down
36 changes: 36 additions & 0 deletions test/bumblebee/audio/whisper_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,42 @@ defmodule Bumblebee.Text.WhisperTest do
)
end

test "base model with safetensors" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model(
{:hf, "openai/whisper-tiny"},
architecture: :base,
params_filename: "model.safetensors"
)

assert %Bumblebee.Audio.Whisper{architecture: :base} = spec

input_features = Nx.sin(Nx.iota({1, 3000, 80}, type: :f32))
decoder_input_ids = Nx.tensor([[50258, 50259, 50359, 50363]])

inputs = %{
"input_features" => input_features,
"decoder_input_ids" => decoder_input_ids
}

outputs = Axon.predict(model, params, inputs)

assert Nx.shape(outputs.hidden_state) == {1, 4, 384}

assert_all_close(
outputs.hidden_state[[.., .., 1..3]],
Nx.tensor([
[
[9.1349, 0.5695, 8.7758],
[0.0160, -7.0785, 1.1313],
[6.1074, -2.0481, -1.5687],
[5.6247, -10.3924, 7.2008]
]
]),
atol: 1.0e-4
)
end

test "for conditional generation model" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model({:hf, "openai/whisper-tiny"})
Expand Down
12 changes: 12 additions & 0 deletions test/bumblebee_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,17 @@ defmodule BumblebeeTest do

assert Enum.sort(Map.keys(params)) == Enum.sort(Map.keys(sharded_params))
end

test "supports .safetensors params file" do
assert {:ok, %{params: params}} = Bumblebee.load_model({:hf, "openai/whisper-tiny"})

assert {:ok, %{params: safetensors_params}} =
Bumblebee.load_model(
{:hf, "openai/whisper-tiny"},
params_filename: "model.safetensors"
)

assert Enum.sort(Map.keys(params)) == Enum.sort(Map.keys(safetensors_params))
end
end
end

0 comments on commit 3f691f2

Please sign in to comment.