From 9388b28be388474e895a12f4ffbb3b299503cace Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 10 Sep 2024 16:51:23 +0200 Subject: [PATCH] Suffix defn cache path in servings (#397) --- lib/bumblebee/audio/speech_to_text_whisper.ex | 2 +- lib/bumblebee/diffusion/stable_diffusion.ex | 4 ++-- .../diffusion/stable_diffusion_controlnet.ex | 4 ++-- lib/bumblebee/shared.ex | 23 ++++++++++++++++++- lib/bumblebee/text/fill_mask.ex | 4 +++- lib/bumblebee/text/question_answering.ex | 4 +++- lib/bumblebee/text/text_classification.ex | 4 +++- lib/bumblebee/text/text_embedding.ex | 4 +++- lib/bumblebee/text/text_generation.ex | 4 +++- lib/bumblebee/text/token_classification.ex | 4 +++- lib/bumblebee/text/translation.ex | 4 +++- .../text/zero_shot_classification.ex | 4 +++- lib/bumblebee/vision/image_classification.ex | 2 +- lib/bumblebee/vision/image_embedding.ex | 2 +- lib/bumblebee/vision/image_to_text.ex | 2 +- 15 files changed, 54 insertions(+), 17 deletions(-) diff --git a/lib/bumblebee/audio/speech_to_text_whisper.ex b/lib/bumblebee/audio/speech_to_text_whisper.ex index d9518678..e52e8ce2 100644 --- a/lib/bumblebee/audio/speech_to_text_whisper.ex +++ b/lib/bumblebee/audio/speech_to_text_whisper.ex @@ -74,7 +74,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do params = Shared.maybe_preallocate(params, preallocate_params, defn_options) generate_fun = - Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn -> + Shared.compile_or_jit(generate_fun, :generate, defn_options, compile != nil, fn -> inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size) seed = Nx.template({batch_size}, :s64) [params, {inputs, seed}] diff --git a/lib/bumblebee/diffusion/stable_diffusion.ex b/lib/bumblebee/diffusion/stable_diffusion.ex index 5abf5663..71fd5d34 100644 --- a/lib/bumblebee/diffusion/stable_diffusion.ex +++ b/lib/bumblebee/diffusion/stable_diffusion.ex @@ -242,7 +242,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion do vae_params = Shared.maybe_preallocate(vae_params, preallocate_params, defn_options) image_fun = - Shared.compile_or_jit(image_fun, defn_options, compile?, fn -> + Shared.compile_or_jit(image_fun, :image, defn_options, compile?, fn -> inputs = %{ "conditional_and_unconditional" => %{ "input_ids" => Nx.template({batch_size, 2, sequence_length}, :u32) @@ -255,7 +255,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion do safety_checker_fun = safety_checker_fun && - Shared.compile_or_jit(safety_checker_fun, defn_options, compile?, fn -> + Shared.compile_or_jit(safety_checker_fun, :safety_checker, defn_options, compile?, fn -> inputs = Bumblebee.Featurizer.batch_template( safety_checker_featurizer, diff --git a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex index a10fc670..eaf57813 100644 --- a/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex +++ b/lib/bumblebee/diffusion/stable_diffusion_controlnet.ex @@ -273,7 +273,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do Shared.maybe_preallocate(controlnet_params, preallocate_params, defn_options) image_fun = - Shared.compile_or_jit(image_fun, defn_options, compile?, fn -> + Shared.compile_or_jit(image_fun, :image, defn_options, compile?, fn -> inputs = %{ "conditional_and_unconditional" => %{ "input_ids" => Nx.template({batch_size, 2, sequence_length}, :u32) @@ -292,7 +292,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do safety_checker_fun = safety_checker_fun && - Shared.compile_or_jit(safety_checker_fun, defn_options, compile?, fn -> + Shared.compile_or_jit(safety_checker_fun, :safety_checker, defn_options, compile?, fn -> inputs = %{ "pixel_values" => Shared.input_template(safety_checker_spec, "pixel_values", [ diff --git a/lib/bumblebee/shared.ex b/lib/bumblebee/shared.ex index 8bad120d..d5b81f19 100644 --- a/lib/bumblebee/shared.ex +++ b/lib/bumblebee/shared.ex @@ -312,14 +312,29 @@ defmodule Bumblebee.Shared do and calls compiles the function upfront. The template function may return a mix of tensors and templates, all arguments are automatically converter to templates. + + If `defn_options[:cache]` is set, the given `scope` is used to create + a suffix. """ @spec compile_or_jit( function(), + scope, keyword(), boolean(), (-> list(Nx.Tensor.t())) ) :: function() - def compile_or_jit(fun, defn_options, compile?, template_fun) do + when scope: String.Chars.t() | {scope, scope} + def compile_or_jit(fun, scope, defn_options, compile?, template_fun) do + defn_options = + case defn_options[:cache] do + cache when is_binary(cache) -> + suffix = "__bumblebee_" <> scope_to_string(scope) + Keyword.replace!(defn_options, :cache, cache <> suffix) + + _ -> + defn_options + end + if compile? do template_args = template_fun.() |> templates() Nx.Defn.compile(fun, template_args, defn_options) @@ -328,6 +343,12 @@ defmodule Bumblebee.Shared do end end + defp scope_to_string({left, right}) do + scope_to_string(left) <> "_" <> scope_to_string(right) + end + + defp scope_to_string(scope), do: to_string(scope) + @doc """ Returns at template for the given model input. diff --git a/lib/bumblebee/text/fill_mask.ex b/lib/bumblebee/text/fill_mask.ex index 68652cbf..91ea8dce 100644 --- a/lib/bumblebee/text/fill_mask.ex +++ b/lib/bumblebee/text/fill_mask.ex @@ -64,8 +64,10 @@ defmodule Bumblebee.Text.FillMask do fn batch_key, defn_options -> params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + scope = {:scores, batch_key} + scores_fun = - Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> + Shared.compile_or_jit(scores_fun, scope, defn_options, compile != nil, fn -> {:sequence_length, sequence_length} = batch_key inputs = %{ diff --git a/lib/bumblebee/text/question_answering.ex b/lib/bumblebee/text/question_answering.ex index 69bccc45..20c8d077 100644 --- a/lib/bumblebee/text/question_answering.ex +++ b/lib/bumblebee/text/question_answering.ex @@ -58,8 +58,10 @@ defmodule Bumblebee.Text.QuestionAnswering do fn batch_key, defn_options -> params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + scope = {:predict, batch_key} + predict_fun = - Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> + Shared.compile_or_jit(scores_fun, scope, defn_options, compile != nil, fn -> {:sequence_length, sequence_length} = batch_key inputs = %{ diff --git a/lib/bumblebee/text/text_classification.ex b/lib/bumblebee/text/text_classification.ex index 0c510d35..22f0541c 100644 --- a/lib/bumblebee/text/text_classification.ex +++ b/lib/bumblebee/text/text_classification.ex @@ -50,8 +50,10 @@ defmodule Bumblebee.Text.TextClassification do fn batch_key, defn_options -> params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + scope = {:scores, batch_key} + scores_fun = - Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> + Shared.compile_or_jit(scores_fun, scope, defn_options, compile != nil, fn -> {:sequence_length, sequence_length} = batch_key inputs = %{ diff --git a/lib/bumblebee/text/text_embedding.ex b/lib/bumblebee/text/text_embedding.ex index 41e284e8..34f41279 100644 --- a/lib/bumblebee/text/text_embedding.ex +++ b/lib/bumblebee/text/text_embedding.ex @@ -105,8 +105,10 @@ defmodule Bumblebee.Text.TextEmbedding do fn batch_key, defn_options -> params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + scope = {:embedding, batch_key} + embedding_fun = - Shared.compile_or_jit(embedding_fun, defn_options, compile != nil, fn -> + Shared.compile_or_jit(embedding_fun, scope, defn_options, compile != nil, fn -> {:sequence_length, sequence_length} = batch_key inputs = %{ diff --git a/lib/bumblebee/text/text_generation.ex b/lib/bumblebee/text/text_generation.ex index c2169071..2ae33102 100644 --- a/lib/bumblebee/text/text_generation.ex +++ b/lib/bumblebee/text/text_generation.ex @@ -53,8 +53,10 @@ defmodule Bumblebee.Text.TextGeneration do fn batch_key, defn_options -> params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + scope = {:generate, batch_key} + generate_fun = - Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn -> + Shared.compile_or_jit(generate_fun, scope, defn_options, compile != nil, fn -> {:sequence_length, sequence_length} = batch_key inputs = %{ diff --git a/lib/bumblebee/text/token_classification.ex b/lib/bumblebee/text/token_classification.ex index 31433ea8..b0f97cd8 100644 --- a/lib/bumblebee/text/token_classification.ex +++ b/lib/bumblebee/text/token_classification.ex @@ -54,8 +54,10 @@ defmodule Bumblebee.Text.TokenClassification do fn batch_key, defn_options -> params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + scope = {:scores, batch_key} + scores_fun = - Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> + Shared.compile_or_jit(scores_fun, scope, defn_options, compile != nil, fn -> {:sequence_length, sequence_length} = batch_key inputs = %{ diff --git a/lib/bumblebee/text/translation.ex b/lib/bumblebee/text/translation.ex index 79082fc4..64a3dbff 100644 --- a/lib/bumblebee/text/translation.ex +++ b/lib/bumblebee/text/translation.ex @@ -50,8 +50,10 @@ defmodule Bumblebee.Text.Translation do fn batch_key, defn_options -> params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + scope = {:generate, batch_key} + generate_fun = - Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn -> + Shared.compile_or_jit(generate_fun, scope, defn_options, compile != nil, fn -> {:sequence_length, sequence_length} = batch_key inputs = %{ diff --git a/lib/bumblebee/text/zero_shot_classification.ex b/lib/bumblebee/text/zero_shot_classification.ex index 05952eaf..4f18fdb8 100644 --- a/lib/bumblebee/text/zero_shot_classification.ex +++ b/lib/bumblebee/text/zero_shot_classification.ex @@ -63,8 +63,10 @@ defmodule Bumblebee.Text.ZeroShotClassification do fn batch_key, defn_options -> params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + scope = {:logits, batch_key} + logits_fun = - Shared.compile_or_jit(logits_fun, defn_options, compile != nil, fn -> + Shared.compile_or_jit(logits_fun, scope, defn_options, compile != nil, fn -> {:sequence_length, sequence_length} = batch_key inputs = %{ diff --git a/lib/bumblebee/vision/image_classification.ex b/lib/bumblebee/vision/image_classification.ex index fca3291d..79fa936c 100644 --- a/lib/bumblebee/vision/image_classification.ex +++ b/lib/bumblebee/vision/image_classification.ex @@ -50,7 +50,7 @@ defmodule Bumblebee.Vision.ImageClassification do params = Shared.maybe_preallocate(params, preallocate_params, defn_options) scores_fun = - Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn -> + Shared.compile_or_jit(scores_fun, :scores, defn_options, compile != nil, fn -> inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size) [params, inputs] end) diff --git a/lib/bumblebee/vision/image_embedding.ex b/lib/bumblebee/vision/image_embedding.ex index d1975ca6..345ac14e 100644 --- a/lib/bumblebee/vision/image_embedding.ex +++ b/lib/bumblebee/vision/image_embedding.ex @@ -73,7 +73,7 @@ defmodule Bumblebee.Vision.ImageEmbedding do params = Shared.maybe_preallocate(params, preallocate_params, defn_options) embedding_fun = - Shared.compile_or_jit(embedding_fun, defn_options, compile != nil, fn -> + Shared.compile_or_jit(embedding_fun, :embedding, defn_options, compile != nil, fn -> inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size) [params, inputs] end) diff --git a/lib/bumblebee/vision/image_to_text.ex b/lib/bumblebee/vision/image_to_text.ex index 53e83cd9..055e13a3 100644 --- a/lib/bumblebee/vision/image_to_text.ex +++ b/lib/bumblebee/vision/image_to_text.ex @@ -43,7 +43,7 @@ defmodule Bumblebee.Vision.ImageToText do params = Shared.maybe_preallocate(params, preallocate_params, defn_options) generate_fun = - Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn -> + Shared.compile_or_jit(generate_fun, :generate, defn_options, compile != nil, fn -> inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size) seed = Nx.template({batch_size}, :s64) [params, {inputs, seed}]