diff --git a/lib/bumblebee/utils/tokenizers.ex b/lib/bumblebee/utils/tokenizers.ex index 505478dc..5a0b723e 100644 --- a/lib/bumblebee/utils/tokenizers.ex +++ b/lib/bumblebee/utils/tokenizers.ex @@ -48,19 +48,29 @@ defmodule Bumblebee.Utils.Tokenizers do pad_id = Tokenizer.token_to_id(tokenizer, pad_token) encodings = - Enum.map(encodings, fn seq -> - seq = - Encoding.pad(seq, pad_length, - pad_id: pad_id, - pad_token: pad_token, - direction: opts[:pad_direction] - ) - - if truncate_length do - Encoding.truncate(seq, truncate_length, direction: opts[:truncate_direction]) - else - seq - end + Enum.map(encodings, fn encoding -> + transformations = + [ + Encoding.Transformation.pad(pad_length, + pad_id: pad_id, + pad_token: pad_token, + direction: opts[:pad_direction] + ) + ] + + transformations = + transformations ++ + if truncate_length do + [ + Encoding.Transformation.truncate(truncate_length, + direction: opts[:truncate_direction] + ) + ] + else + [] + end + + Encoding.transform(encoding, transformations) end) input_ids = encodings |> Enum.map(&Encoding.get_u32_ids/1) |> u32_binaries_to_tensor() @@ -174,9 +184,14 @@ defmodule Bumblebee.Utils.Tokenizers do end def load!(path) do - case Tokenizers.Tokenizer.from_file(path, padding: :none, truncation: :none) do - {:ok, tokenizer} -> tokenizer - {:error, error} -> raise "failed to read tokenizer from file, reason: #{error}" + case Tokenizer.from_file(path) do + {:ok, tokenizer} -> + tokenizer + |> Tokenizer.disable_padding() + |> Tokenizer.disable_truncation() + + {:error, error} -> + raise "failed to read tokenizer from file, reason: #{error}" end end end diff --git a/mix.exs b/mix.exs index 9374b05a..cf634dc3 100644 --- a/mix.exs +++ b/mix.exs @@ -31,9 +31,7 @@ defmodule Bumblebee.MixProject do defp deps do [ {:axon, "~> 0.5.0", axon_opts()}, - # {:tokenizers, "~> 0.3"}, - {:tokenizers, github: "elixir-nx/tokenizers", override: true}, - {:rustler, ">= 0.0.0", optional: true}, + {:tokenizers, "~> 0.4"}, # {:nx, "~> 0.5.0"}, # {:exla, "~> 0.5.0", only: [:dev, :test]}, # {:torchx, "~> 0.5.0", only: [:dev, :test]}, diff --git a/mix.lock b/mix.lock index 278b05e8..255aff73 100644 --- a/mix.lock +++ b/mix.lock @@ -27,13 +27,11 @@ "plug_crypto": {:hex, :plug_crypto, "1.2.5", "918772575e48e81e455818229bf719d4ab4181fcbf7f85b68a35620f78d89ced", [:mix], [], "hexpm", "26549a1d6345e2172eb1c233866756ae44a9609bd33ee6f99147ab3fd87fd842"}, "progress_bar": {:hex, :progress_bar, "3.0.0", "f54ff038c2ac540cfbb4c2bfe97c75e7116ead044f3c2b10c9f212452194b5cd", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}], "hexpm", "6981c2b25ab24aecc91a2dc46623658e1399c21a2ae24db986b90d678530f2b7"}, "ranch": {:hex, :ranch, "1.8.0", "8c7a100a139fd57f17327b6413e4167ac559fbc04ca7448e9be9057311597a1d", [:make, :rebar3], [], "hexpm", "49fbcfd3682fab1f5d109351b61257676da1a2fdbe295904176d5e521a2ddfe5"}, - "rustler": {:hex, :rustler, "0.29.1", "880f20ae3027bd7945def6cea767f5257bc926f33ff50c0d5d5a5315883c084d", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "109497d701861bfcd26eb8f5801fe327a8eef304f56a5b63ef61151ff44ac9b6"}, "rustler_precompiled": {:hex, :rustler_precompiled, "0.6.2", "d2218ba08a43fa331957f30481d00b666664d7e3861431b02bd3f4f30eec8e5b", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "b9048eaed8d7d14a53f758c91865cc616608a438d2595f621f6a4b32a5511709"}, "safetensors": {:hex, :safetensors, "0.1.1", "b5859a010fb56249ecfba4799d316e96b89152576af2db7657786c55dcf2f5b6", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "dfbb525bf3debb2e2d90f840728af70da5d55f6caa091cac4d0891a4eb4c52d5"}, "stb_image": {:hex, :stb_image, "0.6.2", "d680a418416b1d778231d1d16151be3474d187e8505e1bd524aa0d08d2de094f", [:make, :mix], [{:cc_precompiler, "~> 0.1.0", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7.0", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "231ad012f649dd2bd5ef99e9171e814f3235e8f7c45009355789ac4836044a39"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, - "tokenizers": {:git, "https://github.com/elixir-nx/tokenizers.git", "90dd590d5a64863e61666c3c5ebaec2d3e51841c", []}, - "toml": {:hex, :toml, "0.7.0", "fbcd773caa937d0c7a02c301a1feea25612720ac3fa1ccb8bfd9d30d822911de", [:mix], [], "hexpm", "0690246a2478c1defd100b0c9b89b4ea280a22be9a7b313a8a058a2408a2fa70"}, + "tokenizers": {:hex, :tokenizers, "0.4.0", "140283ca74a971391ddbd83cd8cbdb9bd03736f37a1b6989b82d245a95e1eb97", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "ef1a9824f5a893cd3b831c0e5b3d72caa250d2ec462035cc6afef6933b13a82e"}, "torchx": {:git, "https://github.com/elixir-nx/nx.git", "2c6a9d48890d70fb3937cd19b0cb3e2356008488", [sparse: "torchx"]}, "unpickler": {:hex, :unpickler, "0.1.0", "c2262c0819e6985b761e7107546cef96a485f401816be5304a65fdd200d5bd6a", [:mix], [], "hexpm", "e2b3f61e62406187ac52afead8a63bfb4e49394028993f3c4c42712743cab79e"}, "unzip": {:hex, :unzip, "0.8.0", "ee21d87c21b01567317387dab4228ac570ca15b41cfc221a067354cbf8e68c4d", [:mix], [], "hexpm", "ffa67a483efcedcb5876971a50947222e104d5f8fea2c4a0441e6f7967854827"},