diff --git a/lib/bumblebee/audio.ex b/lib/bumblebee/audio.ex index 0963b94c..61a30273 100644 --- a/lib/bumblebee/audio.ex +++ b/lib/bumblebee/audio.ex @@ -42,6 +42,11 @@ defmodule Bumblebee.Audio do * `:defn_options` - the options for JIT compilation. Defaults to `[]` + * `:preallocate_params` - when `true`, explicitly allocates params + on the device configured by `:defn_options`. You may want to set + this option when using partitioned serving, to allocate params + on each of the devices. Defaults to `false` + ## Examples {:ok, whisper} = Bumblebee.load_model({:hf, "openai/whisper-tiny"}) diff --git a/lib/bumblebee/audio/speech_to_text.ex b/lib/bumblebee/audio/speech_to_text.ex index 99e71a02..4392e1d8 100644 --- a/lib/bumblebee/audio/speech_to_text.ex +++ b/lib/bumblebee/audio/speech_to_text.ex @@ -11,12 +11,13 @@ defmodule Bumblebee.Audio.SpeechToText do %Text.GenerationConfig{} = generation_config, opts \\ [] ) do - opts = Keyword.validate!(opts, [:seed, :compile, defn_options: []]) + opts = Keyword.validate!(opts, [:seed, :compile, defn_options: [], preallocate_params: false]) %{model: model, params: params, spec: spec} = model_info Shared.validate_architecture!(spec, [:for_conditional_generation]) + preallocate_params = opts[:preallocate_params] defn_options = opts[:defn_options] compile = @@ -35,6 +36,8 @@ defmodule Bumblebee.Audio.SpeechToText do Nx.Serving.new( fn defn_options -> + params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + generate_fun = Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn -> inputs = %{ diff --git a/lib/bumblebee/diffusion/stable_diffusion.ex b/lib/bumblebee/diffusion/stable_diffusion.ex index bad50d07..e2072b86 100644 --- a/lib/bumblebee/diffusion/stable_diffusion.ex +++ b/lib/bumblebee/diffusion/stable_diffusion.ex @@ -62,6 +62,11 @@ defmodule Bumblebee.Diffusion.StableDiffusion do * `:defn_options` - the options for JIT compilation. Defaults to `[]` + * `:preallocate_params` - when `true`, explicitly allocates params + on the device configured by `:defn_options`. You may want to set + this option when using partitioned serving, to allocate params + on each of the devices. Defaults to `false` + ## Examples repository_id = "CompVis/stable-diffusion-v1-4" @@ -135,13 +140,15 @@ defmodule Bumblebee.Diffusion.StableDiffusion do num_images_per_prompt: 1, guidance_scale: 7.5, seed: 0, - defn_options: [] + defn_options: [], + preallocate_params: false ]) safety_checker = opts[:safety_checker] safety_checker_featurizer = opts[:safety_checker_featurizer] num_steps = opts[:num_steps] num_images_per_prompt = opts[:num_images_per_prompt] + preallocate_params = opts[:preallocate_params] defn_options = opts[:defn_options] if safety_checker != nil and safety_checker_featurizer == nil do @@ -203,11 +210,12 @@ defmodule Bumblebee.Diffusion.StableDiffusion do {safety_checker?, safety_checker[:spec], safety_checker[:params]}, safety_checker_featurizer, {compile != nil, batch_size, sequence_length}, - num_images_per_prompt + num_images_per_prompt, + preallocate_params ] Nx.Serving.new( - fn defn_options -> apply(&init/9, init_args ++ [defn_options]) end, + fn defn_options -> apply(&init/10, init_args ++ [defn_options]) end, defn_options ) |> Nx.Serving.process_options(batch_size: batch_size) @@ -224,8 +232,13 @@ defmodule Bumblebee.Diffusion.StableDiffusion do safety_checker_featurizer, {compile?, batch_size, sequence_length}, num_images_per_prompt, + preallocate_params, defn_options ) do + encoder_params = Shared.maybe_preallocate(encoder_params, preallocate_params, defn_options) + unet_params = Shared.maybe_preallocate(unet_params, preallocate_params, defn_options) + vae_params = Shared.maybe_preallocate(vae_params, preallocate_params, defn_options) + image_fun = Shared.compile_or_jit(image_fun, defn_options, compile?, fn -> text_inputs = %{ @@ -250,6 +263,10 @@ defmodule Bumblebee.Diffusion.StableDiffusion do [safety_checker_params, inputs] end) + safety_checker_params = + safety_checker_params && + Shared.maybe_preallocate(safety_checker_params, preallocate_params, defn_options) + fn inputs -> inputs = Shared.maybe_pad(inputs, batch_size) @@ -275,18 +292,22 @@ defmodule Bumblebee.Diffusion.StableDiffusion do negative_prompts = Enum.map(inputs, & &1.negative_prompt) conditional = - Bumblebee.apply_tokenizer(tokenizer, prompts, - length: sequence_length, - return_token_type_ids: false, - return_attention_mask: false - ) + Nx.with_default_backend(Nx.BinaryBackend, fn -> + Bumblebee.apply_tokenizer(tokenizer, prompts, + length: sequence_length, + return_token_type_ids: false, + return_attention_mask: false + ) + end) unconditional = - Bumblebee.apply_tokenizer(tokenizer, negative_prompts, - length: Nx.axis_size(conditional["input_ids"], 1), - return_attention_mask: false, - return_token_type_ids: false - ) + Nx.with_default_backend(Nx.BinaryBackend, fn -> + Bumblebee.apply_tokenizer(tokenizer, negative_prompts, + length: Nx.axis_size(conditional["input_ids"], 1), + return_attention_mask: false, + return_token_type_ids: false + ) + end) inputs = %{"unconditional" => unconditional, "conditional" => conditional} diff --git a/lib/bumblebee/shared.ex b/lib/bumblebee/shared.ex index f8d47166..c4f652dd 100644 --- a/lib/bumblebee/shared.ex +++ b/lib/bumblebee/shared.ex @@ -406,6 +406,18 @@ defmodule Bumblebee.Shared do end end + @doc """ + If `preallocate?` is `true`, allocates `params` using `defn_options`. + """ + @spec maybe_preallocate(map(), boolean(), keyword()) :: map() + def maybe_preallocate(params, preallocate?, defn_options) do + if preallocate? do + Nx.Defn.jit_apply(&Function.identity/1, [params], defn_options) + else + params + end + end + @doc """ Generates tokenizer implementation. """ diff --git a/lib/bumblebee/text.ex b/lib/bumblebee/text.ex index 71b120e3..1612a422 100644 --- a/lib/bumblebee/text.ex +++ b/lib/bumblebee/text.ex @@ -91,6 +91,11 @@ defmodule Bumblebee.Text do * `:defn_options` - the options for JIT compilation. Defaults to `[]` + * `:preallocate_params` - when `true`, explicitly allocates params + on the device configured by `:defn_options`. You may want to set + this option when using partitioned serving, to allocate params + on each of the devices. Defaults to `false` + ## Examples {:ok, bert} = Bumblebee.load_model({:hf, "dslim/bert-base-NER"}) @@ -152,6 +157,11 @@ defmodule Bumblebee.Text do * `:defn_options` - the options for JIT compilation. Defaults to `[]` + * `:preallocate_params` - when `true`, explicitly allocates params + on the device configured by `:defn_options`. You may want to set + this option when using partitioned serving, to allocate params + on each of the devices. Defaults to `false` + * `:stream` - when `true`, the serving immediately returns a stream that emits text chunks as they are generated. Note that when using streaming, only a single input can be given to the @@ -242,6 +252,11 @@ defmodule Bumblebee.Text do * `:defn_options` - the options for JIT compilation. Defaults to `[]` + * `:preallocate_params` - when `true`, explicitly allocates params + on the device configured by `:defn_options`. You may want to set + this option when using partitioned serving, to allocate params + on each of the devices. Defaults to `false` + ## Examples {:ok, model_info} = Bumblebee.load_model({:hf, "facebook/blenderbot-400M-distill"}) @@ -311,6 +326,11 @@ defmodule Bumblebee.Text do * `:defn_options` - the options for JIT compilation. Defaults to `[]` + * `:preallocate_params` - when `true`, explicitly allocates params + on the device configured by `:defn_options`. You may want to set + this option when using partitioned serving, to allocate params + on each of the devices. Defaults to `false` + ## Examples {:ok, bertweet} = Bumblebee.load_model({:hf, "finiteautomata/bertweet-base-sentiment-analysis"}) @@ -379,6 +399,11 @@ defmodule Bumblebee.Text do * `:defn_options` - the options for JIT compilation. Defaults to `[]` + * `:preallocate_params` - when `true`, explicitly allocates params + on the device configured by `:defn_options`. You may want to set + this option when using partitioned serving, to allocate params + on each of the devices. Defaults to `false` + ## Examples {:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-large"}) @@ -444,6 +469,11 @@ defmodule Bumblebee.Text do * `:defn_options` - the options for JIT compilation. Defaults to `[]` + * `:preallocate_params` - when `true`, explicitly allocates params + on the device configured by `:defn_options`. You may want to set + this option when using partitioned serving, to allocate params + on each of the devices. Defaults to `false` + ## Examples {:ok, bert} = Bumblebee.load_model({:hf, "bert-base-uncased"}) @@ -513,6 +543,11 @@ defmodule Bumblebee.Text do * `:defn_options` - the options for JIT compilation. Defaults to `[]` + * `:preallocate_params` - when `true`, explicitly allocates params + on the device configured by `:defn_options`. You may want to set + this option when using partitioned serving, to allocate params + on each of the devices. Defaults to `false` + ## Examples {:ok, roberta} = Bumblebee.load_model({:hf, "deepset/roberta-base-squad2"}) @@ -581,6 +616,11 @@ defmodule Bumblebee.Text do * `:defn_options` - the options for JIT compilation. Defaults to `[]` + * `:preallocate_params` - when `true`, explicitly allocates params + on the device configured by `:defn_options`. You may want to set + this option when using partitioned serving, to allocate params + on each of the devices. Defaults to `false` + ## Examples {:ok, model} = Bumblebee.load_model({:hf, "facebook/bart-large-mnli"}) diff --git a/lib/bumblebee/text/conversation.ex b/lib/bumblebee/text/conversation.ex index ddff98b4..8e4ed38d 100644 --- a/lib/bumblebee/text/conversation.ex +++ b/lib/bumblebee/text/conversation.ex @@ -19,7 +19,7 @@ defmodule Bumblebee.Text.Conversation do %Text.GenerationConfig{} = generation_config, opts \\ [] ) do - opts = Keyword.validate!(opts, [:seed, :compile, defn_options: []]) + opts = Keyword.validate!(opts, [:seed, :compile, defn_options: [], preallocate_params: false]) %{model: model, params: params, spec: spec} = model_info @@ -28,6 +28,7 @@ defmodule Bumblebee.Text.Conversation do :for_conditional_generation ]) + preallocate_params = opts[:preallocate_params] defn_options = opts[:defn_options] compile = @@ -49,6 +50,8 @@ defmodule Bumblebee.Text.Conversation do Nx.Serving.new( fn batch_key, defn_options -> + params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + generate_fun = Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn -> {:sequence_length, sequence_length} = batch_key @@ -88,12 +91,14 @@ defmodule Bumblebee.Text.Conversation do end inputs = - Bumblebee.apply_tokenizer(tokenizer, texts, - length: sequence_length, - pad_direction: :left, - truncate_direction: :left, - return_token_type_ids: false - ) + Nx.with_default_backend(Nx.BinaryBackend, fn -> + Bumblebee.apply_tokenizer(tokenizer, texts, + length: sequence_length, + pad_direction: :left, + truncate_direction: :left, + return_token_type_ids: false + ) + end) batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) diff --git a/lib/bumblebee/text/fill_mask.ex b/lib/bumblebee/text/fill_mask.ex index 36956f87..46b1cfac 100644 --- a/lib/bumblebee/text/fill_mask.ex +++ b/lib/bumblebee/text/fill_mask.ex @@ -6,9 +6,12 @@ defmodule Bumblebee.Text.FillMask do def fill_mask(model_info, tokenizer, opts \\ []) do %{model: model, params: params, spec: spec} = model_info Shared.validate_architecture!(spec, :for_masked_language_modeling) - opts = Keyword.validate!(opts, [:compile, top_k: 5, defn_options: []]) + + opts = + Keyword.validate!(opts, [:compile, top_k: 5, defn_options: [], preallocate_params: false]) top_k = opts[:top_k] + preallocate_params = opts[:preallocate_params] defn_options = opts[:defn_options] compile = @@ -51,6 +54,8 @@ defmodule Bumblebee.Text.FillMask do Nx.Serving.new( fn batch_key, defn_options -> + params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + scores_fun = Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> {:sequence_length, sequence_length} = batch_key @@ -77,10 +82,12 @@ defmodule Bumblebee.Text.FillMask do texts = for text <- texts, do: validate_text!(text, mask_token) inputs = - Bumblebee.apply_tokenizer(tokenizer, texts, - length: sequence_length, - return_token_type_ids: false - ) + Nx.with_default_backend(Nx.BinaryBackend, fn -> + Bumblebee.apply_tokenizer(tokenizer, texts, + length: sequence_length, + return_token_type_ids: false + ) + end) batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index d3fa6c52..45709765 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -798,7 +798,14 @@ defmodule Bumblebee.Text.Generation do @doc false def generation(model_info, tokenizer, %Text.GenerationConfig{} = generation_config, opts \\ []) do - opts = Keyword.validate!(opts, [:seed, :compile, defn_options: [], stream: false]) + opts = + Keyword.validate!(opts, [ + :seed, + :compile, + defn_options: [], + preallocate_params: false, + stream: false + ]) %{model: model, params: params, spec: spec} = model_info @@ -807,6 +814,7 @@ defmodule Bumblebee.Text.Generation do :for_causal_language_modeling ]) + preallocate_params = opts[:preallocate_params] defn_options = opts[:defn_options] compile = @@ -825,6 +833,8 @@ defmodule Bumblebee.Text.Generation do Nx.Serving.new( fn batch_key, defn_options -> + params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + generate_fun = Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn -> {:sequence_length, sequence_length} = batch_key @@ -853,11 +863,13 @@ defmodule Bumblebee.Text.Generation do {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1) inputs = - Bumblebee.apply_tokenizer(tokenizer, texts, - length: sequence_length, - pad_direction: :left, - return_token_type_ids: false - ) + Nx.with_default_backend(Nx.BinaryBackend, fn -> + Bumblebee.apply_tokenizer(tokenizer, texts, + length: sequence_length, + pad_direction: :left, + return_token_type_ids: false + ) + end) batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) diff --git a/lib/bumblebee/text/question_answering.ex b/lib/bumblebee/text/question_answering.ex index 2e462671..c4f8a8e2 100644 --- a/lib/bumblebee/text/question_answering.ex +++ b/lib/bumblebee/text/question_answering.ex @@ -8,8 +8,9 @@ defmodule Bumblebee.Text.QuestionAnswering do %{model: model, params: params, spec: spec} = model_info Shared.validate_architecture!(spec, :for_question_answering) - opts = Keyword.validate!(opts, [:compile, defn_options: []]) + opts = Keyword.validate!(opts, [:compile, defn_options: [], preallocate_params: false]) + preallocate_params = opts[:preallocate_params] defn_options = opts[:defn_options] compile = @@ -35,6 +36,8 @@ defmodule Bumblebee.Text.QuestionAnswering do Nx.Serving.new( fn batch_key, defn_options -> + params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + predict_fun = Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> {:sequence_length, sequence_length} = batch_key @@ -70,11 +73,13 @@ defmodule Bumblebee.Text.QuestionAnswering do end) all_inputs = - Bumblebee.apply_tokenizer(tokenizer, raw_inputs, - length: sequence_length, - return_token_type_ids: true, - return_offsets: true - ) + Nx.with_default_backend(Nx.BinaryBackend, fn -> + Bumblebee.apply_tokenizer(tokenizer, raw_inputs, + length: sequence_length, + return_token_type_ids: true, + return_offsets: true + ) + end) inputs = Map.take(all_inputs, ["input_ids", "attention_mask", "token_type_ids"]) diff --git a/lib/bumblebee/text/text_classification.ex b/lib/bumblebee/text/text_classification.ex index 9f4074f9..7f2b3652 100644 --- a/lib/bumblebee/text/text_classification.ex +++ b/lib/bumblebee/text/text_classification.ex @@ -8,10 +8,17 @@ defmodule Bumblebee.Text.TextClassification do Shared.validate_architecture!(spec, :for_sequence_classification) opts = - Keyword.validate!(opts, [:compile, top_k: 5, scores_function: :softmax, defn_options: []]) + Keyword.validate!(opts, [ + :compile, + top_k: 5, + scores_function: :softmax, + defn_options: [], + preallocate_params: false + ]) top_k = opts[:top_k] scores_function = opts[:scores_function] + preallocate_params = opts[:preallocate_params] defn_options = opts[:defn_options] compile = @@ -35,6 +42,8 @@ defmodule Bumblebee.Text.TextClassification do Nx.Serving.new( fn batch_key, defn_options -> + params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + scores_fun = Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> {:sequence_length, sequence_length} = batch_key @@ -59,10 +68,12 @@ defmodule Bumblebee.Text.TextClassification do {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1) inputs = - Bumblebee.apply_tokenizer(tokenizer, texts, - length: sequence_length, - return_token_type_ids: false - ) + Nx.with_default_backend(Nx.BinaryBackend, fn -> + Bumblebee.apply_tokenizer(tokenizer, texts, + length: sequence_length, + return_token_type_ids: false + ) + end) batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) diff --git a/lib/bumblebee/text/text_embedding.ex b/lib/bumblebee/text/text_embedding.ex index a293b8c0..3932f24c 100644 --- a/lib/bumblebee/text/text_embedding.ex +++ b/lib/bumblebee/text/text_embedding.ex @@ -12,12 +12,14 @@ defmodule Bumblebee.Text.TextEmbedding do output_attribute: :pooled_state, output_pool: nil, embedding_processor: nil, - defn_options: [] + defn_options: [], + preallocate_params: false ]) output_attribute = opts[:output_attribute] output_pool = opts[:output_pool] embedding_processor = opts[:embedding_processor] + preallocate_params = opts[:preallocate_params] defn_options = opts[:defn_options] compile = @@ -80,6 +82,8 @@ defmodule Bumblebee.Text.TextEmbedding do Nx.Serving.new( fn batch_key, defn_options -> + params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + embedding_fun = Shared.compile_or_jit(embedding_fun, defn_options, compile != nil, fn -> {:sequence_length, sequence_length} = batch_key @@ -104,10 +108,12 @@ defmodule Bumblebee.Text.TextEmbedding do {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1) inputs = - Bumblebee.apply_tokenizer(tokenizer, texts, - length: sequence_length, - return_token_type_ids: false - ) + Nx.with_default_backend(Nx.BinaryBackend, fn -> + Bumblebee.apply_tokenizer(tokenizer, texts, + length: sequence_length, + return_token_type_ids: false + ) + end) batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) diff --git a/lib/bumblebee/text/token_classification.ex b/lib/bumblebee/text/token_classification.ex index 25fd0f2f..0e8f8264 100644 --- a/lib/bumblebee/text/token_classification.ex +++ b/lib/bumblebee/text/token_classification.ex @@ -14,12 +14,14 @@ defmodule Bumblebee.Text.TokenClassification do :compile, scores_function: :softmax, ignored_labels: ["O"], - defn_options: [] + defn_options: [], + preallocate_params: false ]) aggregation = opts[:aggregation] ignored_labels = opts[:ignored_labels] scores_function = opts[:scores_function] + preallocate_params = opts[:preallocate_params] defn_options = opts[:defn_options] compile = @@ -43,6 +45,8 @@ defmodule Bumblebee.Text.TokenClassification do Nx.Serving.new( fn batch_key, defn_options -> + params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + scores_fun = Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> {:sequence_length, sequence_length} = batch_key @@ -67,11 +71,13 @@ defmodule Bumblebee.Text.TokenClassification do {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1) all_inputs = - Bumblebee.apply_tokenizer(tokenizer, texts, - length: sequence_length, - return_special_tokens_mask: true, - return_offsets: true - ) + Nx.with_default_backend(Nx.BinaryBackend, fn -> + Bumblebee.apply_tokenizer(tokenizer, texts, + length: sequence_length, + return_special_tokens_mask: true, + return_offsets: true + ) + end) inputs = Map.take(all_inputs, ["input_ids", "attention_mask"]) diff --git a/lib/bumblebee/text/zero_shot_classification.ex b/lib/bumblebee/text/zero_shot_classification.ex index 0f5e11d0..af3390ad 100644 --- a/lib/bumblebee/text/zero_shot_classification.ex +++ b/lib/bumblebee/text/zero_shot_classification.ex @@ -13,11 +13,13 @@ defmodule Bumblebee.Text.ZeroShotClassification do :compile, hypothesis_template: &default_hypothesis_template/1, top_k: 5, - defn_options: [] + defn_options: [], + preallocate_params: false ]) hypothesis_template = opts[:hypothesis_template] top_k = opts[:top_k] + preallocate_params = opts[:preallocate_params] defn_options = opts[:defn_options] hypotheses = Enum.map(labels, hypothesis_template) @@ -56,6 +58,8 @@ defmodule Bumblebee.Text.ZeroShotClassification do Nx.Serving.new( fn batch_key, defn_options -> + params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + scores_fun = Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> {:sequence_length, sequence_length} = batch_key @@ -85,10 +89,12 @@ defmodule Bumblebee.Text.ZeroShotClassification do pairs = for text <- texts, hypothesis <- hypotheses, do: {text, hypothesis} inputs = - Bumblebee.apply_tokenizer(tokenizer, pairs, - length: sequence_length, - return_token_type_ids: false - ) + Nx.with_default_backend(Nx.BinaryBackend, fn -> + Bumblebee.apply_tokenizer(tokenizer, pairs, + length: sequence_length, + return_token_type_ids: false + ) + end) batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) diff --git a/lib/bumblebee/vision.ex b/lib/bumblebee/vision.ex index d9765b54..f8d61d5a 100644 --- a/lib/bumblebee/vision.ex +++ b/lib/bumblebee/vision.ex @@ -50,6 +50,11 @@ defmodule Bumblebee.Vision do * `:defn_options` - the options for JIT compilation. Defaults to `[]` + * `:preallocate_params` - when `true`, explicitly allocates params + on the device configured by `:defn_options`. You may want to set + this option when using partitioned serving, to allocate params + on each of the devices. Defaults to `false` + ## Examples {:ok, resnet} = Bumblebee.load_model({:hf, "microsoft/resnet-50"}) @@ -106,6 +111,11 @@ defmodule Bumblebee.Vision do * `:defn_options` - the options for JIT compilation. Defaults to `[]` + * `:preallocate_params` - when `true`, explicitly allocates params + on the device configured by `:defn_options`. You may want to set + this option when using partitioned serving, to allocate params + on each of the devices. Defaults to `false` + ## Examples {:ok, blip} = Bumblebee.load_model({:hf, "Salesforce/blip-image-captioning-base"}) @@ -166,6 +176,11 @@ defmodule Bumblebee.Vision do * `:defn_options` - the options for JIT compilation. Defaults to `[]` + * `:preallocate_params` - when `true`, explicitly allocates params + on the device configured by `:defn_options`. You may want to set + this option when using partitioned serving, to allocate params + on each of the devices. Defaults to `false` + ## Examples {:ok, clip} = diff --git a/lib/bumblebee/vision/image_classification.ex b/lib/bumblebee/vision/image_classification.ex index 48d9eb11..cf44a15b 100644 --- a/lib/bumblebee/vision/image_classification.ex +++ b/lib/bumblebee/vision/image_classification.ex @@ -12,10 +12,17 @@ defmodule Bumblebee.Vision.ImageClassification do ]) opts = - Keyword.validate!(opts, [:compile, top_k: 5, scores_function: :softmax, defn_options: []]) + Keyword.validate!(opts, [ + :compile, + top_k: 5, + scores_function: :softmax, + defn_options: [], + preallocate_params: false + ]) top_k = opts[:top_k] scores_function = opts[:scores_function] + preallocate_params = opts[:preallocate_params] defn_options = opts[:defn_options] compile = @@ -36,6 +43,8 @@ defmodule Bumblebee.Vision.ImageClassification do Nx.Serving.new( fn defn_options -> + params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + scores_fun = Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> inputs = %{ diff --git a/lib/bumblebee/vision/image_embedding.ex b/lib/bumblebee/vision/image_embedding.ex index ba975647..40a2add7 100644 --- a/lib/bumblebee/vision/image_embedding.ex +++ b/lib/bumblebee/vision/image_embedding.ex @@ -11,11 +11,13 @@ defmodule Bumblebee.Vision.ImageEmbedding do :compile, output_attribute: :pooled_state, embedding_processor: nil, - defn_options: [] + defn_options: [], + preallocate_params: false ]) output_attribute = opts[:output_attribute] embedding_processor = opts[:embedding_processor] + preallocate_params = opts[:preallocate_params] defn_options = opts[:defn_options] compile = @@ -57,6 +59,8 @@ defmodule Bumblebee.Vision.ImageEmbedding do Nx.Serving.new( fn defn_options -> + params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + embedding_fun = Shared.compile_or_jit(embedding_fun, defn_options, compile != nil, fn -> inputs = %{ diff --git a/lib/bumblebee/vision/image_to_text.ex b/lib/bumblebee/vision/image_to_text.ex index 80de4e92..dbdeb93e 100644 --- a/lib/bumblebee/vision/image_to_text.ex +++ b/lib/bumblebee/vision/image_to_text.ex @@ -11,12 +11,13 @@ defmodule Bumblebee.Vision.ImageToText do %Text.GenerationConfig{} = generation_config, opts \\ [] ) do - opts = Keyword.validate!(opts, [:seed, :compile, defn_options: []]) + opts = Keyword.validate!(opts, [:seed, :compile, defn_options: [], preallocate_params: false]) %{model: model, params: params, spec: spec} = model_info Shared.validate_architecture!(spec, [:for_conditional_generation]) + preallocate_params = opts[:preallocate_params] defn_options = opts[:defn_options] compile = @@ -33,6 +34,8 @@ defmodule Bumblebee.Vision.ImageToText do Nx.Serving.new( fn defn_options -> + params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + generate_fun = Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn -> inputs = %{