Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve tensor allocations in servings #233

Merged
merged 2 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lib/bumblebee/audio.ex
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ defmodule Bumblebee.Audio do

* `:defn_options` - the options for JIT compilation. Defaults to `[]`

* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`

## Examples

{:ok, whisper} = Bumblebee.load_model({:hf, "openai/whisper-tiny"})
Expand Down
5 changes: 4 additions & 1 deletion lib/bumblebee/audio/speech_to_text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ defmodule Bumblebee.Audio.SpeechToText do
%Text.GenerationConfig{} = generation_config,
opts \\ []
) do
opts = Keyword.validate!(opts, [:seed, :compile, defn_options: []])
opts = Keyword.validate!(opts, [:seed, :compile, defn_options: [], preallocate_params: false])

%{model: model, params: params, spec: spec} = model_info

Shared.validate_architecture!(spec, [:for_conditional_generation])

preallocate_params = opts[:preallocate_params]
defn_options = opts[:defn_options]

compile =
Expand All @@ -35,6 +36,8 @@ defmodule Bumblebee.Audio.SpeechToText do

Nx.Serving.new(
fn defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

generate_fun =
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
inputs = %{
Expand Down
47 changes: 34 additions & 13 deletions lib/bumblebee/diffusion/stable_diffusion.ex
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ defmodule Bumblebee.Diffusion.StableDiffusion do

* `:defn_options` - the options for JIT compilation. Defaults to `[]`

* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`

## Examples

repository_id = "CompVis/stable-diffusion-v1-4"
Expand Down Expand Up @@ -135,13 +140,15 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
num_images_per_prompt: 1,
guidance_scale: 7.5,
seed: 0,
defn_options: []
defn_options: [],
preallocate_params: false
])

safety_checker = opts[:safety_checker]
safety_checker_featurizer = opts[:safety_checker_featurizer]
num_steps = opts[:num_steps]
num_images_per_prompt = opts[:num_images_per_prompt]
preallocate_params = opts[:preallocate_params]
defn_options = opts[:defn_options]

if safety_checker != nil and safety_checker_featurizer == nil do
Expand Down Expand Up @@ -203,11 +210,12 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
{safety_checker?, safety_checker[:spec], safety_checker[:params]},
safety_checker_featurizer,
{compile != nil, batch_size, sequence_length},
num_images_per_prompt
num_images_per_prompt,
preallocate_params
]

Nx.Serving.new(
fn defn_options -> apply(&init/9, init_args ++ [defn_options]) end,
fn defn_options -> apply(&init/10, init_args ++ [defn_options]) end,
defn_options
)
|> Nx.Serving.process_options(batch_size: batch_size)
Expand All @@ -224,8 +232,13 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
safety_checker_featurizer,
{compile?, batch_size, sequence_length},
num_images_per_prompt,
preallocate_params,
defn_options
) do
encoder_params = Shared.maybe_preallocate(encoder_params, preallocate_params, defn_options)
unet_params = Shared.maybe_preallocate(unet_params, preallocate_params, defn_options)
vae_params = Shared.maybe_preallocate(vae_params, preallocate_params, defn_options)

image_fun =
Shared.compile_or_jit(image_fun, defn_options, compile?, fn ->
text_inputs = %{
Expand All @@ -250,6 +263,10 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
[safety_checker_params, inputs]
end)

safety_checker_params =
safety_checker_params &&
Shared.maybe_preallocate(safety_checker_params, preallocate_params, defn_options)

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)

Expand All @@ -275,18 +292,22 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
negative_prompts = Enum.map(inputs, & &1.negative_prompt)

conditional =
Bumblebee.apply_tokenizer(tokenizer, prompts,
length: sequence_length,
return_token_type_ids: false,
return_attention_mask: false
)
Nx.with_default_backend(Nx.BinaryBackend, fn ->
Bumblebee.apply_tokenizer(tokenizer, prompts,
length: sequence_length,
return_token_type_ids: false,
return_attention_mask: false
)
end)

unconditional =
Bumblebee.apply_tokenizer(tokenizer, negative_prompts,
length: Nx.axis_size(conditional["input_ids"], 1),
return_attention_mask: false,
return_token_type_ids: false
)
Nx.with_default_backend(Nx.BinaryBackend, fn ->
Bumblebee.apply_tokenizer(tokenizer, negative_prompts,
length: Nx.axis_size(conditional["input_ids"], 1),
return_attention_mask: false,
return_token_type_ids: false
)
end)

inputs = %{"unconditional" => unconditional, "conditional" => conditional}

Expand Down
12 changes: 12 additions & 0 deletions lib/bumblebee/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,18 @@ defmodule Bumblebee.Shared do
end
end

@doc """
If `preallocate?` is `true`, allocates `params` using `defn_options`.
"""
@spec maybe_preallocate(map(), boolean(), keyword()) :: map()
def maybe_preallocate(params, preallocate?, defn_options) do
if preallocate? do
Nx.Defn.jit_apply(&Function.identity/1, [params], defn_options)
else
params
end
end

@doc """
Generates tokenizer implementation.
"""
Expand Down
40 changes: 40 additions & 0 deletions lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ defmodule Bumblebee.Text do

* `:defn_options` - the options for JIT compilation. Defaults to `[]`

* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`

## Examples

{:ok, bert} = Bumblebee.load_model({:hf, "dslim/bert-base-NER"})
Expand Down Expand Up @@ -152,6 +157,11 @@ defmodule Bumblebee.Text do

* `:defn_options` - the options for JIT compilation. Defaults to `[]`

* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`

* `:stream` - when `true`, the serving immediately returns a
stream that emits text chunks as they are generated. Note that
when using streaming, only a single input can be given to the
Expand Down Expand Up @@ -242,6 +252,11 @@ defmodule Bumblebee.Text do

* `:defn_options` - the options for JIT compilation. Defaults to `[]`

* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`

## Examples

{:ok, model_info} = Bumblebee.load_model({:hf, "facebook/blenderbot-400M-distill"})
Expand Down Expand Up @@ -311,6 +326,11 @@ defmodule Bumblebee.Text do

* `:defn_options` - the options for JIT compilation. Defaults to `[]`

* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`

## Examples

{:ok, bertweet} = Bumblebee.load_model({:hf, "finiteautomata/bertweet-base-sentiment-analysis"})
Expand Down Expand Up @@ -379,6 +399,11 @@ defmodule Bumblebee.Text do

* `:defn_options` - the options for JIT compilation. Defaults to `[]`

* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`

## Examples

{:ok, model_info} = Bumblebee.load_model({:hf, "intfloat/e5-large"})
Expand Down Expand Up @@ -444,6 +469,11 @@ defmodule Bumblebee.Text do

* `:defn_options` - the options for JIT compilation. Defaults to `[]`

* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`

## Examples

{:ok, bert} = Bumblebee.load_model({:hf, "bert-base-uncased"})
Expand Down Expand Up @@ -513,6 +543,11 @@ defmodule Bumblebee.Text do

* `:defn_options` - the options for JIT compilation. Defaults to `[]`

* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`

## Examples

{:ok, roberta} = Bumblebee.load_model({:hf, "deepset/roberta-base-squad2"})
Expand Down Expand Up @@ -581,6 +616,11 @@ defmodule Bumblebee.Text do

* `:defn_options` - the options for JIT compilation. Defaults to `[]`

* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. Defaults to `false`

## Examples

{:ok, model} = Bumblebee.load_model({:hf, "facebook/bart-large-mnli"})
Expand Down
19 changes: 12 additions & 7 deletions lib/bumblebee/text/conversation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ defmodule Bumblebee.Text.Conversation do
%Text.GenerationConfig{} = generation_config,
opts \\ []
) do
opts = Keyword.validate!(opts, [:seed, :compile, defn_options: []])
opts = Keyword.validate!(opts, [:seed, :compile, defn_options: [], preallocate_params: false])

%{model: model, params: params, spec: spec} = model_info

Expand All @@ -28,6 +28,7 @@ defmodule Bumblebee.Text.Conversation do
:for_conditional_generation
])

preallocate_params = opts[:preallocate_params]
defn_options = opts[:defn_options]

compile =
Expand All @@ -49,6 +50,8 @@ defmodule Bumblebee.Text.Conversation do

Nx.Serving.new(
fn batch_key, defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

generate_fun =
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key
Expand Down Expand Up @@ -88,12 +91,14 @@ defmodule Bumblebee.Text.Conversation do
end

inputs =
Bumblebee.apply_tokenizer(tokenizer, texts,
length: sequence_length,
pad_direction: :left,
truncate_direction: :left,
return_token_type_ids: false
)
Nx.with_default_backend(Nx.BinaryBackend, fn ->
Bumblebee.apply_tokenizer(tokenizer, texts,
length: sequence_length,
pad_direction: :left,
truncate_direction: :left,
return_token_type_ids: false
)
end)

batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)
Expand Down
17 changes: 12 additions & 5 deletions lib/bumblebee/text/fill_mask.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ defmodule Bumblebee.Text.FillMask do
def fill_mask(model_info, tokenizer, opts \\ []) do
%{model: model, params: params, spec: spec} = model_info
Shared.validate_architecture!(spec, :for_masked_language_modeling)
opts = Keyword.validate!(opts, [:compile, top_k: 5, defn_options: []])

opts =
Keyword.validate!(opts, [:compile, top_k: 5, defn_options: [], preallocate_params: false])

top_k = opts[:top_k]
preallocate_params = opts[:preallocate_params]
defn_options = opts[:defn_options]

compile =
Expand Down Expand Up @@ -51,6 +54,8 @@ defmodule Bumblebee.Text.FillMask do

Nx.Serving.new(
fn batch_key, defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

scores_fun =
Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key
Expand All @@ -77,10 +82,12 @@ defmodule Bumblebee.Text.FillMask do
texts = for text <- texts, do: validate_text!(text, mask_token)

inputs =
Bumblebee.apply_tokenizer(tokenizer, texts,
length: sequence_length,
return_token_type_ids: false
)
Nx.with_default_backend(Nx.BinaryBackend, fn ->
Bumblebee.apply_tokenizer(tokenizer, texts,
length: sequence_length,
return_token_type_ids: false
)
end)

batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)
Expand Down
24 changes: 18 additions & 6 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,14 @@ defmodule Bumblebee.Text.Generation do

@doc false
def generation(model_info, tokenizer, %Text.GenerationConfig{} = generation_config, opts \\ []) do
opts = Keyword.validate!(opts, [:seed, :compile, defn_options: [], stream: false])
opts =
Keyword.validate!(opts, [
:seed,
:compile,
defn_options: [],
preallocate_params: false,
stream: false
])

%{model: model, params: params, spec: spec} = model_info

Expand All @@ -807,6 +814,7 @@ defmodule Bumblebee.Text.Generation do
:for_causal_language_modeling
])

preallocate_params = opts[:preallocate_params]
defn_options = opts[:defn_options]

compile =
Expand All @@ -825,6 +833,8 @@ defmodule Bumblebee.Text.Generation do

Nx.Serving.new(
fn batch_key, defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

generate_fun =
Shared.compile_or_jit(generate_fun, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key
Expand Down Expand Up @@ -853,11 +863,13 @@ defmodule Bumblebee.Text.Generation do
{texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1)

inputs =
Bumblebee.apply_tokenizer(tokenizer, texts,
length: sequence_length,
pad_direction: :left,
return_token_type_ids: false
)
Nx.with_default_backend(Nx.BinaryBackend, fn ->
Bumblebee.apply_tokenizer(tokenizer, texts,
length: sequence_length,
pad_direction: :left,
return_token_type_ids: false
)
end)

batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)
Expand Down
Loading
Loading