Skip to content

Commit

Permalink
Suffix defn cache path in servings (#397)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Sep 10, 2024
1 parent 9aaeb13 commit 9388b28
Show file tree
Hide file tree
Showing 15 changed files with 54 additions and 17 deletions.
2 changes: 1 addition & 1 deletion lib/bumblebee/audio/speech_to_text_whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

generate_fun =
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
Shared.compile_or_jit(generate_fun, :generate, defn_options, compile != nil, fn ->
inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size)
seed = Nx.template({batch_size}, :s64)
[params, {inputs, seed}]
Expand Down
4 changes: 2 additions & 2 deletions lib/bumblebee/diffusion/stable_diffusion.ex
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
vae_params = Shared.maybe_preallocate(vae_params, preallocate_params, defn_options)

image_fun =
Shared.compile_or_jit(image_fun, defn_options, compile?, fn ->
Shared.compile_or_jit(image_fun, :image, defn_options, compile?, fn ->
inputs = %{
"conditional_and_unconditional" => %{
"input_ids" => Nx.template({batch_size, 2, sequence_length}, :u32)
Expand All @@ -255,7 +255,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion do

safety_checker_fun =
safety_checker_fun &&
Shared.compile_or_jit(safety_checker_fun, defn_options, compile?, fn ->
Shared.compile_or_jit(safety_checker_fun, :safety_checker, defn_options, compile?, fn ->
inputs =
Bumblebee.Featurizer.batch_template(
safety_checker_featurizer,
Expand Down
4 changes: 2 additions & 2 deletions lib/bumblebee/diffusion/stable_diffusion_controlnet.ex
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do
Shared.maybe_preallocate(controlnet_params, preallocate_params, defn_options)

image_fun =
Shared.compile_or_jit(image_fun, defn_options, compile?, fn ->
Shared.compile_or_jit(image_fun, :image, defn_options, compile?, fn ->
inputs = %{
"conditional_and_unconditional" => %{
"input_ids" => Nx.template({batch_size, 2, sequence_length}, :u32)
Expand All @@ -292,7 +292,7 @@ defmodule Bumblebee.Diffusion.StableDiffusionControlNet do

safety_checker_fun =
safety_checker_fun &&
Shared.compile_or_jit(safety_checker_fun, defn_options, compile?, fn ->
Shared.compile_or_jit(safety_checker_fun, :safety_checker, defn_options, compile?, fn ->
inputs = %{
"pixel_values" =>
Shared.input_template(safety_checker_spec, "pixel_values", [
Expand Down
23 changes: 22 additions & 1 deletion lib/bumblebee/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,29 @@ defmodule Bumblebee.Shared do
and calls compiles the function upfront. The template function may
return a mix of tensors and templates, all arguments are automatically
converter to templates.
If `defn_options[:cache]` is set, the given `scope` is used to create
a suffix.
"""
@spec compile_or_jit(
function(),
scope,
keyword(),
boolean(),
(-> list(Nx.Tensor.t()))
) :: function()
def compile_or_jit(fun, defn_options, compile?, template_fun) do
when scope: String.Chars.t() | {scope, scope}
def compile_or_jit(fun, scope, defn_options, compile?, template_fun) do
defn_options =
case defn_options[:cache] do
cache when is_binary(cache) ->
suffix = "__bumblebee_" <> scope_to_string(scope)
Keyword.replace!(defn_options, :cache, cache <> suffix)

_ ->
defn_options
end

if compile? do
template_args = template_fun.() |> templates()
Nx.Defn.compile(fun, template_args, defn_options)
Expand All @@ -328,6 +343,12 @@ defmodule Bumblebee.Shared do
end
end

defp scope_to_string({left, right}) do
scope_to_string(left) <> "_" <> scope_to_string(right)
end

defp scope_to_string(scope), do: to_string(scope)

@doc """
Returns at template for the given model input.
Expand Down
4 changes: 3 additions & 1 deletion lib/bumblebee/text/fill_mask.ex
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ defmodule Bumblebee.Text.FillMask do
fn batch_key, defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

scope = {:scores, batch_key}

scores_fun =
Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn ->
Shared.compile_or_jit(scores_fun, scope, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key

inputs = %{
Expand Down
4 changes: 3 additions & 1 deletion lib/bumblebee/text/question_answering.ex
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ defmodule Bumblebee.Text.QuestionAnswering do
fn batch_key, defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

scope = {:predict, batch_key}

predict_fun =
Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn ->
Shared.compile_or_jit(scores_fun, scope, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key

inputs = %{
Expand Down
4 changes: 3 additions & 1 deletion lib/bumblebee/text/text_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ defmodule Bumblebee.Text.TextClassification do
fn batch_key, defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

scope = {:scores, batch_key}

scores_fun =
Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn ->
Shared.compile_or_jit(scores_fun, scope, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key

inputs = %{
Expand Down
4 changes: 3 additions & 1 deletion lib/bumblebee/text/text_embedding.ex
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ defmodule Bumblebee.Text.TextEmbedding do
fn batch_key, defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

scope = {:embedding, batch_key}

embedding_fun =
Shared.compile_or_jit(embedding_fun, defn_options, compile != nil, fn ->
Shared.compile_or_jit(embedding_fun, scope, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key

inputs = %{
Expand Down
4 changes: 3 additions & 1 deletion lib/bumblebee/text/text_generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ defmodule Bumblebee.Text.TextGeneration do
fn batch_key, defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

scope = {:generate, batch_key}

generate_fun =
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
Shared.compile_or_jit(generate_fun, scope, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key

inputs = %{
Expand Down
4 changes: 3 additions & 1 deletion lib/bumblebee/text/token_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ defmodule Bumblebee.Text.TokenClassification do
fn batch_key, defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

scope = {:scores, batch_key}

scores_fun =
Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn ->
Shared.compile_or_jit(scores_fun, scope, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key

inputs = %{
Expand Down
4 changes: 3 additions & 1 deletion lib/bumblebee/text/translation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ defmodule Bumblebee.Text.Translation do
fn batch_key, defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

scope = {:generate, batch_key}

generate_fun =
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
Shared.compile_or_jit(generate_fun, scope, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key

inputs = %{
Expand Down
4 changes: 3 additions & 1 deletion lib/bumblebee/text/zero_shot_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ defmodule Bumblebee.Text.ZeroShotClassification do
fn batch_key, defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

scope = {:logits, batch_key}

logits_fun =
Shared.compile_or_jit(logits_fun, defn_options, compile != nil, fn ->
Shared.compile_or_jit(logits_fun, scope, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key

inputs = %{
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/vision/image_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ defmodule Bumblebee.Vision.ImageClassification do
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

scores_fun =
Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn ->
Shared.compile_or_jit(scores_fun, :scores, defn_options, compile != nil, fn ->
inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size)
[params, inputs]
end)
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/vision/image_embedding.ex
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ defmodule Bumblebee.Vision.ImageEmbedding do
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

embedding_fun =
Shared.compile_or_jit(embedding_fun, defn_options, compile != nil, fn ->
Shared.compile_or_jit(embedding_fun, :embedding, defn_options, compile != nil, fn ->
inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size)
[params, inputs]
end)
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/vision/image_to_text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ defmodule Bumblebee.Vision.ImageToText do
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

generate_fun =
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
Shared.compile_or_jit(generate_fun, :generate, defn_options, compile != nil, fn ->
inputs = Bumblebee.Featurizer.batch_template(featurizer, batch_size)
seed = Nx.template({batch_size}, :s64)
[params, {inputs, seed}]
Expand Down

0 comments on commit 9388b28

Please sign in to comment.