Skip to content

Commit

Permalink
Update tokenizers
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Aug 9, 2023
1 parent c80c6b4 commit 5197d5d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 17 deletions.
47 changes: 31 additions & 16 deletions lib/bumblebee/utils/tokenizers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,29 @@ defmodule Bumblebee.Utils.Tokenizers do
pad_id = Tokenizer.token_to_id(tokenizer, pad_token)

encodings =
Enum.map(encodings, fn seq ->
seq =
Encoding.pad(seq, pad_length,
pad_id: pad_id,
pad_token: pad_token,
direction: opts[:pad_direction]
)

if truncate_length do
Encoding.truncate(seq, truncate_length, direction: opts[:truncate_direction])
else
seq
end
Enum.map(encodings, fn encoding ->
transformations =
[
Encoding.Transformation.pad(pad_length,
pad_id: pad_id,
pad_token: pad_token,
direction: opts[:pad_direction]
)
]

transformations =
transformations ++
if truncate_length do
[
Encoding.Transformation.truncate(truncate_length,
direction: opts[:truncate_direction]
)
]
else
[]
end

Encoding.transform(encoding, transformations)
end)

input_ids = encodings |> Enum.map(&Encoding.get_u32_ids/1) |> u32_binaries_to_tensor()
Expand Down Expand Up @@ -174,9 +184,14 @@ defmodule Bumblebee.Utils.Tokenizers do
end

def load!(path) do
case Tokenizers.Tokenizer.from_file(path, padding: :none, truncation: :none) do
{:ok, tokenizer} -> tokenizer
{:error, error} -> raise "failed to read tokenizer from file, reason: #{error}"
case Tokenizer.from_file(path) do
{:ok, tokenizer} ->
tokenizer
|> Tokenizer.disable_padding()
|> Tokenizer.disable_truncation()

{:error, error} ->
raise "failed to read tokenizer from file, reason: #{error}"
end
end
end
2 changes: 1 addition & 1 deletion mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"safetensors": {:hex, :safetensors, "0.1.1", "b5859a010fb56249ecfba4799d316e96b89152576af2db7657786c55dcf2f5b6", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "dfbb525bf3debb2e2d90f840728af70da5d55f6caa091cac4d0891a4eb4c52d5"},
"stb_image": {:hex, :stb_image, "0.6.2", "d680a418416b1d778231d1d16151be3474d187e8505e1bd524aa0d08d2de094f", [:make, :mix], [{:cc_precompiler, "~> 0.1.0", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7.0", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "231ad012f649dd2bd5ef99e9171e814f3235e8f7c45009355789ac4836044a39"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
"tokenizers": {:git, "https://github.com/elixir-nx/tokenizers.git", "90dd590d5a64863e61666c3c5ebaec2d3e51841c", []},
"tokenizers": {:git, "https://github.com/elixir-nx/tokenizers.git", "d698b6df15db388daa870c57fc7668b277f4b4d8", []},
"toml": {:hex, :toml, "0.7.0", "fbcd773caa937d0c7a02c301a1feea25612720ac3fa1ccb8bfd9d30d822911de", [:mix], [], "hexpm", "0690246a2478c1defd100b0c9b89b4ea280a22be9a7b313a8a058a2408a2fa70"},
"torchx": {:git, "https://github.com/elixir-nx/nx.git", "2c6a9d48890d70fb3937cd19b0cb3e2356008488", [sparse: "torchx"]},
"unpickler": {:hex, :unpickler, "0.1.0", "c2262c0819e6985b761e7107546cef96a485f401816be5304a65fdd200d5bd6a", [:mix], [], "hexpm", "e2b3f61e62406187ac52afead8a63bfb4e49394028993f3c4c42712743cab79e"},
Expand Down

0 comments on commit 5197d5d

Please sign in to comment.