Skip to content

Commit

Permalink
Load the tokenizer data from the memory (#836)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenbingl authored Nov 9, 2024
1 parent 14f280a commit 3da0d3c
Show file tree
Hide file tree
Showing 11 changed files with 377 additions and 334,714 deletions.
36 changes: 36 additions & 0 deletions include/ortx_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,33 @@ typedef OrtxObject OrtxStringArray;
typedef OrtxObject OrtxTokenId2DArray;
typedef OrtxObject OrtxDetokenizerCache;

struct OrtxTokenizerBlob {
const char* config_json_blob;
const char* vocab_json_blob;
const char* token_module_blob;
const char* raw_model_blob;
const char* reserved_blob_1;

const size_t config_blob_len;
const size_t vocab_blob_len;
const size_t token_module_blob_len;
const size_t raw_model_blob_len;
const size_t reserved_blob_1_len;

#ifdef __cplusplus
OrtxTokenizerBlob(const std::string_view& config_json_blob,
const std::string_view& vocab_json_blob,
const std::string_view& token_module_blob,
const std::string_view& raw_model_blob)
: config_json_blob(config_json_blob.data()), vocab_json_blob(vocab_json_blob.data()),
token_module_blob(token_module_blob.data()), raw_model_blob(raw_model_blob.data()),
config_blob_len(config_json_blob.size()),
vocab_blob_len(vocab_json_blob.size()), token_module_blob_len(token_module_blob.size()),
raw_model_blob_len(raw_model_blob.size()), reserved_blob_1(nullptr),
reserved_blob_1_len(0) {}
#endif
};


#ifdef __cplusplus
extern "C" {
Expand All @@ -26,6 +53,15 @@ extern "C" {
*/
extError_t ORTX_API_CALL OrtxCreateTokenizer(OrtxTokenizer** tokenizer, const char* tokenizer_path);

/** \brief Create a tokenizer object with the specified tokenizer blob
*
* \param tokenizer Pointer to store the created tokenizer object
* \param tokenizer_blob Pointer to the tokenizer blob
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxCreateTokenizerFromBlob(OrtxTokenizer** tokenizer, const struct OrtxTokenizerBlob* tokenizer_blob);


/** \brief Tokenize the input using the specified tokenizer
*
* \param tokenizer Pointer to the tokenizer object
Expand Down
154 changes: 60 additions & 94 deletions operators/tokenizer/bpe_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -650,53 +650,6 @@ std::string JsonFastTokenizer::TokenBytesToString(std::vector<uint8_t>& bytes) {
return result;
}

// Custom hash function for the vector key
struct VectorHash {
size_t operator()(const std::vector<uint8_t>& v) const {
std::hash<uint8_t> hasher;
size_t seed = 0;
for (uint8_t i : v) {
seed ^= hasher(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
};

// Custom equality function for the vector key
struct VectorEqual {
bool operator()(const std::vector<uint8_t>& a, const std::vector<uint8_t>& b) const {
return a == b;
}
};

OrtxStatus JsonFastTokenizer::LoadAddedTokens(const json& tok_json, const ort_extensions::TokenJsonConfig& config) {
auto added_tokens = tok_json.find("added_tokens");
if (added_tokens != tok_json.end()) {
for (const auto& token : *added_tokens) {
AddedToken added_token;
added_token.id_ = token.value("id", 0);
added_token.token_type_ = token.value("__type", "");
added_token.content_ = token.value("content", "");
added_token.lstrip_ = token.value("lstrip", false);
added_token.normalized_ = token.value("normalized", false);
added_token.rstrip_ = token.value("rstrip", false);
added_token.single_word_ = token.value("single_word", false);
added_token.special_ = token.value("special", false);

added_tokens_.emplace_back(added_token);
if (added_token.content_ == config.bos_token_) {
bos_token_id_ = added_token.id_;
} else if (added_token.content_ == config.eos_token_) {
eos_token_id_ = added_token.id_;
} else if (added_token.content_ == config.pad_token_) {
pad_token_id_ = added_token.id_;
}
}
}

return bbpe_tokenizer_->LoadAddedTokens(added_tokens_);
}

// Helper methods (to be added to the class declaration)
void JsonFastTokenizer::LoadSpmModelParams(const json& tok_json) {
auto decoder_node = tok_json.find("decoder");
Expand All @@ -722,7 +675,29 @@ void JsonFastTokenizer::LoadSpmModelParams(const json& tok_json) {
}
}

void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::TokenJsonConfig& config) {
void JsonFastTokenizer::UpdateTokenizer(const TokenJsonConfig& config, const json& tok_json) {
added_tokens_ = config.added_tokens_;
auto added_tokens = tok_json.find("added_tokens");
if (added_tokens != tok_json.end()) {
for (const auto& token : *added_tokens) {
added_tokens_.emplace_back(TokenJsonConfig::ParseAddedToken(token));
}
}

for (const auto& added_token : added_tokens_) {
if (added_token.content_ == config.bos_token_) {
bos_token_id_ = added_token.id_;
} else if (added_token.content_ == config.eos_token_) {
eos_token_id_ = added_token.id_;
} else if (added_token.content_ == config.pad_token_) {
pad_token_id_ = added_token.id_;
}
}

bbpe_tokenizer_->LoadAddedTokens(added_tokens_);
add_bos_token_ = config.add_bos_token_;
add_eos_token_ = config.add_eos_token_;

if (!config.add_bos_token_ && !config.bos_token_.empty()) {
auto post_processor = tok_json.find("post_processor");
if (post_processor != tok_json.end()) {
Expand All @@ -738,14 +713,14 @@ void JsonFastTokenizer::UpdateTokenAdditionFlags(const json& tok_json, const ort
}

OrtxStatus JsonFastTokenizer::Load(const ort_extensions::TokenJsonConfig& config) {
std::string voc_file = config.GetVocabDataFile();
std::ifstream ifs = path(voc_file).open();
if (!ifs.is_open()) {
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open json file: " + voc_file);
std::unique_ptr<std::istream> vocab_stream;
auto status = config.OpenVocabFile(vocab_stream);
if (!status.IsOk()) {
return status;
}

nlohmann::json tok_json;
ifs >> tok_json;
*vocab_stream >> tok_json;

const char token_sub[] = "Tokenizer";
model_name_ = config.tokenizer_class_.substr(0, config.tokenizer_class_.find(token_sub));
Expand All @@ -767,38 +742,48 @@ OrtxStatus JsonFastTokenizer::Load(const ort_extensions::TokenJsonConfig& config
}

bbpe_tokenizer_ = std::make_unique<BpeModel>();
OrtxStatus status = bbpe_tokenizer_->Load(*model_node,
status = bbpe_tokenizer_->Load(*model_node,
bpe_conf_.get().GetSpecialTokens().c_str(),
bpe_conf_.get().spm_model_);
if (!status.IsOk()) {
return status;
}

status = LoadAddedTokens(tok_json, config);
if (!status.IsOk()) {
return status;
if (status.IsOk()) {
UpdateTokenizer(config, tok_json);
}

add_bos_token_ = config.add_bos_token_;
add_eos_token_ = config.add_eos_token_;
UpdateTokenAdditionFlags(tok_json, config);

return status;
}

// Custom hash function for the vector key
struct VectorHash {
size_t operator()(const std::vector<uint8_t>& v) const {
std::hash<uint8_t> hasher;
size_t seed = 0;
for (uint8_t i : v) {
seed ^= hasher(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
};

// Custom equality function for the vector key
struct VectorEqual {
bool operator()(const std::vector<uint8_t>& a, const std::vector<uint8_t>& b) const {
return a == b;
}
};

OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::TokenJsonConfig& config) {
std::string voc_file = config.GetVocabDataFile();
std::ifstream ifs = path(voc_file).open();
if (!ifs.is_open()) {
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open json file: " + voc_file);
std::unique_ptr<std::istream> vocab_stream;
auto status = config.OpenVocabFile(vocab_stream);
if (!status.IsOk()) {
return status;
}

std::unordered_map<std::string, uint32_t> vocab;
std::vector<std::pair<std::string, std::string>> merges;
std::unordered_map<std::vector<uint8_t>, uint32_t, VectorHash, VectorEqual> bpe_ranks;

std::string line;
while (std::getline(ifs, line)) {
while (std::getline(*vocab_stream, line)) {
if (!line.empty()) {
std::istringstream lineStream(line);
std::string token;
Expand Down Expand Up @@ -857,7 +842,8 @@ OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::TokenJson

// Populate merges
for (auto& val : byte_merges) {
merges.push_back({JsonFastTokenizer::TokenBytesToString(std::get<0>(val)), JsonFastTokenizer::TokenBytesToString(std::get<1>(val))});
merges.push_back({JsonFastTokenizer::TokenBytesToString(std::get<0>(val)),
JsonFastTokenizer::TokenBytesToString(std::get<1>(val))});
}

const char token_sub[] = "Tokenizer";
Expand All @@ -871,32 +857,12 @@ OrtxStatus JsonFastTokenizer::LoadTikTokenBase64(const ort_extensions::TokenJson
// re-bind the configuration object
bpe_conf_ = json_conf_;

OrtxStatus status = bbpe_tokenizer_->Load(vocab,
merges,
bpe_conf_.get().GetSpecialTokens().c_str(),
false);

if (!status.IsOk()) {
return status;
}

std::string module_file = config.GetTikTokenModuleFile();
std::ifstream module_ifs = path(module_file).open();
if (!module_ifs.is_open()) {
return OrtxStatus(kOrtxErrorInvalidFile, "Failed to open module file: " + module_file);
}
status = bbpe_tokenizer_->Load(vocab, merges, bpe_conf_.get().GetSpecialTokens().c_str(), false);

nlohmann::json tok_json;
module_ifs >> tok_json;
status = LoadAddedTokens(tok_json, config);
if (!status.IsOk()) {
return status;
if (status.IsOk()) {
UpdateTokenizer(config, json());
}

add_bos_token_ = config.add_bos_token_;
add_eos_token_ = config.add_eos_token_;
UpdateTokenAdditionFlags(tok_json, config);

return status;
}

Expand Down
3 changes: 1 addition & 2 deletions operators/tokenizer/bpe_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,7 @@ class JsonFastTokenizer : public KernelBpeTokenizer {
private:
std::string TokenBytesToString(std::vector<uint8_t>& bytes);
void LoadSpmModelParams(const json& tok_json);
void UpdateTokenAdditionFlags(const json& tok_json, const ort_extensions::TokenJsonConfig& config);
OrtxStatus LoadAddedTokens(const json& tok_json, const ort_extensions::TokenJsonConfig& config);
void UpdateTokenizer(const ort_extensions::TokenJsonConfig& config, const json& tok_json);

BpeModelConf json_conf_;
std::vector<ort_extensions::AddedToken> added_tokens_;
Expand Down
4 changes: 1 addition & 3 deletions operators/tokenizer/bpe_tokenizer_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,10 @@ class BpeModel {
return {};
}

OrtxStatus LoadAddedTokens(const std::vector<AddedToken>& added_tokens) {
void LoadAddedTokens(const std::vector<AddedToken>& added_tokens) {
for (const auto& token : added_tokens) {
added_tokens_.Add(ustring(token.content_), 0, token.id_);
}

return {};
}

std::vector<std::string> BuildDecoder() const { return id2token_map_; }
Expand Down
Loading

0 comments on commit 3da0d3c

Please sign in to comment.