Skip to content

Commit

Permalink
Add ESM tokenizer tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Sep 10, 2024
1 parent fd8b9a2 commit 710816e
Show file tree
Hide file tree
Showing 3 changed files with 331 additions and 0 deletions.
1 change: 1 addition & 0 deletions tests/models/all_tokenization_tests.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ export * as GPT2Tokenizer from "./gpt2/tokenization.js";
export * as T5Tokenizer from "./t5/tokenization.js";
export * as WhisperTokenizer from "./whisper/tokenization.js";
export * as FalconTokenizer from "./falcon/tokenization.js";
export * as EsmTokenizer from "./esm/tokenization.js";
export * as BlenderbotSmallTokenizer from "./blenderbot_small/tokenization.js";
322 changes: 322 additions & 0 deletions tests/models/esm/tokenization.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
import { EsmTokenizer } from "../../../src/tokenizers.js";
import { BASE_TEST_STRINGS, ESM_TEST_STRINGS } from "../test_strings.js";

export const TOKENIZER_CLASS = EsmTokenizer;
export const TEST_CONFIG = {
"Xenova/nucleotide-transformer-500m-human-ref": {
SIMPLE: {
text: BASE_TEST_STRINGS.SIMPLE,
// "tokens": ["How", "are", "you", "doing?"],
ids: [3, 0, 0, 0, 0],
decoded: "<cls> <unk> <unk> <unk> <unk>",
},
SIMPLE_WITH_PUNCTUATION: {
text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
// "tokens": ["You", "should've", "done", "this"],
ids: [3, 0, 0, 0, 0],
decoded: "<cls> <unk> <unk> <unk> <unk>",
},
NUMBERS: {
text: BASE_TEST_STRINGS.NUMBERS,
// "tokens": ["0123456789", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "100", "1000"],
ids: [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
decoded: "<cls> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk>",
},
TEXT_WITH_NUMBERS: {
text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
// "tokens": ["T", "he", "company", "was", "founded", "in", "2016."],
ids: [3, 4101, 0, 0, 0, 0, 0, 0],
decoded: "<cls> T <unk> <unk> <unk> <unk> <unk> <unk>",
},
PUNCTUATION: {
text: BASE_TEST_STRINGS.PUNCTUATION,
// "tokens": ["A", "'ll", "!!to?'d''d", "of,", "can't."],
ids: [3, 4100, 0, 0, 0, 0],
decoded: "<cls> A <unk> <unk> <unk> <unk>",
},
PYTHON_CODE: {
text: BASE_TEST_STRINGS.PYTHON_CODE,
// "tokens": ["def", "main():", "pass"],
ids: [3, 0, 0, 0],
decoded: "<cls> <unk> <unk> <unk>",
},
JAVASCRIPT_CODE: {
text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
// "tokens": ["let", "a", "=", "obj.toString();", "toString();"],
ids: [3, 0, 0, 0, 0, 0],
decoded: "<cls> <unk> <unk> <unk> <unk> <unk>",
},
NEWLINES: {
text: BASE_TEST_STRINGS.NEWLINES,
// "tokens": ["T", "his", "is", "a", "test."],
ids: [3, 4101, 0, 0, 0, 0],
decoded: "<cls> T <unk> <unk> <unk> <unk>",
},
BASIC: {
text: BASE_TEST_STRINGS.BASIC,
// "tokens": ["U", "N", "want\u00e9d,running"],
ids: [3, 0, 4104, 0],
decoded: "<cls> <unk> N <unk>",
},
CONTROL_TOKENS: {
text: BASE_TEST_STRINGS.CONTROL_TOKENS,
// "tokens": ["1\u00002\ufffd3"],
ids: [3, 0],
decoded: "<cls> <unk>",
},
HELLO_WORLD_TITLECASE: {
text: BASE_TEST_STRINGS.HELLO_WORLD_TITLECASE,
// "tokens": ["Hello", "World"],
ids: [3, 0, 0],
decoded: "<cls> <unk> <unk>",
},
HELLO_WORLD_LOWERCASE: {
text: BASE_TEST_STRINGS.HELLO_WORLD_LOWERCASE,
// "tokens": ["hello", "world"],
ids: [3, 0, 0],
decoded: "<cls> <unk> <unk>",
},
CHINESE_ONLY: {
text: BASE_TEST_STRINGS.CHINESE_ONLY,
// "tokens": ["\u751f\u6d3b\u7684\u771f\u8c1b\u662f"],
ids: [3, 0],
decoded: "<cls> <unk>",
},
LEADING_SPACE: {
text: BASE_TEST_STRINGS.LEADING_SPACE,
// "tokens": ["leading", "space"],
ids: [3, 0, 0],
decoded: "<cls> <unk> <unk>",
},
TRAILING_SPACE: {
text: BASE_TEST_STRINGS.TRAILING_SPACE,
// "tokens": ["trailing", "space"],
ids: [3, 0, 0],
decoded: "<cls> <unk> <unk>",
},
DOUBLE_SPACE: {
text: BASE_TEST_STRINGS.DOUBLE_SPACE,
// "tokens": ["Hi", "Hello"],
ids: [3, 0, 0],
decoded: "<cls> <unk> <unk>",
},
CURRENCY: {
text: BASE_TEST_STRINGS.CURRENCY,
// "tokens": ["test", "$1", "R2", "#3", "\u20ac4", "\u00a35", "\u00a56", "\u20a37", "\u20b98", "\u20b19", "test"],
ids: [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
decoded: "<cls> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk>",
},
CURRENCY_WITH_DECIMALS: {
text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
// "tokens": ["I", "bought", "an", "apple", "for", "$1.00", "at", "the", "store."],
ids: [3, 0, 0, 0, 0, 0, 0, 0, 0, 0],
decoded: "<cls> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk>",
},
ELLIPSIS: {
text: BASE_TEST_STRINGS.ELLIPSIS,
// "tokens": ["you\u2026"],
ids: [3, 0],
decoded: "<cls> <unk>",
},
TEXT_WITH_ESCAPE_CHARACTERS: {
text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS,
// "tokens": ["you\u2026"],
ids: [3, 0],
decoded: "<cls> <unk>",
},
TEXT_WITH_ESCAPE_CHARACTERS_2: {
text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS_2,
// "tokens": ["you\u2026", "you\u2026"],
ids: [3, 0, 0],
decoded: "<cls> <unk> <unk>",
},
TILDE_NORMALIZATION: {
text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
// "tokens": ["weird", "\uff5e", "edge", "\uff5e", "case"],
ids: [3, 0, 0, 0, 0, 0],
decoded: "<cls> <unk> <unk> <unk> <unk> <unk>",
},
SPIECE_UNDERSCORE: {
text: BASE_TEST_STRINGS.SPIECE_UNDERSCORE,
// "tokens": ["\u2581", "T", "his", "\u2581is", "\u2581a", "\u2581test", "\u2581."],
ids: [3, 0, 4101, 0, 0, 0, 0, 0],
decoded: "<cls> <unk> T <unk> <unk> <unk> <unk> <unk>",
},
SPECIAL_TOKENS: {
text: ESM_TEST_STRINGS.SPECIAL_TOKENS,
tokens: ["<unk>", "<pad>", "<mask>", "<cls>", "<eos>", "<bos>"],
ids: [3, 0, 1, 2, 3, 4105, 4106],
decoded: "<cls> <unk> <pad> <mask> <cls> <eos> <bos>",
},
PROTEIN_SEQUENCES_1: {
text: ESM_TEST_STRINGS.PROTEIN_SEQUENCES_1,
tokens: ["ATTCCG", "ATTCCG", "ATTCCG"],
ids: [3, 367, 367, 367],
decoded: "<cls> ATTCCG ATTCCG ATTCCG",
},
PROTEIN_SEQUENCES_2: {
text: ESM_TEST_STRINGS.PROTEIN_SEQUENCES_2,
tokens: ["ATTTCT", "CTCTCT", "CTCTGA", "GATCGA", "TCGATC", "G", "A", "T"],
ids: [3, 349, 2461, 2464, 3184, 1738, 4103, 4100, 4101],
decoded: "<cls> ATTTCT CTCTCT CTCTGA GATCGA TCGATC G A T",
},
},
"Xenova/esm2_t12_35M_UR50D": {
SIMPLE: {
text: BASE_TEST_STRINGS.SIMPLE,
// "tokens": ["H", "ow", "are", "you", "doing?"],
ids: [0, 21, 3, 3, 3, 3, 2],
decoded: "<cls> H <unk> <unk> <unk> <unk> <eos>",
},
SIMPLE_WITH_PUNCTUATION: {
text: BASE_TEST_STRINGS.SIMPLE_WITH_PUNCTUATION,
// "tokens": ["Y", "ou", "should've", "done", "this"],
ids: [0, 19, 3, 3, 3, 3, 2],
decoded: "<cls> Y <unk> <unk> <unk> <unk> <eos>",
},
NUMBERS: {
text: BASE_TEST_STRINGS.NUMBERS,
// "tokens": ["0123456789", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "100", "1000"],
ids: [0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2],
decoded: "<cls> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <eos>",
},
TEXT_WITH_NUMBERS: {
text: BASE_TEST_STRINGS.TEXT_WITH_NUMBERS,
// "tokens": ["T", "he", "company", "was", "founded", "in", "2016", "."],
ids: [0, 11, 3, 3, 3, 3, 3, 3, 29, 2],
decoded: "<cls> T <unk> <unk> <unk> <unk> <unk> <unk>. <eos>",
},
PUNCTUATION: {
text: BASE_TEST_STRINGS.PUNCTUATION,
// "tokens": ["A", "'ll", "!!to?'d''d", "of,", "can't", "."],
ids: [0, 5, 3, 3, 3, 3, 29, 2],
decoded: "<cls> A <unk> <unk> <unk> <unk>. <eos>",
},
PYTHON_CODE: {
text: BASE_TEST_STRINGS.PYTHON_CODE,
// "tokens": ["def", "main():", "pass"],
ids: [0, 3, 3, 3, 2],
decoded: "<cls> <unk> <unk> <unk> <eos>",
},
JAVASCRIPT_CODE: {
text: BASE_TEST_STRINGS.JAVASCRIPT_CODE,
// "tokens": ["let", "a", "=", "obj", ".", "to", "S", "tring();", "to", "S", "tring();"],
ids: [0, 3, 3, 3, 3, 29, 3, 8, 3, 3, 8, 3, 2],
decoded: "<cls> <unk> <unk> <unk> <unk>. <unk> S <unk> <unk> S <unk> <eos>",
},
NEWLINES: {
text: BASE_TEST_STRINGS.NEWLINES,
// "tokens": ["T", "his", "is", "a", "test", "."],
ids: [0, 11, 3, 3, 3, 3, 29, 2],
decoded: "<cls> T <unk> <unk> <unk> <unk>. <eos>",
},
BASIC: {
text: BASE_TEST_STRINGS.BASIC,
// "tokens": ["U", "N", "want\u00e9d,running"],
ids: [0, 26, 17, 3, 2],
decoded: "<cls> U N <unk> <eos>",
},
CONTROL_TOKENS: {
text: BASE_TEST_STRINGS.CONTROL_TOKENS,
// "tokens": ["1\u00002\ufffd3"],
ids: [0, 3, 2],
decoded: "<cls> <unk> <eos>",
},
HELLO_WORLD_TITLECASE: {
text: BASE_TEST_STRINGS.HELLO_WORLD_TITLECASE,
// "tokens": ["H", "ello", "W", "orld"],
ids: [0, 21, 3, 22, 3, 2],
decoded: "<cls> H <unk> W <unk> <eos>",
},
HELLO_WORLD_LOWERCASE: {
text: BASE_TEST_STRINGS.HELLO_WORLD_LOWERCASE,
// "tokens": ["hello", "world"],
ids: [0, 3, 3, 2],
decoded: "<cls> <unk> <unk> <eos>",
},
CHINESE_ONLY: {
text: BASE_TEST_STRINGS.CHINESE_ONLY,
// "tokens": ["\u751f\u6d3b\u7684\u771f\u8c1b\u662f"],
ids: [0, 3, 2],
decoded: "<cls> <unk> <eos>",
},
LEADING_SPACE: {
text: BASE_TEST_STRINGS.LEADING_SPACE,
// "tokens": ["leading", "space"],
ids: [0, 3, 3, 2],
decoded: "<cls> <unk> <unk> <eos>",
},
TRAILING_SPACE: {
text: BASE_TEST_STRINGS.TRAILING_SPACE,
// "tokens": ["trailing", "space"],
ids: [0, 3, 3, 2],
decoded: "<cls> <unk> <unk> <eos>",
},
DOUBLE_SPACE: {
text: BASE_TEST_STRINGS.DOUBLE_SPACE,
// "tokens": ["H", "i", "H", "ello"],
ids: [0, 21, 3, 21, 3, 2],
decoded: "<cls> H <unk> H <unk> <eos>",
},
CURRENCY: {
text: BASE_TEST_STRINGS.CURRENCY,
// "tokens": ["test", "$1", "R", "2", "#3", "\u20ac4", "\u00a35", "\u00a56", "\u20a37", "\u20b98", "\u20b19", "test"],
ids: [0, 3, 3, 10, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2],
decoded: "<cls> <unk> <unk> R <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <eos>",
},
CURRENCY_WITH_DECIMALS: {
text: BASE_TEST_STRINGS.CURRENCY_WITH_DECIMALS,
// "tokens": ["I", "bought", "an", "apple", "for", "$1", ".", "00", "at", "the", "store", "."],
ids: [0, 12, 3, 3, 3, 3, 3, 29, 3, 3, 3, 3, 29, 2],
decoded: "<cls> I <unk> <unk> <unk> <unk> <unk>. <unk> <unk> <unk> <unk>. <eos>",
},
ELLIPSIS: {
text: BASE_TEST_STRINGS.ELLIPSIS,
// "tokens": ["you\u2026"],
ids: [0, 3, 2],
decoded: "<cls> <unk> <eos>",
},
TEXT_WITH_ESCAPE_CHARACTERS: {
text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS,
// "tokens": ["you\u2026"],
ids: [0, 3, 2],
decoded: "<cls> <unk> <eos>",
},
TEXT_WITH_ESCAPE_CHARACTERS_2: {
text: BASE_TEST_STRINGS.TEXT_WITH_ESCAPE_CHARACTERS_2,
// "tokens": ["you\u2026", "you\u2026"],
ids: [0, 3, 3, 2],
decoded: "<cls> <unk> <unk> <eos>",
},
TILDE_NORMALIZATION: {
text: BASE_TEST_STRINGS.TILDE_NORMALIZATION,
// "tokens": ["weird", "\uff5e", "edge", "\uff5e", "case"],
ids: [0, 3, 3, 3, 3, 3, 2],
decoded: "<cls> <unk> <unk> <unk> <unk> <unk> <eos>",
},
SPIECE_UNDERSCORE: {
text: BASE_TEST_STRINGS.SPIECE_UNDERSCORE,
// "tokens": ["\u2581", "T", "his", "\u2581is", "\u2581a", "\u2581test", "\u2581", "."],
ids: [0, 3, 11, 3, 3, 3, 3, 3, 29, 2],
decoded: "<cls> <unk> T <unk> <unk> <unk> <unk> <unk>. <eos>",
},
SPECIAL_TOKENS: {
text: ESM_TEST_STRINGS.SPECIAL_TOKENS,
// "tokens": ["<unk>", "<pad>", "<mask>", "<cls>", "<eos>", "<bos>"],
ids: [0, 3, 1, 32, 0, 2, 3, 2],
decoded: "<cls> <unk> <pad> <mask> <cls> <eos> <unk> <eos>",
},
PROTEIN_SEQUENCES_1: {
text: ESM_TEST_STRINGS.PROTEIN_SEQUENCES_1,
tokens: ["A", "T", "T", "C", "C", "G", "A", "T", "T", "C", "C", "G", "A", "T", "T", "C", "C", "G"],
ids: [0, 5, 11, 11, 23, 23, 6, 5, 11, 11, 23, 23, 6, 5, 11, 11, 23, 23, 6, 2],
decoded: "<cls> A T T C C G A T T C C G A T T C C G <eos>",
},
PROTEIN_SEQUENCES_2: {
text: ESM_TEST_STRINGS.PROTEIN_SEQUENCES_2,
tokens: ["A", "T", "T", "T", "C", "T", "C", "T", "C", "T", "C", "T", "C", "T", "C", "T", "G", "A", "G", "A", "T", "C", "G", "A", "T", "C", "G", "A", "T", "C", "G", "A", "T"],
ids: [0, 5, 11, 11, 11, 23, 11, 23, 11, 23, 11, 23, 11, 23, 11, 23, 11, 6, 5, 6, 5, 11, 23, 6, 5, 11, 23, 6, 5, 11, 23, 6, 5, 11, 2],
decoded: "<cls> A T T T C T C T C T C T C T C T G A G A T C G A T C G A T C G A T <eos>",
},
},
};
8 changes: 8 additions & 0 deletions tests/models/test_strings.js
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,11 @@ export const FALCON_TEST_STRINGS = {
// Special case for splitting on 3 numbers
NUMBERS_SPLIT: "12 and 123 and 1234",
};

export const ESM_TEST_STRINGS = {
// Special tokens
SPECIAL_TOKENS: "<unk><pad><mask><cls><eos><bos>",
// Actual protein sequences
PROTEIN_SEQUENCES_1: "ATTCCGATTCCGATTCCG",
PROTEIN_SEQUENCES_2: "ATTTCTCTCTCTCTCTGAGATCGATCGATCGAT",
};

0 comments on commit 710816e

Please sign in to comment.