diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 63150243..1f0ea9b4 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -492,6 +492,10 @@ defmodule Bumblebee do * `:spec` - the model specification to use when building the model. By default the specification is loaded using `load_spec/2` + * `:spec_overrides` - additional options to configure the model + specification with. This is a shorthand for using `load_spec/2`, + `configure/2` and passing as `:spec` + * `:module` - the model specification module. By default it is inferred from the configuration file, if that is not possible, it must be specified explicitly @@ -534,6 +538,11 @@ defmodule Bumblebee do spec = Bumblebee.configure(spec, num_labels: 10) {:ok, resnet} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}, spec: spec) + Or as a shorthand, you can pass just the options to override: + + {:ok, resnet} = + Bumblebee.load_model({:hf, "microsoft/resnet-50"}, spec_overrides: [num_labels: 10]) + """ @doc type: :model @spec load_model(repository(), keyword()) :: {:ok, model_info()} | {:error, String.t()} @@ -543,6 +552,7 @@ defmodule Bumblebee do opts = Keyword.validate!(opts, [ :spec, + :spec_overrides, :module, :architecture, :params_variant, @@ -561,10 +571,19 @@ defmodule Bumblebee do end defp maybe_load_model_spec(opts, repository, repo_files) do - if spec = opts[:spec] do - {:ok, spec} - else - do_load_spec(repository, repo_files, opts[:module], opts[:architecture]) + spec_result = + if spec = opts[:spec] do + {:ok, spec} + else + do_load_spec(repository, repo_files, opts[:module], opts[:architecture]) + end + + with {:ok, spec} <- spec_result do + if options = opts[:spec_overrides] do + {:ok, configure(spec, options)} + else + {:ok, spec} + end end end diff --git a/test/bumblebee/vision/dino_v2_test.exs b/test/bumblebee/vision/dino_v2_test.exs index c9140cc1..3151f63b 100644 --- a/test/bumblebee/vision/dino_v2_test.exs +++ b/test/bumblebee/vision/dino_v2_test.exs @@ -143,14 +143,9 @@ defmodule Bumblebee.Vision.DinoV2Test do end test ":backbone with different feature map subset" do - assert {:ok, spec} = - Bumblebee.load_spec({:hf, "hf-internal-testing/tiny-random-Dinov2Backbone"}) - - spec = Bumblebee.configure(spec, backbone_output_indices: [0, 2]) - assert {:ok, %{model: model, params: params, spec: spec}} = Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-Dinov2Backbone"}, - spec: spec + spec_overrides: [backbone_output_indices: [0, 2]] ) assert %Bumblebee.Vision.DinoV2{architecture: :backbone} = spec