Skip to content

Commit

Permalink
Fix loading Gemma tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Mar 14, 2024
1 parent 93b5580 commit 21a2533
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
10 changes: 6 additions & 4 deletions lib/bumblebee/text/pre_trained_tokenizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,12 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do
special_tokens: %{unk: "[UNK]", sep: "[SEP]", pad: "[PAD]", cls: "[CLS]", mask: "[MASK]"}
},
gemma: %{
unk: "<unk>",
bos: "<bos>",
eos: "<eos>",
pad: "<pad>"
special_tokens: %{
unk: "<unk>",
bos: "<bos>",
eos: "<eos>",
pad: "<pad>"
}
},
gpt_neo_x: %{
special_tokens: %{
Expand Down
36 changes: 36 additions & 0 deletions test/bumblebee/text/pre_trained_tokenizer_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,42 @@ defmodule Bumblebee.Text.PreTrainedTokenizerTest do
)
end

test ":gemma" do
assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "unsloth/gemma-7b-it"})

assert %Bumblebee.Text.PreTrainedTokenizer{type: :gemma} = tokenizer

inputs = Bumblebee.apply_tokenizer(tokenizer, ["Hello World"])

assert_equal(
inputs["input_ids"],
Nx.tensor([[2, 4521, 3855]])
)

assert_equal(
inputs["attention_mask"],
Nx.tensor([[1, 1, 1]])
)
end

test ":gpt_neo_x" do
assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "EleutherAI/gpt-neox-20b"})

assert %Bumblebee.Text.PreTrainedTokenizer{type: :gpt_neo_x} = tokenizer

inputs = Bumblebee.apply_tokenizer(tokenizer, ["Hello World"])

assert_equal(
inputs["input_ids"],
Nx.tensor([[12092, 3645]])
)

assert_equal(
inputs["attention_mask"],
Nx.tensor([[1, 1]])
)
end

test ":gpt2" do
assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai-community/gpt2"})

Expand Down

0 comments on commit 21a2533

Please sign in to comment.