Skip to content

Commit

Permalink
Update tokenizers
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Aug 9, 2023
1 parent c80c6b4 commit 9c5a9f8
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 22 deletions.
47 changes: 31 additions & 16 deletions lib/bumblebee/utils/tokenizers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,29 @@ defmodule Bumblebee.Utils.Tokenizers do
pad_id = Tokenizer.token_to_id(tokenizer, pad_token)

encodings =
Enum.map(encodings, fn seq ->
seq =
Encoding.pad(seq, pad_length,
pad_id: pad_id,
pad_token: pad_token,
direction: opts[:pad_direction]
)

if truncate_length do
Encoding.truncate(seq, truncate_length, direction: opts[:truncate_direction])
else
seq
end
Enum.map(encodings, fn encoding ->
transformations =
[
Encoding.Transformation.pad(pad_length,
pad_id: pad_id,
pad_token: pad_token,
direction: opts[:pad_direction]
)
]

transformations =
transformations ++
if truncate_length do
[
Encoding.Transformation.truncate(truncate_length,
direction: opts[:truncate_direction]
)
]
else
[]
end

Encoding.transform(encoding, transformations)
end)

input_ids = encodings |> Enum.map(&Encoding.get_u32_ids/1) |> u32_binaries_to_tensor()
Expand Down Expand Up @@ -174,9 +184,14 @@ defmodule Bumblebee.Utils.Tokenizers do
end

def load!(path) do
case Tokenizers.Tokenizer.from_file(path, padding: :none, truncation: :none) do
{:ok, tokenizer} -> tokenizer
{:error, error} -> raise "failed to read tokenizer from file, reason: #{error}"
case Tokenizer.from_file(path) do
{:ok, tokenizer} ->
tokenizer
|> Tokenizer.disable_padding()
|> Tokenizer.disable_truncation()

{:error, error} ->
raise "failed to read tokenizer from file, reason: #{error}"
end
end
end
4 changes: 1 addition & 3 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ defmodule Bumblebee.MixProject do
defp deps do
[
{:axon, "~> 0.5.0", axon_opts()},
# {:tokenizers, "~> 0.3"},
{:tokenizers, github: "elixir-nx/tokenizers", override: true},
{:rustler, ">= 0.0.0", optional: true},
{:tokenizers, "~> 0.4"},
# {:nx, "~> 0.5.0"},
# {:exla, "~> 0.5.0", only: [:dev, :test]},
# {:torchx, "~> 0.5.0", only: [:dev, :test]},
Expand Down
4 changes: 1 addition & 3 deletions mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,11 @@
"plug_crypto": {:hex, :plug_crypto, "1.2.5", "918772575e48e81e455818229bf719d4ab4181fcbf7f85b68a35620f78d89ced", [:mix], [], "hexpm", "26549a1d6345e2172eb1c233866756ae44a9609bd33ee6f99147ab3fd87fd842"},
"progress_bar": {:hex, :progress_bar, "3.0.0", "f54ff038c2ac540cfbb4c2bfe97c75e7116ead044f3c2b10c9f212452194b5cd", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}], "hexpm", "6981c2b25ab24aecc91a2dc46623658e1399c21a2ae24db986b90d678530f2b7"},
"ranch": {:hex, :ranch, "1.8.0", "8c7a100a139fd57f17327b6413e4167ac559fbc04ca7448e9be9057311597a1d", [:make, :rebar3], [], "hexpm", "49fbcfd3682fab1f5d109351b61257676da1a2fdbe295904176d5e521a2ddfe5"},
"rustler": {:hex, :rustler, "0.29.1", "880f20ae3027bd7945def6cea767f5257bc926f33ff50c0d5d5a5315883c084d", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "109497d701861bfcd26eb8f5801fe327a8eef304f56a5b63ef61151ff44ac9b6"},
"rustler_precompiled": {:hex, :rustler_precompiled, "0.6.2", "d2218ba08a43fa331957f30481d00b666664d7e3861431b02bd3f4f30eec8e5b", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "b9048eaed8d7d14a53f758c91865cc616608a438d2595f621f6a4b32a5511709"},
"safetensors": {:hex, :safetensors, "0.1.1", "b5859a010fb56249ecfba4799d316e96b89152576af2db7657786c55dcf2f5b6", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "dfbb525bf3debb2e2d90f840728af70da5d55f6caa091cac4d0891a4eb4c52d5"},
"stb_image": {:hex, :stb_image, "0.6.2", "d680a418416b1d778231d1d16151be3474d187e8505e1bd524aa0d08d2de094f", [:make, :mix], [{:cc_precompiler, "~> 0.1.0", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7.0", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "231ad012f649dd2bd5ef99e9171e814f3235e8f7c45009355789ac4836044a39"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
"tokenizers": {:git, "https://github.com/elixir-nx/tokenizers.git", "90dd590d5a64863e61666c3c5ebaec2d3e51841c", []},
"toml": {:hex, :toml, "0.7.0", "fbcd773caa937d0c7a02c301a1feea25612720ac3fa1ccb8bfd9d30d822911de", [:mix], [], "hexpm", "0690246a2478c1defd100b0c9b89b4ea280a22be9a7b313a8a058a2408a2fa70"},
"tokenizers": {:hex, :tokenizers, "0.4.0", "140283ca74a971391ddbd83cd8cbdb9bd03736f37a1b6989b82d245a95e1eb97", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "ef1a9824f5a893cd3b831c0e5b3d72caa250d2ec462035cc6afef6933b13a82e"},
"torchx": {:git, "https://github.com/elixir-nx/nx.git", "2c6a9d48890d70fb3937cd19b0cb3e2356008488", [sparse: "torchx"]},
"unpickler": {:hex, :unpickler, "0.1.0", "c2262c0819e6985b761e7107546cef96a485f401816be5304a65fdd200d5bd6a", [:mix], [], "hexpm", "e2b3f61e62406187ac52afead8a63bfb4e49394028993f3c4c42712743cab79e"},
"unzip": {:hex, :unzip, "0.8.0", "ee21d87c21b01567317387dab4228ac570ca15b41cfc221a067354cbf8e68c4d", [:mix], [], "hexpm", "ffa67a483efcedcb5876971a50947222e104d5f8fea2c4a0441e6f7967854827"},
Expand Down

0 comments on commit 9c5a9f8

Please sign in to comment.