Skip to content

Commit

Permalink
Use eos token for padding lazily (#369)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Apr 22, 2024
1 parent f0b4155 commit c0d069c
Showing 1 changed file with 7 additions and 16 deletions.
23 changes: 7 additions & 16 deletions lib/bumblebee/text/pre_trained_tokenizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,7 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do
special_tokens: %{
unk: "<|endoftext|>",
bos: "<|endoftext|>",
eos: "<|endoftext|>",
# CodeGen doesn't originally have a pad token, however when necessary
# we pad with the EOS token
pad: "<|endoftext|>"
eos: "<|endoftext|>"
}
},
distilbert: %{
Expand All @@ -152,20 +149,14 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do
special_tokens: %{
unk: "<|endoftext|>",
bos: "<|endoftext|>",
eos: "<|endoftext|>",
# GPT-NeoX doesn't originally have a pad token, however when necessary
# we pad with the EOS token
pad: "<|endoftext|>"
eos: "<|endoftext|>"
}
},
gpt2: %{
special_tokens: %{
unk: "<|endoftext|>",
bos: "<|endoftext|>",
eos: "<|endoftext|>",
# GPT-2 doesn't originally have a pad token, however when necessary
# we pad with the EOS token
pad: "<|endoftext|>"
eos: "<|endoftext|>"
}
},
layout_lm: %{
Expand All @@ -175,10 +166,7 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do
special_tokens: %{
eos: "</s>",
unk: "<unk>",
sep: "</s>",
# Llama doesn't originally have a pad token, however when necessary
# we pad with the EOS token
pad: "</s>"
sep: "</s>"
}
},
mbart: %{
Expand Down Expand Up @@ -275,8 +263,11 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do
def apply(tokenizer, input) do
input = List.wrap(input)

# Some tokenizers don't specify a PAD token, in which case we use
# the EOS token for padding by default
pad_token =
tokenizer.special_tokens[:pad] ||
tokenizer.special_tokens[:eos] ||
raise ArgumentError,
"expected the tokenizer to defined a padding token, but none was found"

Expand Down

0 comments on commit c0d069c

Please sign in to comment.