Skip to content

Commit

Permalink
Add :spec_overrides to Bumblebee.load_model/2 (#340)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Feb 21, 2024
1 parent f739e0a commit beca321
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
27 changes: 23 additions & 4 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,10 @@ defmodule Bumblebee do
* `:spec` - the model specification to use when building the model.
By default the specification is loaded using `load_spec/2`
* `:spec_overrides` - additional options to configure the model
specification with. This is a shorthand for using `load_spec/2`,
`configure/2` and passing as `:spec`
* `:module` - the model specification module. By default it is
inferred from the configuration file, if that is not possible,
it must be specified explicitly
Expand Down Expand Up @@ -534,6 +538,11 @@ defmodule Bumblebee do
spec = Bumblebee.configure(spec, num_labels: 10)
{:ok, resnet} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}, spec: spec)
Or as a shorthand, you can pass just the options to override:
{:ok, resnet} =
Bumblebee.load_model({:hf, "microsoft/resnet-50"}, spec_overrides: [num_labels: 10])
"""
@doc type: :model
@spec load_model(repository(), keyword()) :: {:ok, model_info()} | {:error, String.t()}
Expand All @@ -543,6 +552,7 @@ defmodule Bumblebee do
opts =
Keyword.validate!(opts, [
:spec,
:spec_overrides,
:module,
:architecture,
:params_variant,
Expand All @@ -561,10 +571,19 @@ defmodule Bumblebee do
end

defp maybe_load_model_spec(opts, repository, repo_files) do
if spec = opts[:spec] do
{:ok, spec}
else
do_load_spec(repository, repo_files, opts[:module], opts[:architecture])
spec_result =
if spec = opts[:spec] do
{:ok, spec}
else
do_load_spec(repository, repo_files, opts[:module], opts[:architecture])
end

with {:ok, spec} <- spec_result do
if options = opts[:spec_overrides] do
{:ok, configure(spec, options)}
else
{:ok, spec}
end
end
end

Expand Down
7 changes: 1 addition & 6 deletions test/bumblebee/vision/dino_v2_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,9 @@ defmodule Bumblebee.Vision.DinoV2Test do
end

test ":backbone with different feature map subset" do
assert {:ok, spec} =
Bumblebee.load_spec({:hf, "hf-internal-testing/tiny-random-Dinov2Backbone"})

spec = Bumblebee.configure(spec, backbone_output_indices: [0, 2])

assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-Dinov2Backbone"},
spec: spec
spec_overrides: [backbone_output_indices: [0, 2]]
)

assert %Bumblebee.Vision.DinoV2{architecture: :backbone} = spec
Expand Down

0 comments on commit beca321

Please sign in to comment.