Skip to content

Commit

Permalink
Improve tensor allocations in servings (#233)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Aug 3, 2023
1 parent 68d5f48 commit 333ba09
Show file tree
Hide file tree
Showing 17 changed files with 232 additions and 62 deletions.
5 changes: 5 additions & 0 deletions lib/bumblebee/audio.ex
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ defmodule Bumblebee.Audio do
* `:defn_options` - the options for JIT compilation. Defaults to `[]`
* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`
## Examples
{:ok, whisper} = Bumblebee.load_model({:hf, "openai/whisper-tiny"})
Expand Down
5 changes: 4 additions & 1 deletion lib/bumblebee/audio/speech_to_text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ defmodule Bumblebee.Audio.SpeechToText do
%Text.GenerationConfig{} = generation_config,
opts \\ []
) do
opts = Keyword.validate!(opts, [:seed, :compile, defn_options: []])
opts = Keyword.validate!(opts, [:seed, :compile, defn_options: [], preallocate_params: false])

%{model: model, params: params, spec: spec} = model_info

Shared.validate_architecture!(spec, [:for_conditional_generation])

preallocate_params = opts[:preallocate_params]
defn_options = opts[:defn_options]

compile =
Expand All @@ -35,6 +36,8 @@ defmodule Bumblebee.Audio.SpeechToText do

Nx.Serving.new(
fn defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

generate_fun =
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
inputs = %{
Expand Down
47 changes: 34 additions & 13 deletions lib/bumblebee/diffusion/stable_diffusion.ex
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
* `:defn_options` - the options for JIT compilation. Defaults to `[]`
* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`
## Examples
repository_id = "CompVis/stable-diffusion-v1-4"
Expand Down Expand Up @@ -135,13 +140,15 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
num_images_per_prompt: 1,
guidance_scale: 7.5,
seed: 0,
defn_options: []
defn_options: [],
preallocate_params: false
])

safety_checker = opts[:safety_checker]
safety_checker_featurizer = opts[:safety_checker_featurizer]
num_steps = opts[:num_steps]
num_images_per_prompt = opts[:num_images_per_prompt]
preallocate_params = opts[:preallocate_params]
defn_options = opts[:defn_options]

if safety_checker != nil and safety_checker_featurizer == nil do
Expand Down Expand Up @@ -203,11 +210,12 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
{safety_checker?, safety_checker[:spec], safety_checker[:params]},
safety_checker_featurizer,
{compile != nil, batch_size, sequence_length},
num_images_per_prompt
num_images_per_prompt,
preallocate_params
]

Nx.Serving.new(
fn defn_options -> apply(&init/9, init_args ++ [defn_options]) end,
fn defn_options -> apply(&init/10, init_args ++ [defn_options]) end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size)
Expand All @@ -224,8 +232,13 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
safety_checker_featurizer,
{compile?, batch_size, sequence_length},
num_images_per_prompt,
preallocate_params,
defn_options
) do
encoder_params = Shared.maybe_preallocate(encoder_params, preallocate_params, defn_options)
unet_params = Shared.maybe_preallocate(unet_params, preallocate_params, defn_options)
vae_params = Shared.maybe_preallocate(vae_params, preallocate_params, defn_options)

image_fun =
Shared.compile_or_jit(image_fun, defn_options, compile?, fn ->
text_inputs = %{
Expand All @@ -250,6 +263,10 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
[safety_checker_params, inputs]
end)

safety_checker_params =
safety_checker_params &&
Shared.maybe_preallocate(safety_checker_params, preallocate_params, defn_options)

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)

Expand All @@ -275,18 +292,22 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
negative_prompts = Enum.map(inputs, & &1.negative_prompt)

conditional =
Bumblebee.apply_tokenizer(tokenizer, prompts,
length: sequence_length,
return_token_type_ids: false,
return_attention_mask: false
)
Nx.with_default_backend(Nx.BinaryBackend, fn ->
Bumblebee.apply_tokenizer(tokenizer, prompts,
length: sequence_length,
return_token_type_ids: false,
return_attention_mask: false
)
end)

unconditional =
Bumblebee.apply_tokenizer(tokenizer, negative_prompts,
length: Nx.axis_size(conditional["input_ids"], 1),
return_attention_mask: false,
return_token_type_ids: false
)
Nx.with_default_backend(Nx.BinaryBackend, fn ->
Bumblebee.apply_tokenizer(tokenizer, negative_prompts,
length: Nx.axis_size(conditional["input_ids"], 1),
return_attention_mask: false,
return_token_type_ids: false
)
end)

inputs = %{"unconditional" => unconditional, "conditional" => conditional}

Expand Down
12 changes: 12 additions & 0 deletions lib/bumblebee/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,18 @@ defmodule Bumblebee.Shared do
end
end

@doc """
If `preallocate?` is `true`, allocates `params` using `defn_options`.
"""
@spec maybe_preallocate(map(), boolean(), keyword()) :: map()
def maybe_preallocate(params, preallocate?, defn_options) do
if preallocate? do
Nx.Defn.jit_apply(&Function.identity/1, [params], defn_options)
else
params
end
end

@doc """
Generates tokenizer implementation.
"""
Expand Down
40 changes: 40 additions & 0 deletions lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ defmodule Bumblebee.Text do
* `:defn_options` - the options for JIT compilation. Defaults to `[]`
* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`
## Examples
{:ok, bert} = Bumblebee.load_model({:hf, "dslim/bert-base-NER"})
Expand Down Expand Up @@ -152,6 +157,11 @@ defmodule Bumblebee.Text do
* `:defn_options` - the options for JIT compilation. Defaults to `[]`
* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`
* `:stream` - when `true`, the serving immediately returns a
stream that emits text chunks as they are generated. Note that
when using streaming, only a single input can be given to the
Expand Down Expand Up @@ -242,6 +252,11 @@ defmodule Bumblebee.Text do
* `:defn_options` - the options for JIT compilation. Defaults to `[]`
* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`
## Examples
{:ok, model_info} = Bumblebee.load_model({:hf, "facebook/blenderbot-400M-distill"})
Expand Down Expand Up @@ -311,6 +326,11 @@ defmodule Bumblebee.Text do
* `:defn_options` - the options for JIT compilation. Defaults to `[]`
* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`
## Examples
{:ok, bertweet} = Bumblebee.load_model({:hf, "finiteautomata/bertweet-base-sentiment-analysis"})
Expand Down Expand Up @@ -379,6 +399,11 @@ defmodule Bumblebee.Text do
* `:defn_options` - the options for JIT compilation. Defaults to `[]`
* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`
## Examples
{:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-large"})
Expand Down Expand Up @@ -444,6 +469,11 @@ defmodule Bumblebee.Text do
* `:defn_options` - the options for JIT compilation. Defaults to `[]`
* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`
## Examples
{:ok, bert} = Bumblebee.load_model({:hf, "bert-base-uncased"})
Expand Down Expand Up @@ -513,6 +543,11 @@ defmodule Bumblebee.Text do
* `:defn_options` - the options for JIT compilation. Defaults to `[]`
* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`
## Examples
{:ok, roberta} = Bumblebee.load_model({:hf, "deepset/roberta-base-squad2"})
Expand Down Expand Up @@ -581,6 +616,11 @@ defmodule Bumblebee.Text do
* `:defn_options` - the options for JIT compilation. Defaults to `[]`
* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`
## Examples
{:ok, model} = Bumblebee.load_model({:hf, "facebook/bart-large-mnli"})
Expand Down
19 changes: 12 additions & 7 deletions lib/bumblebee/text/conversation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ defmodule Bumblebee.Text.Conversation do
%Text.GenerationConfig{} = generation_config,
opts \\ []
) do
opts = Keyword.validate!(opts, [:seed, :compile, defn_options: []])
opts = Keyword.validate!(opts, [:seed, :compile, defn_options: [], preallocate_params: false])

%{model: model, params: params, spec: spec} = model_info

Expand All @@ -28,6 +28,7 @@ defmodule Bumblebee.Text.Conversation do
:for_conditional_generation
])

preallocate_params = opts[:preallocate_params]
defn_options = opts[:defn_options]

compile =
Expand All @@ -49,6 +50,8 @@ defmodule Bumblebee.Text.Conversation do

Nx.Serving.new(
fn batch_key, defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

generate_fun =
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key
Expand Down Expand Up @@ -88,12 +91,14 @@ defmodule Bumblebee.Text.Conversation do
end

inputs =
Bumblebee.apply_tokenizer(tokenizer, texts,
length: sequence_length,
pad_direction: :left,
truncate_direction: :left,
return_token_type_ids: false
)
Nx.with_default_backend(Nx.BinaryBackend, fn ->
Bumblebee.apply_tokenizer(tokenizer, texts,
length: sequence_length,
pad_direction: :left,
truncate_direction: :left,
return_token_type_ids: false
)
end)

batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)
Expand Down
17 changes: 12 additions & 5 deletions lib/bumblebee/text/fill_mask.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ defmodule Bumblebee.Text.FillMask do
def fill_mask(model_info, tokenizer, opts \\ []) do
%{model: model, params: params, spec: spec} = model_info
Shared.validate_architecture!(spec, :for_masked_language_modeling)
opts = Keyword.validate!(opts, [:compile, top_k: 5, defn_options: []])

opts =
Keyword.validate!(opts, [:compile, top_k: 5, defn_options: [], preallocate_params: false])

top_k = opts[:top_k]
preallocate_params = opts[:preallocate_params]
defn_options = opts[:defn_options]

compile =
Expand Down Expand Up @@ -51,6 +54,8 @@ defmodule Bumblebee.Text.FillMask do

Nx.Serving.new(
fn batch_key, defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

scores_fun =
Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key
Expand All @@ -77,10 +82,12 @@ defmodule Bumblebee.Text.FillMask do
texts = for text <- texts, do: validate_text!(text, mask_token)

inputs =
Bumblebee.apply_tokenizer(tokenizer, texts,
length: sequence_length,
return_token_type_ids: false
)
Nx.with_default_backend(Nx.BinaryBackend, fn ->
Bumblebee.apply_tokenizer(tokenizer, texts,
length: sequence_length,
return_token_type_ids: false
)
end)

batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)
Expand Down
24 changes: 18 additions & 6 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,14 @@ defmodule Bumblebee.Text.Generation do

@doc false
def generation(model_info, tokenizer, %Text.GenerationConfig{} = generation_config, opts \\ []) do
opts = Keyword.validate!(opts, [:seed, :compile, defn_options: [], stream: false])
opts =
Keyword.validate!(opts, [
:seed,
:compile,
defn_options: [],
preallocate_params: false,
stream: false
])

%{model: model, params: params, spec: spec} = model_info

Expand All @@ -807,6 +814,7 @@ defmodule Bumblebee.Text.Generation do
:for_causal_language_modeling
])

preallocate_params = opts[:preallocate_params]
defn_options = opts[:defn_options]

compile =
Expand All @@ -825,6 +833,8 @@ defmodule Bumblebee.Text.Generation do

Nx.Serving.new(
fn batch_key, defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

generate_fun =
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key
Expand Down Expand Up @@ -853,11 +863,13 @@ defmodule Bumblebee.Text.Generation do
{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1)

inputs =
Bumblebee.apply_tokenizer(tokenizer, texts,
length: sequence_length,
pad_direction: :left,
return_token_type_ids: false
)
Nx.with_default_backend(Nx.BinaryBackend, fn ->
Bumblebee.apply_tokenizer(tokenizer, texts,
length: sequence_length,
pad_direction: :left,
return_token_type_ids: false
)
end)

batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)
Expand Down
Loading

0 comments on commit 333ba09

Please sign in to comment.