Skip to content

Commit

Permalink
Load parameter tensors lazily (#344)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Feb 23, 2024
1 parent a72c5dd commit 78c7694
Show file tree
Hide file tree
Showing 10 changed files with 166 additions and 97 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

This release changes the directory structure of the models cache, such that cached files from the same HuggingFace Hub repository are grouped in a separate subdirectory. This change is meant to simplify the process of manually removing specific models from the cache to free up space. As a result, the cache contents from prior versions are invalidated, so you most likely want to remove the current cache contents. To find the cache location run `elixir -e 'Mix.install([{:bumblebee, "0.4.2"}]); IO.puts(Bumblebee.cache_dir())'` (defaults to the standard cache location for the given operating system).

We also reduced memory usage during parameter loading (both when loading onto the CPU and GPU directly). Previously, larger models sometimes required loading parameters using CPU and only then transfering to the GPU, in order to avoid running out of GPU memory during parameter transformations. With this release this should no longer be the case. Loading parameters now has barely any memory footprint other than the parameters themselves.

### Added

* Notebook on LLaMA 2 to the docs ([#259](https://github.com/elixir-nx/bumblebee/pull/259))
Expand Down
9 changes: 5 additions & 4 deletions examples/phoenix/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,15 @@ config :nx, :default_backend, {EXLA.Backend, client: :host}

Then, for any expensive computations you can use [`Nx.Defn.compile/3`](https://hexdocs.pm/nx/Nx.Defn.html#compile/3) (or [`Axon.compile/4`](https://hexdocs.pm/axon/Axon.html#compile/4)) passing `compiler: EXLA` as an option. When you use a Bumblebee serving the compilation is handled for you, just make sure to pass `:compile` and `defn_options: [compiler: EXLA]` when creating the serving.

There's a final important detail related to parameters. With the above configuration, a model will run on the GPU, however parameters will be loaded onto the CPU (due to the default backend), so they will need to be copied onto the GPU every time. To avoid that, you want to make sure that parameters are allocated on the same device that the computation runs on. The simplest way to achieve that is passing `preallocate_params: true` to when creating the serving, in addition to `:defn_options`.
There's a final important detail related to parameters. With the above configuration, a model will run on the GPU, however parameters will be loaded onto the CPU (due to the default backend), so they will need to be copied onto the GPU every time. To avoid that, you can load the parameters onto the GPU directly using `Bumblebee.load_model(..., backend: EXLA.Backend)`.

When building the Bumblebee serving, make sure to specify the compiler and `:compile` shapes, so that the computation is compiled upfront when the serving boots.

```elixir
serving =
Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer,
Bumblebee.Text.text_embedding(model_info, tokenizer,
compile: [batch_size: 1, sequence_length: 512],
defn_options: [compiler: EXLA],
preallocate_params: true
defn_options: [compiler: EXLA]
)
```

Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ defmodule Bumblebee do
end
end

defp params_file_loader_fun(".safetensors"), do: &Safetensors.read!/1
defp params_file_loader_fun(".safetensors"), do: &Safetensors.read!(&1, lazy: true)
defp params_file_loader_fun(_), do: &Bumblebee.Conversion.PyTorch.Loader.load!/1

@doc """
Expand Down
5 changes: 4 additions & 1 deletion lib/bumblebee/conversion/pytorch.ex
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ defmodule Bumblebee.Conversion.PyTorch do
defp with_default_backend(backend, fun), do: Nx.with_default_backend(backend, fun)

defp state_dict?(%{} = dict) when not is_struct(dict) do
Enum.all?(dict, fn {key, value} -> is_binary(key) and is_struct(value, Nx.Tensor) end)
Enum.all?(dict, fn {key, value} ->
is_binary(key) and Nx.LazyContainer.impl_for(value) != nil
end)
end

defp state_dict?(_other), do: false
Expand Down Expand Up @@ -141,6 +143,7 @@ defmodule Bumblebee.Conversion.PyTorch do

{value, diff} =
if all_sources_found? do
source_values = Enum.map(source_values, &Nx.to_tensor/1)
value = builder_fun.(Enum.reverse(source_values))

case verify_param_shape(param_expr, value) do
Expand Down
76 changes: 76 additions & 0 deletions lib/bumblebee/conversion/pytorch/file_tensor.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
defmodule Bumblebee.Conversion.PyTorch.FileTensor do
@moduledoc false

defstruct [:shape, :type, :offset, :strides, :storage]
end

defimpl Nx.LazyContainer, for: Bumblebee.Conversion.PyTorch.FileTensor do
alias Bumblebee.Conversion.PyTorch.Loader

def traverse(lazy_tensor, acc, fun) do
template = Nx.template(lazy_tensor.shape, lazy_tensor.type)

load = fn ->
binary =
case lazy_tensor.storage do
{:zip, path, file_name} ->
Loader.open_zip!(path, fn unzip ->
Loader.read_zip_file(unzip, file_name)
end)

{:file, path, offset, size} ->
File.open!(path, [:read, :raw], fn file ->
{:ok, binary} = :file.pread(file, offset, size)
binary
end)
end

%{offset: offset, shape: shape, type: type, strides: strides} = lazy_tensor

{_, bit_unit} = type
byte_unit = div(bit_unit, 8)
size = Tuple.product(shape)
binary = binary_part(binary, offset * byte_unit, size * byte_unit)
binary |> Nx.from_binary(type) |> to_contiguous(shape, strides)
end

fun.(template, load, acc)
end

defp to_contiguous(tensor, shape, strides) do
# PyTorch tensors may not be contiguous in memory, so strides are
# used to indicate jumps necessary when traversing each axis.
# Since Nx doesn't have the notion of strides, we transpose the
# tensor, in a way that makes it contiguous, which is equivalent
# to strides being decreasing

memory_axes_order =
strides
|> Tuple.to_list()
|> Enum.with_index()
|> Enum.sort_by(&elem(&1, 0), :desc)
|> Enum.map(&elem(&1, 1))

if memory_axes_order == Nx.axes(shape) do
Nx.reshape(tensor, shape)
else
memory_shape =
memory_axes_order
|> Enum.map(fn axis -> elem(shape, axis) end)
|> List.to_tuple()

tensor
|> Nx.reshape(memory_shape)
|> Nx.transpose(axes: inverse_permutation(memory_axes_order))
end
end

defp inverse_permutation(list) do
list
|> Enum.with_index()
|> Enum.reduce(List.to_tuple(list), fn {src_idx, dest_idx}, inverse ->
put_elem(inverse, src_idx, dest_idx)
end)
|> Tuple.to_list()
end
end
134 changes: 57 additions & 77 deletions lib/bumblebee/conversion/pytorch/loader.ex
Original file line number Diff line number Diff line change
Expand Up @@ -26,47 +26,52 @@ defmodule Bumblebee.Conversion.PyTorch.Loader do
end

defp load_zip!(path) do
zip_file = Unzip.LocalFile.open(path)

try do
{:ok, unzip} = Unzip.new(zip_file)

contents =
open_zip!(path, fn unzip ->
file_name_map =
unzip
|> Unzip.list_entries()
|> Map.new(fn %Unzip.Entry{file_name: file_name} ->
content =
unzip
|> Unzip.file_stream!(file_name)
|> Enum.to_list()
|> IO.iodata_to_binary()

# Strip the root dir from the file name
name =
file_name
|> Path.split()
|> Enum.drop(1)
|> Enum.join("/")

{name, content}
name = file_name |> Path.split() |> Enum.drop(1) |> Enum.join("/")
{name, file_name}
end)

binary = read_zip_file(unzip, Map.fetch!(file_name_map, "data.pkl"))

{term, ""} =
Unpickler.load!(Map.fetch!(contents, "data.pkl"),
Unpickler.load!(binary,
object_resolver: &object_resolver/1,
persistent_id_resolver: fn
{"storage", storage_type, key, _location, _size} ->
binary = Map.fetch!(contents, "data/#{key}")
{:storage, storage_type, binary}
file_name = Map.fetch!(file_name_map, "data/#{key}")
{:storage, storage_type, {:zip, path, file_name}}
end
)

term
end)
end

@doc false
def open_zip!(path, fun) do
zip_file = Unzip.LocalFile.open(path)

try do
{:ok, unzip} = Unzip.new(zip_file)
fun.(unzip)
after
Unzip.LocalFile.close(zip_file)
end
end

@doc false
def read_zip_file(unzip, file_name) do
unzip
|> Unzip.file_stream!(file_name)
|> Enum.to_list()
|> IO.iodata_to_binary()
end

defp object_resolver(%{constructor: "collections.OrderedDict", set_items: items}) do
{:ok, Map.new(items)}
end
Expand All @@ -76,13 +81,18 @@ defmodule Bumblebee.Conversion.PyTorch.Loader do
constructor: "torch._utils._rebuild_tensor_v2",
args: [storage, offset, shape, strides, _requires_grad, _backward_hooks]
}) do
{:storage, storage_type, binary} = storage
{_, bit_unit} = type = storage_type_to_nx(storage_type)
byte_unit = div(bit_unit, 8)
size = Tuple.product(shape)
binary = binary_part(binary, offset * byte_unit, size * byte_unit)
tensor = binary |> Nx.from_binary(type) |> to_contiguous(shape, strides)
{:ok, tensor}
{:storage, storage_type, storage} = storage
type = storage_type_to_nx(storage_type)

lazy_tensor = %Bumblebee.Conversion.PyTorch.FileTensor{
shape: shape,
type: type,
offset: offset,
strides: strides,
storage: storage
}

{:ok, lazy_tensor}
end

# See https://github.com/pytorch/pytorch/blob/v1.12.1/torch/_tensor.py#L222-L226
Expand Down Expand Up @@ -173,53 +183,17 @@ defmodule Bumblebee.Conversion.PyTorch.Loader do
end
end

defp to_contiguous(tensor, shape, strides) do
# PyTorch tensors may not be contiguous in memory, so strides are
# used to indicate jumps necessary when traversing each axis.
# Since Nx doesn't have the notion of strides, we transpose the
# tensor, in a way that makes it contiguous, which is equivalent
# to strides being decreasing

memory_axes_order =
strides
|> Tuple.to_list()
|> Enum.with_index()
|> Enum.sort_by(&elem(&1, 0), :desc)
|> Enum.map(&elem(&1, 1))

if memory_axes_order == Nx.axes(shape) do
Nx.reshape(tensor, shape)
else
memory_shape =
memory_axes_order
|> Enum.map(fn axis -> elem(shape, axis) end)
|> List.to_tuple()

tensor
|> Nx.reshape(memory_shape)
|> Nx.transpose(axes: inverse_permutation(memory_axes_order))
end
end

defp inverse_permutation(list) do
list
|> Enum.with_index()
|> Enum.reduce(List.to_tuple(list), fn {src_idx, dest_idx}, inverse ->
put_elem(inverse, src_idx, dest_idx)
end)
|> Tuple.to_list()
end

@legacy_magic_number 119_547_037_146_038_801_333_356

defp load_legacy!(path) do
data = File.read!(path)
full_size = byte_size(data)

{@legacy_magic_number, data} = Unpickler.load!(data)
{_protocol_version, data} = Unpickler.load!(data)
{_system_info, data} = Unpickler.load!(data)

binaries = storage_binaries(data)
binaries = storage_binaries(data, full_size)

{term, _} =
Unpickler.load!(data,
Expand All @@ -229,16 +203,18 @@ defmodule Bumblebee.Conversion.PyTorch.Loader do
{_, bit_unit} = storage_type_to_nx(storage_type)
byte_unit = div(bit_unit, 8)

binary =
{file_offset, size} = Map.fetch!(binaries, root_key)

storage =
case view_metadata do
nil ->
binaries[root_key]
{:file, path, file_offset, size}

{_view_key, offset, view_size} ->
binary_part(binaries[root_key], offset * byte_unit, view_size * byte_unit)
{:file, path, file_offset + offset * byte_unit, view_size * byte_unit}
end

{:storage, storage_type, binary}
{:storage, storage_type, storage}

{"module", module, _source_file, _source} ->
module
Expand All @@ -248,7 +224,7 @@ defmodule Bumblebee.Conversion.PyTorch.Loader do
term
end

defp storage_binaries(data) do
defp storage_binaries(data, full_size) do
# We first do a dry run on the pickle and extract storage metadata,
# then we use that metadata to read the storage binaries that follow

Expand All @@ -269,16 +245,20 @@ defmodule Bumblebee.Conversion.PyTorch.Loader do

{storage_keys, data} = Unpickler.load!(data)

{pairs, ""} =
Enum.map_reduce(storage_keys, data, fn key, data ->
offset = full_size - byte_size(data)

{pairs, _offset} =
Enum.map_reduce(storage_keys, offset, fn key, offset ->
{size, byte_unit} = Map.fetch!(storage_infos, key)
bytes = size * byte_unit

# Each storage binary is prefixed with the number of elements.
# Each storage binary is prefixed with the number of elements,
# stored as integer-little-signed-size(64), hence the 8 bytes.
# See https://github.com/pytorch/pytorch/blob/v1.11.0/torch/csrc/generic/serialization.cpp#L93-L134
<<^size::integer-little-signed-size(64), chunk::binary-size(bytes), data::binary>> = data
start_offset = offset + 8
offset = start_offset + bytes

{{key, chunk}, data}
{{key, {start_offset, bytes}}, offset}
end)

Map.new(pairs)
Expand Down
2 changes: 1 addition & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ defmodule Bumblebee.MixProject do
{:torchx, github: "elixir-nx/nx", sparse: "torchx", override: true, only: [:dev, :test]},
{:nx_image, "~> 0.1.0"},
{:unpickler, "~> 0.1.0"},
{:safetensors, "~> 0.1.2"},
{:safetensors, "~> 0.1.3"},
{:castore, "~> 0.1 or ~> 1.0"},
{:jason, "~> 1.4.0"},
{:unzip, "~> 0.10.0"},
Expand Down
2 changes: 1 addition & 1 deletion mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"progress_bar": {:hex, :progress_bar, "3.0.0", "f54ff038c2ac540cfbb4c2bfe97c75e7116ead044f3c2b10c9f212452194b5cd", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}], "hexpm", "6981c2b25ab24aecc91a2dc46623658e1399c21a2ae24db986b90d678530f2b7"},
"ranch": {:hex, :ranch, "1.8.0", "8c7a100a139fd57f17327b6413e4167ac559fbc04ca7448e9be9057311597a1d", [:make, :rebar3], [], "hexpm", "49fbcfd3682fab1f5d109351b61257676da1a2fdbe295904176d5e521a2ddfe5"},
"rustler_precompiled": {:hex, :rustler_precompiled, "0.6.2", "d2218ba08a43fa331957f30481d00b666664d7e3861431b02bd3f4f30eec8e5b", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "b9048eaed8d7d14a53f758c91865cc616608a438d2595f621f6a4b32a5511709"},
"safetensors": {:hex, :safetensors, "0.1.2", "849434fea20b2ed14b92e74205a925d86039c4ef53efe861e5c7b574c3ba8fa6", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "298a5c82e34fc3b955464b89c080aa9a2625a47d69148d51113771e19166d4e0"},
"safetensors": {:hex, :safetensors, "0.1.3", "7ff3c22391e213289c713898481d492c9c28a49ab1d0705b72630fb8360426b2", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "fe50b53ea59fde4e723dd1a2e31cfdc6013e69343afac84c6be86d6d7c562c14"},
"stb_image": {:hex, :stb_image, "0.6.2", "d680a418416b1d778231d1d16151be3474d187e8505e1bd524aa0d08d2de094f", [:make, :mix], [{:cc_precompiler, "~> 0.1.0", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7.0", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "231ad012f649dd2bd5ef99e9171e814f3235e8f7c45009355789ac4836044a39"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
"tokenizers": {:hex, :tokenizers, "0.4.0", "140283ca74a971391ddbd83cd8cbdb9bd03736f37a1b6989b82d245a95e1eb97", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "ef1a9824f5a893cd3b831c0e5b3d72caa250d2ec462035cc6afef6933b13a82e"},
Expand Down
Loading

0 comments on commit 78c7694

Please sign in to comment.