diff --git a/lib/bumblebee/utils/tokenizers.ex b/lib/bumblebee/utils/tokenizers.ex index 505478dc..5a0b723e 100644 --- a/lib/bumblebee/utils/tokenizers.ex +++ b/lib/bumblebee/utils/tokenizers.ex @@ -48,19 +48,29 @@ defmodule Bumblebee.Utils.Tokenizers do pad_id = Tokenizer.token_to_id(tokenizer, pad_token) encodings = - Enum.map(encodings, fn seq -> - seq = - Encoding.pad(seq, pad_length, - pad_id: pad_id, - pad_token: pad_token, - direction: opts[:pad_direction] - ) - - if truncate_length do - Encoding.truncate(seq, truncate_length, direction: opts[:truncate_direction]) - else - seq - end + Enum.map(encodings, fn encoding -> + transformations = + [ + Encoding.Transformation.pad(pad_length, + pad_id: pad_id, + pad_token: pad_token, + direction: opts[:pad_direction] + ) + ] + + transformations = + transformations ++ + if truncate_length do + [ + Encoding.Transformation.truncate(truncate_length, + direction: opts[:truncate_direction] + ) + ] + else + [] + end + + Encoding.transform(encoding, transformations) end) input_ids = encodings |> Enum.map(&Encoding.get_u32_ids/1) |> u32_binaries_to_tensor() @@ -174,9 +184,14 @@ defmodule Bumblebee.Utils.Tokenizers do end def load!(path) do - case Tokenizers.Tokenizer.from_file(path, padding: :none, truncation: :none) do - {:ok, tokenizer} -> tokenizer - {:error, error} -> raise "failed to read tokenizer from file, reason: #{error}" + case Tokenizer.from_file(path) do + {:ok, tokenizer} -> + tokenizer + |> Tokenizer.disable_padding() + |> Tokenizer.disable_truncation() + + {:error, error} -> + raise "failed to read tokenizer from file, reason: #{error}" end end end diff --git a/mix.lock b/mix.lock index 278b05e8..dabb1dfc 100644 --- a/mix.lock +++ b/mix.lock @@ -32,7 +32,7 @@ "safetensors": {:hex, :safetensors, "0.1.1", "b5859a010fb56249ecfba4799d316e96b89152576af2db7657786c55dcf2f5b6", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "dfbb525bf3debb2e2d90f840728af70da5d55f6caa091cac4d0891a4eb4c52d5"}, "stb_image": {:hex, :stb_image, "0.6.2", "d680a418416b1d778231d1d16151be3474d187e8505e1bd524aa0d08d2de094f", [:make, :mix], [{:cc_precompiler, "~> 0.1.0", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7.0", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "231ad012f649dd2bd5ef99e9171e814f3235e8f7c45009355789ac4836044a39"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, - "tokenizers": {:git, "https://github.com/elixir-nx/tokenizers.git", "90dd590d5a64863e61666c3c5ebaec2d3e51841c", []}, + "tokenizers": {:git, "https://github.com/elixir-nx/tokenizers.git", "d698b6df15db388daa870c57fc7668b277f4b4d8", []}, "toml": {:hex, :toml, "0.7.0", "fbcd773caa937d0c7a02c301a1feea25612720ac3fa1ccb8bfd9d30d822911de", [:mix], [], "hexpm", "0690246a2478c1defd100b0c9b89b4ea280a22be9a7b313a8a058a2408a2fa70"}, "torchx": {:git, "https://github.com/elixir-nx/nx.git", "2c6a9d48890d70fb3937cd19b0cb3e2356008488", [sparse: "torchx"]}, "unpickler": {:hex, :unpickler, "0.1.0", "c2262c0819e6985b761e7107546cef96a485f401816be5304a65fdd200d5bd6a", [:mix], [], "hexpm", "e2b3f61e62406187ac52afead8a63bfb4e49394028993f3c4c42712743cab79e"},