diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 1f034e6d..2ea19fca 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -494,7 +494,10 @@ defmodule Bumblebee do model, input_template, paths, - [params_mapping: params_mapping] ++ Keyword.take(opts, [:backend, :log_params_diff]) + [ + params_mapping: params_mapping, + loader_fun: filename |> Path.extname() |> params_file_loader_fun() + ] ++ Keyword.take(opts, [:backend, :log_params_diff]) ) {:ok, params} @@ -525,6 +528,18 @@ defmodule Bumblebee do end end + defp params_file_loader_fun(".safetensors") do + fn path -> + path + |> File.read!() + |> Safetensors.load!() + end + end + + defp params_file_loader_fun(_) do + &Bumblebee.Conversion.PyTorch.Loader.load!/1 + end + @doc """ Featurizes `input` with the given featurizer. diff --git a/lib/bumblebee/conversion/pytorch.ex b/lib/bumblebee/conversion/pytorch.ex index 7390c7d9..45f59a01 100644 --- a/lib/bumblebee/conversion/pytorch.ex +++ b/lib/bumblebee/conversion/pytorch.ex @@ -24,17 +24,28 @@ defmodule Bumblebee.Conversion.PyTorch do between the Axon model and the PyTorch state. For more details see `Bumblebee.HuggingFace.Transformers.Model.params_mapping/1` + * `:loader_fun` - a 1-arity function that takes a path argument + and loads the params file. Defaults to + `Bumblebee.Conversion.PyTorch.Loader.load!/1` + """ @spec load_params!(Axon.t(), map(), Path.t() | list(Path.t()), keyword()) :: map() def load_params!(model, input_template, path, opts \\ []) do - opts = Keyword.validate!(opts, [:log_params_diff, :backend, params_mapping: %{}]) + opts = + opts + |> Keyword.validate!([ + :log_params_diff, + :backend, + params_mapping: %{}, + loader_fun: &Bumblebee.Conversion.PyTorch.Loader.load!/1 + ]) with_default_backend(opts[:backend], fn -> pytorch_state = path |> List.wrap() |> Enum.map(fn path -> - pytorch_state = Bumblebee.Conversion.PyTorch.Loader.load!(path) + pytorch_state = opts[:loader_fun].(path) unless state_dict?(pytorch_state) do raise "expected a serialized model state dictionary at #{path}, but got: #{inspect(pytorch_state)}" diff --git a/mix.exs b/mix.exs index 7daf2abc..9374b05a 100644 --- a/mix.exs +++ b/mix.exs @@ -42,6 +42,7 @@ defmodule Bumblebee.MixProject do {:torchx, github: "elixir-nx/nx", sparse: "torchx", override: true, only: [:dev, :test]}, {:nx_image, "~> 0.1.0"}, {:unpickler, "~> 0.1.0"}, + {:safetensors, "~> 0.1.1"}, {:castore, "~> 0.1 or ~> 1.0"}, {:jason, "~> 1.4.0"}, {:unzip, "0.8.0"}, diff --git a/mix.lock b/mix.lock index ffcc214b..d03f5fa3 100644 --- a/mix.lock +++ b/mix.lock @@ -29,6 +29,7 @@ "ranch": {:hex, :ranch, "1.8.0", "8c7a100a139fd57f17327b6413e4167ac559fbc04ca7448e9be9057311597a1d", [:make, :rebar3], [], "hexpm", "49fbcfd3682fab1f5d109351b61257676da1a2fdbe295904176d5e521a2ddfe5"}, "rustler": {:hex, :rustler, "0.29.1", "880f20ae3027bd7945def6cea767f5257bc926f33ff50c0d5d5a5315883c084d", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "109497d701861bfcd26eb8f5801fe327a8eef304f56a5b63ef61151ff44ac9b6"}, "rustler_precompiled": {:hex, :rustler_precompiled, "0.6.2", "d2218ba08a43fa331957f30481d00b666664d7e3861431b02bd3f4f30eec8e5b", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "b9048eaed8d7d14a53f758c91865cc616608a438d2595f621f6a4b32a5511709"}, + "safetensors": {:hex, :safetensors, "0.1.0", "3f0d9af32c5c2d43ad8bd483e0db2182ce539d4a16119a35d5047e5dee5a1f2f", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5.3", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "2bc6122b8bbeec4efbea14f2e3fe00ebeaf83773f0323af99df5d685e8e6824f"}, "stb_image": {:hex, :stb_image, "0.6.2", "d680a418416b1d778231d1d16151be3474d187e8505e1bd524aa0d08d2de094f", [:make, :mix], [{:cc_precompiler, "~> 0.1.0", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7.0", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "231ad012f649dd2bd5ef99e9171e814f3235e8f7c45009355789ac4836044a39"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, "tokenizers": {:git, "https://github.com/elixir-nx/tokenizers.git", "90dd590d5a64863e61666c3c5ebaec2d3e51841c", []}, diff --git a/test/bumblebee/audio/whisper_test.exs b/test/bumblebee/audio/whisper_test.exs index 0b866d6a..7152aa27 100644 --- a/test/bumblebee/audio/whisper_test.exs +++ b/test/bumblebee/audio/whisper_test.exs @@ -38,6 +38,42 @@ defmodule Bumblebee.Text.WhisperTest do ) end + test "base model with safetensors" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:hf, "openai/whisper-tiny"}, + architecture: :base, + params_filename: "model.safetensors" + ) + + assert %Bumblebee.Audio.Whisper{architecture: :base} = spec + + input_features = Nx.sin(Nx.iota({1, 3000, 80}, type: :f32)) + decoder_input_ids = Nx.tensor([[50258, 50259, 50359, 50363]]) + + inputs = %{ + "input_features" => input_features, + "decoder_input_ids" => decoder_input_ids + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 4, 384} + + assert_all_close( + outputs.hidden_state[[.., .., 1..3]], + Nx.tensor([ + [ + [9.1349, 0.5695, 8.7758], + [0.0160, -7.0785, 1.1313], + [6.1074, -2.0481, -1.5687], + [5.6247, -10.3924, 7.2008] + ] + ]), + atol: 1.0e-4 + ) + end + test "for conditional generation model" do assert {:ok, %{model: model, params: params, spec: spec}} = Bumblebee.load_model({:hf, "openai/whisper-tiny"}) diff --git a/test/bumblebee_test.exs b/test/bumblebee_test.exs index 8f8a816e..c689d938 100644 --- a/test/bumblebee_test.exs +++ b/test/bumblebee_test.exs @@ -19,5 +19,17 @@ defmodule BumblebeeTest do assert Enum.sort(Map.keys(params)) == Enum.sort(Map.keys(sharded_params)) end + + test "supports .safetensors params file" do + assert {:ok, %{params: params}} = Bumblebee.load_model({:hf, "openai/whisper-tiny"}) + + assert {:ok, %{params: safetensors_params}} = + Bumblebee.load_model( + {:hf, "openai/whisper-tiny"}, + params_filename: "model.safetensors" + ) + + assert Enum.sort(Map.keys(params)) == Enum.sort(Map.keys(safetensors_params)) + end end end