Skip to content

Commit

Permalink
Encode() returns tokens' string representation too
Browse files Browse the repository at this point in the history
  • Loading branch information
Daulet Zhanguzin committed Jul 7, 2023
1 parent c220985 commit 19c010f
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 57 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ Encode text and decode tokens:
fmt.Println("Vocab size:", tk.VocabSize())
// Vocab size: 30522
fmt.Println(tk.Encode("brown fox jumps over the lazy dog", false))
// [2829 4419 14523 2058 1996 13971 3899]
// [2829 4419 14523 2058 1996 13971 3899] [brown fox jumps over the lazy dog]
fmt.Println(tk.Encode("brown fox jumps over the lazy dog", true))
// [101 2829 4419 14523 2058 1996 13971 3899 102]
// [101 2829 4419 14523 2058 1996 13971 3899 102] [[CLS] brown fox jumps over the lazy dog [SEP]]
fmt.Println(tk.Decode([]uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, true))
// brown fox jumps over the lazy dog
```
Expand Down
4 changes: 2 additions & 2 deletions example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ func main() {
fmt.Println("Vocab size:", tk.VocabSize())
// Vocab size: 30522
fmt.Println(tk.Encode("brown fox jumps over the lazy dog", false))
// [2829 4419 14523 2058 1996 13971 3899]
// [2829 4419 14523 2058 1996 13971 3899] [brown fox jumps over the lazy dog]
fmt.Println(tk.Encode("brown fox jumps over the lazy dog", true))
// [101 2829 4419 14523 2058 1996 13971 3899 102]
// [101 2829 4419 14523 2058 1996 13971 3899 102] [[CLS] brown fox jumps over the lazy dog [SEP]]
fmt.Println(tk.Decode([]uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, true))
// brown fox jumps over the lazy dog
}
74 changes: 57 additions & 17 deletions lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@ use std::path::PathBuf;
use std::ptr;
use tokenizers::tokenizer::Tokenizer;

#[repr(C)]
pub struct Buffer {
ids: *mut u32,
tokens: *mut *mut libc::c_char,
len: usize,
}

#[no_mangle]
pub extern "C" fn from_bytes(bytes: *const u8, len: u32) -> *mut Tokenizer {
let bytes_slice = unsafe { std::slice::from_raw_parts(bytes, len as usize) };
Expand Down Expand Up @@ -44,15 +51,7 @@ pub extern "C" fn from_file(config: *const libc::c_char) -> *mut libc::c_void {
}

#[no_mangle]
pub extern "C" fn free_tokenizer(ptr: *mut ::libc::c_void) {
if ptr.is_null() {
return;
}
ptr.cast::<Tokenizer>();
}

#[no_mangle]
pub extern "C" fn encode(ptr: *mut libc::c_void, message: *const libc::c_char, len: *mut u32, add_special_tokens: bool) -> *mut u32 {
pub extern "C" fn encode(ptr: *mut libc::c_void, message: *const libc::c_char, add_special_tokens: bool) -> Buffer {
let tokenizer: &Tokenizer;
unsafe {
tokenizer = ptr.cast::<Tokenizer>().as_ref().expect("failed to cast tokenizer");
Expand All @@ -61,14 +60,23 @@ pub extern "C" fn encode(ptr: *mut libc::c_void, message: *const libc::c_char, l
let message = message_cstr.to_str().unwrap();

let encoding = tokenizer.encode(message, add_special_tokens).expect("failed to encode input");
let mut vec = encoding.get_ids().to_vec();
vec.shrink_to_fit();
unsafe {
*len = vec.len() as u32;
}
let vec_ptr = vec.as_mut_ptr();
std::mem::forget(vec);
vec_ptr
let mut vec_ids = encoding.get_ids().to_vec();
let mut vec_tokens = encoding.get_tokens()
.to_vec().into_iter()
.map(|s| std::ffi::CString::new(s).unwrap().into_raw())
.collect::<Vec<_>>();

vec_ids.shrink_to_fit();
vec_tokens.shrink_to_fit();

let ids = vec_ids.as_mut_ptr();
let tokens = vec_tokens.as_mut_ptr();
let len = vec_ids.len();

std::mem::forget(vec_ids);
std::mem::forget(vec_tokens);

Buffer { ids, tokens, len }
}

#[no_mangle]
Expand All @@ -92,3 +100,35 @@ pub extern "C" fn vocab_size(ptr: *mut libc::c_void) -> u32 {
}
tokenizer.get_vocab_size(true) as u32
}

#[no_mangle]
pub extern "C" fn free_tokenizer(ptr: *mut ::libc::c_void) {
if ptr.is_null() {
return;
}
ptr.cast::<Tokenizer>();
}

#[no_mangle]
pub extern "C" fn free_buffer(buf: Buffer) {
if buf.ids.is_null() {
return;
}
unsafe {
Vec::from_raw_parts(buf.ids, buf.len, buf.len);
let strings = Vec::from_raw_parts(buf.tokens, buf.len, buf.len);
for s in strings {
drop(std::ffi::CString::from_raw(s));
}
}
}

#[no_mangle]
pub extern "C" fn free_string(ptr: *mut libc::c_char) {
if ptr.is_null() {
return;
}
unsafe {
drop(std::ffi::CString::from_raw(ptr));
}
}
4 changes: 2 additions & 2 deletions release/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ func main() {
fmt.Println("Vocab size:", tk.VocabSize())
// Vocab size: 30522
fmt.Println(tk.Encode("brown fox jumps over the lazy dog", false))
// [2829 4419 14523 2058 1996 13971 3899]
// [2829 4419 14523 2058 1996 13971 3899] [brown fox jumps over the lazy dog]
fmt.Println(tk.Encode("brown fox jumps over the lazy dog", true))
// [101 2829 4419 14523 2058 1996 13971 3899 102]
// [101 2829 4419 14523 2058 1996 13971 3899 102] [[CLS] brown fox jumps over the lazy dog [SEP]]
fmt.Println(tk.Decode([]uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, true))
// brown fox jumps over the lazy dog
}
25 changes: 15 additions & 10 deletions tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,27 @@ func (t *Tokenizer) Close() error {
return nil
}

func (t *Tokenizer) Encode(str string, addSpecialTokens bool) []uint32 {
func (t *Tokenizer) Encode(str string, addSpecialTokens bool) ([]uint32, []string) {
cStr := C.CString(str)
defer C.free(unsafe.Pointer(cStr))
var len C.uint
res := C.encode(t.tokenizer, cStr, &len, C.bool(addSpecialTokens))
if len > 0 {
// can't dealloc nil
defer C.free(unsafe.Pointer(res))
res := C.encode(t.tokenizer, cStr, C.bool(addSpecialTokens))
len := int(res.len)
if len == 0 {
return nil, nil
}
slice := unsafe.Slice(res, len)
defer C.free_buffer(res)

ids := unsafe.Slice(res.ids, len)
tokenIDs := make([]uint32, len)
for i, v := range slice {
for i, v := range ids {
tokenIDs[i] = uint32(v)
}
return tokenIDs

tokens := make([]string, len)
for i, s := range (*[1 << 30]*C.char)(unsafe.Pointer(res.tokens))[:len:len] {
tokens[i] = C.GoString(s)
}
return tokenIDs, tokens
}

func (t *Tokenizer) Decode(tokenIDs []uint32, skipSpecialTokens bool) string {
Expand All @@ -78,7 +83,7 @@ func (t *Tokenizer) Decode(tokenIDs []uint32, skipSpecialTokens bool) string {
}
len := C.uint(len(tokenIDs))
res := C.decode(t.tokenizer, (*C.uint)(unsafe.Pointer(&tokenIDs[0])), len, C.bool(skipSpecialTokens))
defer C.free(unsafe.Pointer(res))
defer C.free_string(res)
return C.GoString(res)
}

Expand Down
54 changes: 33 additions & 21 deletions tokenizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,30 @@ func TestEmbeddingConfig(t *testing.T) {
name string
str string
addSpecial bool
want []uint32
wantIDs []uint32
wantTokens []string
}{
{
name: "without special tokens",
str: "brown fox jumps over the lazy dog",
addSpecial: false,
want: []uint32{0xca3f, 0x2f304, 0x5185b, 0x3c54, 0x3a89, 0x35fc3, 0x57b4},
wantIDs: []uint32{0xca3f, 0x2f304, 0x5185b, 0x3c54, 0x3a89, 0x35fc3, 0x57b4},
wantTokens: []string{"brown", "fox", "jumps", "over", "the", "lazy", "dog"},
},
{
name: "with special tokens",
str: "brown fox jumps over the lazy dog",
addSpecial: true,
want: []uint32{0x65, 0xca3f, 0x2f304, 0x5185b, 0x3c54, 0x3a89, 0x35fc3, 0x57b4, 0x66},
wantIDs: []uint32{0x65, 0xca3f, 0x2f304, 0x5185b, 0x3c54, 0x3a89, 0x35fc3, 0x57b4, 0x66},
wantTokens: []string{"[CLS]", "brown", "fox", "jumps", "over", "the", "lazy", "dog", "[SEP]"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tk.Encode(tt.str, tt.addSpecial)
got := tk.Encode(tt.str, tt.addSpecial)
assert.Equal(t, tt.want, got)
gotIDs, gotTokens := tk.Encode(tt.str, tt.addSpecial)
assert.Equal(t, tt.wantIDs, gotIDs)
assert.Equal(t, tt.wantTokens, gotTokens)
})
}
}
Expand All @@ -62,37 +66,39 @@ func TestEncode(t *testing.T) {
name string
str string
addSpecial bool
want []uint32
wantIDs []uint32
wantTokens []string
}{
{
name: "without special tokens",
str: "brown fox jumps over the lazy dog",
addSpecial: false,
want: []uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899},
wantIDs: []uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899},
wantTokens: []string{"brown", "fox", "jumps", "over", "the", "lazy", "dog"},
},
{
name: "with special tokens",
str: "brown fox jumps over the lazy dog",
addSpecial: true,
want: []uint32{101, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 102},
wantIDs: []uint32{101, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 102},
wantTokens: []string{"[CLS]", "brown", "fox", "jumps", "over", "the", "lazy", "dog", "[SEP]"},
},
{
name: "empty string",
str: "",
addSpecial: false,
want: []uint32{},
},
{
name: "empty string with special tokens",
str: "",
addSpecial: false,
want: []uint32{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tk.Encode(tt.str, tt.addSpecial)
assert.Equal(t, tt.want, got)
gotIDs, gotTokens := tk.Encode(tt.str, tt.addSpecial)
assert.Equal(t, tt.wantIDs, gotIDs)
assert.Equal(t, tt.wantTokens, gotTokens)
})
}
}
Expand All @@ -104,39 +110,44 @@ func TestEncodeWithTruncation(t *testing.T) {
addSpecial bool
maxLen int
dir tokenizers.TruncationDirection
want []uint32
wantIDs []uint32
wantTokens []string
}{
{
name: "without special tokens, left truncation",
str: "brown fox jumps over the lazy dog",
addSpecial: false,
maxLen: 5,
dir: tokenizers.TruncationDirectionLeft,
want: []uint32{0x5185b, 0x3c54, 0x3a89, 0x35fc3, 0x57b4},
wantIDs: []uint32{0x5185b, 0x3c54, 0x3a89, 0x35fc3, 0x57b4},
wantTokens: []string{"jumps", "over", "the", "lazy", "dog"},
},
{
name: "without special tokens, right truncation",
str: "brown fox jumps over the lazy dog",
addSpecial: false,
maxLen: 5,
dir: tokenizers.TruncationDirectionRight,
want: []uint32{0xca3f, 0x2f304, 0x5185b, 0x3c54, 0x3a89},
wantIDs: []uint32{0xca3f, 0x2f304, 0x5185b, 0x3c54, 0x3a89},
wantTokens: []string{"brown", "fox", "jumps", "over", "the"},
},
{
name: "with special tokens, left truncation",
str: "brown fox jumps over the lazy dog",
addSpecial: true,
maxLen: 5,
dir: tokenizers.TruncationDirectionLeft,
want: []uint32{0x65, 0x3a89, 0x35fc3, 0x57b4, 0x66},
wantIDs: []uint32{0x65, 0x3a89, 0x35fc3, 0x57b4, 0x66},
wantTokens: []string{"[CLS]", "the", "lazy", "dog", "[SEP]"},
},
{
name: "with special tokens, right truncation",
str: "brown fox jumps over the lazy dog",
addSpecial: true,
maxLen: 5,
dir: tokenizers.TruncationDirectionRight,
want: []uint32{0x65, 0xca3f, 0x2f304, 0x5185b, 0x66},
wantIDs: []uint32{0x65, 0xca3f, 0x2f304, 0x5185b, 0x66},
wantTokens: []string{"[CLS]", "brown", "fox", "jumps", "[SEP]"},
},
}
for _, tt := range tests {
Expand All @@ -146,8 +157,9 @@ func TestEncodeWithTruncation(t *testing.T) {
defer tk.Close()

tk.Encode(tt.str, tt.addSpecial)
got := tk.Encode(tt.str, tt.addSpecial)
assert.Equal(t, tt.want, got)
gotIDs, gotTokens := tk.Encode(tt.str, tt.addSpecial)
assert.Equal(t, tt.wantIDs, gotIDs)
assert.Equal(t, tt.wantTokens, gotTokens)
})
}
}
Expand Down Expand Up @@ -215,7 +227,7 @@ func BenchmarkEncodeNTimes(b *testing.B) {
expected := []uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tokens := tk.Encode("brown fox jumps over the lazy dog", false)
tokens, _ := tk.Encode("brown fox jumps over the lazy dog", false)
assert.Equal(b, expected, tokens)
}
}
Expand All @@ -230,7 +242,7 @@ func BenchmarkEncodeNChars(b *testing.B) {
}
str := string(input)
b.ResetTimer()
tokens := tk.Encode(str, false)
tokens, _ := tk.Encode(str, false)
assert.Greater(b, len(tokens), 0)
}

Expand Down
16 changes: 13 additions & 3 deletions tokenizers.h
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
#include <stdbool.h>
#include <stdint.h>

struct Buffer {
uint32_t *ids;
char *tokens;
uint32_t len;
};

void *from_bytes(const uint8_t *config, uint32_t len);

void *from_bytes_with_truncation(const uint8_t *config, uint32_t len, uint32_t max_len, uint8_t direction);

void *from_file(const char *config);

void free_tokenizer(void *ptr);

uint32_t *encode(void *ptr, const char *message, uint32_t *len, bool add_special_tokens);
struct Buffer encode(void *ptr, const char *message, bool add_special_tokens);

char *decode(void *ptr, const uint32_t *ids, uint32_t len, bool skip_special_tokens);

uint32_t vocab_size(void *ptr);

void free_tokenizer(void *ptr);

void free_buffer(struct Buffer buffer);

void free_string(char *string);

0 comments on commit 19c010f

Please sign in to comment.