Skip to content

Commit

Permalink
[init] Refine Python docs (#17)
Browse files Browse the repository at this point in the history
This PR provides detailed python documents. 

It also slightly changes method and parameter names for GrammarMatcher:
- matcher.vocab_size -> matcher.mask_vocab_size
- max_rollback_steps -> max_rollback_tokens
- rollback_steps -> rollback_tokens
  • Loading branch information
Ubospica authored Oct 14, 2024
1 parent c32653b commit 53e5174
Show file tree
Hide file tree
Showing 20 changed files with 436 additions and 244 deletions.
46 changes: 23 additions & 23 deletions cpp/grammar_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ class GrammarMatcher::Impl : public GrammarMatcherBase {
std::optional<std::vector<int>> stop_token_ids = std::nullopt,
bool terminate_without_stop_token = false,
std::optional<int> mask_vocab_size = std::nullopt,
int max_rollback_steps = 0
int max_rollback_tokens = 0
)
: GrammarMatcherBase(init_ctx->grammar),
init_ctx_(init_ctx),
stop_token_ids_(stop_token_ids.value_or(init_ctx->detected_stop_token_ids)),
terminate_without_stop_token_(terminate_without_stop_token),
mask_vocab_size_(mask_vocab_size.value_or(init_ctx_->vocab_size)),
max_rollback_steps_(max_rollback_steps),
max_rollback_tokens_(max_rollback_tokens),
tmp_accepted_bitset_(mask_vocab_size_) {
XGRAMMAR_CHECK(!stop_token_ids.has_value() || !stop_token_ids->empty())
<< "The stop_token_ids should not be empty";
Expand All @@ -145,16 +145,16 @@ class GrammarMatcher::Impl : public GrammarMatcherBase {
void FindNextTokenBitmask(DLTensor* next_token_bitmask);

static void GetRejectedTokensFromBitMask(
const DLTensor& token_bitmask, size_t vocab_size, std::vector<int>* rejected_tokens
const DLTensor& token_bitmask, size_t mask_vocab_size, std::vector<int>* rejected_tokens
);

std::string FindJumpForwardString();

void Rollback(int num_tokens);

int GetMaxRollbackSteps() const { return max_rollback_steps_; }
int GetMaxRollbackTokens() const { return max_rollback_tokens_; }

size_t GetVocabSize() const { return mask_vocab_size_; }
size_t GetMaskVocabSize() const { return mask_vocab_size_; }

bool IsTerminated() const;

Expand All @@ -179,7 +179,7 @@ class GrammarMatcher::Impl : public GrammarMatcherBase {
const std::vector<bool>& uncertain_tokens_bitset
);

static void CheckTokenBitmaskValidity(const DLTensor& token_bitmask, size_t vocab_size);
static void CheckTokenBitmaskValidity(const DLTensor& token_bitmask, size_t mask_vocab_size);

/*! \brief Set the acceptable next token in next_token_bitmask. */
void SetTokenBitmask(
Expand All @@ -202,7 +202,7 @@ class GrammarMatcher::Impl : public GrammarMatcherBase {
std::vector<int> stop_token_ids_;
bool terminate_without_stop_token_;
size_t mask_vocab_size_;
int max_rollback_steps_;
int max_rollback_tokens_;
std::deque<int> token_length_history;

// Temporary data for FindNextTokenBitmask. They are stored here to avoid repeated allocation.
Expand Down Expand Up @@ -279,7 +279,7 @@ bool GrammarMatcher::Impl::AcceptToken(int32_t token_id, bool verbose) {
++pos;
}
token_length_history.push_back(token.size());
if (static_cast<int>(token_length_history.size()) > max_rollback_steps_) {
if (static_cast<int>(token_length_history.size()) > max_rollback_tokens_) {
DiscardEarliestChars(token_length_history.front());
token_length_history.pop_front();
}
Expand Down Expand Up @@ -311,7 +311,7 @@ bool GrammarMatcher::Impl::AcceptString(const std::string& input_str, bool verbo
++accepted_cnt;
}
token_length_history.push_back(input_str.size());
if (static_cast<int>(token_length_history.size()) > max_rollback_steps_) {
if (static_cast<int>(token_length_history.size()) > max_rollback_tokens_) {
DiscardEarliestChars(token_length_history.front());
token_length_history.pop_front();
}
Expand All @@ -324,15 +324,15 @@ bool GrammarMatcher::Impl::AcceptString(const std::string& input_str, bool verbo
}

void GrammarMatcher::Impl::CheckTokenBitmaskValidity(
const DLTensor& token_bitmask, size_t vocab_size
const DLTensor& token_bitmask, size_t mask_vocab_size
) {
XGRAMMAR_CHECK(
token_bitmask.dtype.code == kDLInt && token_bitmask.dtype.bits == 32 && token_bitmask.data &&
token_bitmask.ndim == 1 && token_bitmask.shape
) << "The provied bitmask's shape or dtype is not valid.";
XGRAMMAR_CHECK(token_bitmask.shape[0] >= DynamicBitset::CalculateBufferSize(vocab_size))
XGRAMMAR_CHECK(token_bitmask.shape[0] >= DynamicBitset::CalculateBufferSize(mask_vocab_size))
<< "The provided bitmask is not large enough to store the token set. The length should be "
<< DynamicBitset::CalculateBufferSize(vocab_size) << " at least";
<< DynamicBitset::CalculateBufferSize(mask_vocab_size) << " at least";
}

void GrammarMatcher::Impl::FindNextTokenBitmask(DLTensor* next_token_bitmask) {
Expand Down Expand Up @@ -449,10 +449,10 @@ void GrammarMatcher::Impl::FindNextTokenBitmask(DLTensor* next_token_bitmask) {
}

void GrammarMatcher::Impl::GetRejectedTokensFromBitMask(
const DLTensor& token_bitmask, size_t vocab_size, std::vector<int>* rejected_tokens
const DLTensor& token_bitmask, size_t mask_vocab_size, std::vector<int>* rejected_tokens
) {
CheckTokenBitmaskValidity(token_bitmask, vocab_size);
DynamicBitset bitset(vocab_size, reinterpret_cast<uint32_t*>(token_bitmask.data));
CheckTokenBitmaskValidity(token_bitmask, mask_vocab_size);
DynamicBitset bitset(mask_vocab_size, reinterpret_cast<uint32_t*>(token_bitmask.data));
rejected_tokens->clear();
for (int i = bitset.FindFirstZero(); i != -1; i = bitset.FindNextZero(i)) {
rejected_tokens->push_back(i);
Expand Down Expand Up @@ -631,14 +631,14 @@ GrammarMatcher::GrammarMatcher(
std::optional<std::vector<int>> stop_token_ids,
bool terminate_without_stop_token,
std::optional<int> mask_vocab_size,
int max_rollback_steps
int max_rollback_tokens
)
: pimpl_(std::make_shared<GrammarMatcher::Impl>(
init_ctx,
stop_token_ids,
terminate_without_stop_token,
mask_vocab_size,
max_rollback_steps
max_rollback_tokens
)) {}

bool GrammarMatcher::AcceptToken(int32_t token_id, bool verbose) {
Expand All @@ -649,27 +649,27 @@ bool GrammarMatcher::AcceptString(const std::string& input_str, bool verbose) {
return pimpl_->AcceptString(input_str, verbose);
}

uint32_t GrammarMatcher::GetBufferSize(size_t vocab_size) {
return DynamicBitset::CalculateBufferSize(vocab_size);
uint32_t GrammarMatcher::GetBufferSize(size_t mask_vocab_size) {
return DynamicBitset::CalculateBufferSize(mask_vocab_size);
}

void GrammarMatcher::FindNextTokenBitmask(DLTensor* next_token_bitmask) {
pimpl_->FindNextTokenBitmask(next_token_bitmask);
}

void GrammarMatcher::GetRejectedTokensFromBitMask(
const DLTensor& token_bitmask, size_t vocab_size, std::vector<int>* rejected_tokens
const DLTensor& token_bitmask, size_t mask_vocab_size, std::vector<int>* rejected_tokens
) {
return Impl::GetRejectedTokensFromBitMask(token_bitmask, vocab_size, rejected_tokens);
return Impl::GetRejectedTokensFromBitMask(token_bitmask, mask_vocab_size, rejected_tokens);
}

std::string GrammarMatcher::FindJumpForwardString() { return pimpl_->FindJumpForwardString(); }

void GrammarMatcher::Rollback(int num_tokens) { pimpl_->Rollback(num_tokens); }

int GrammarMatcher::GetMaxRollbackSteps() const { return pimpl_->GetMaxRollbackSteps(); }
int GrammarMatcher::GetMaxRollbackTokens() const { return pimpl_->GetMaxRollbackTokens(); }

size_t GrammarMatcher::GetVocabSize() const { return pimpl_->GetVocabSize(); }
size_t GrammarMatcher::GetMaskVocabSize() const { return pimpl_->GetMaskVocabSize(); }

bool GrammarMatcher::IsTerminated() const { return pimpl_->IsTerminated(); }

Expand Down
6 changes: 3 additions & 3 deletions cpp/grammar_matcher_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
* \file xgrammar/grammar_matcher_base.h
* \brief The base class of GrammarMatcher. It implements a character-based matching automata.
*/
#ifndef XGRAMMAR_GRAMMAR_STATE_MATCHER_BASE_H_
#define XGRAMMAR_GRAMMAR_STATE_MATCHER_BASE_H_
#ifndef XGRAMMAR_GRAMMAR_MATCHER_BASE_H_
#define XGRAMMAR_GRAMMAR_MATCHER_BASE_H_

#include <xgrammar/xgrammar.h>

Expand Down Expand Up @@ -413,4 +413,4 @@ inline bool GrammarMatcherBase::ExpandRulePosition(

} // namespace xgrammar

#endif // XGRAMMAR_GRAMMAR_STATE_MATCHER_BASE_H_
#endif // XGRAMMAR_GRAMMAR_MATCHER_BASE_H_
6 changes: 3 additions & 3 deletions cpp/grammar_matcher_preproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
* \file xgrammar/grammar_matcher_preproc.h
* \brief The header for the preprocessing of the grammar matcher.
*/
#ifndef XGRAMMAR_GRAMMAR_STATE_MATCHER_PREPROC_H_
#define XGRAMMAR_GRAMMAR_STATE_MATCHER_PREPROC_H_
#ifndef XGRAMMAR_GRAMMAR_MATCHER_PREPROC_H_
#define XGRAMMAR_GRAMMAR_MATCHER_PREPROC_H_

#include <xgrammar/xgrammar.h>

Expand Down Expand Up @@ -508,4 +508,4 @@ void GrammarMatcherInitContextCache::Clear() { pimpl_->Clear(); }

} // namespace xgrammar

#endif // XGRAMMAR_GRAMMAR_STATE_MATCHER_PREPROC_H_
#endif // XGRAMMAR_GRAMMAR_MATCHER_PREPROC_H_
14 changes: 7 additions & 7 deletions cpp/grammar_matcher_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
* \file xgrammar/grammar_matcher_state.h
* \brief The header for the definition of the state used in the grammar matcher.
*/
#ifndef XGRAMMAR_GRAMMAR_STATE_MATCHER_STATE_H_
#define XGRAMMAR_GRAMMAR_STATE_MATCHER_STATE_H_
#ifndef XGRAMMAR_GRAMMAR_MATCHER_STATE_H_
#define XGRAMMAR_GRAMMAR_MATCHER_STATE_H_

#include <xgrammar/xgrammar.h>

Expand Down Expand Up @@ -224,15 +224,15 @@ class RulePositionTree {
* \details This class helps to maintain nodes by automatically maintaining the attached references.
* If a node is not existing in any stack in the history record, it will be freed.
*
* It can store up to the previous max_rollback_steps + 1 steps of history, and thus supports
* rolling back up to max_rollback_steps steps.
* It can store up to the previous max_rollback_tokens + 1 steps of history, and thus supports
* rolling back up to max_rollback_tokens steps.
*/
class StackTopsHistory {
public:
/*!
* \param tree The RulePositionTree to be associated with. Possibly modify the tree by attaching
* and removing references to the stack top nodes.
* \param max_rollback_steps The maximum number of rollback steps to be supported.
* \param max_rollback_tokens The maximum number of rollback tokens to be supported.
*/
StackTopsHistory(RulePositionTree* tree) : tree_(tree) {}

Expand All @@ -254,7 +254,7 @@ class StackTopsHistory {
* any more. */
void Rollback(int rollback_steps) {
XGRAMMAR_DCHECK(rollback_steps < static_cast<int>(stack_tops_history_.size()))
<< "The number of requested rollback steps is greater than or equal to the current "
<< "The number of requested rollback tokens is greater than or equal to the current "
"history "
<< "size: " << rollback_steps << " vs " << stack_tops_history_.size() << ".";
while (rollback_steps--) {
Expand Down Expand Up @@ -443,4 +443,4 @@ inline void StackTopsHistory::CheckWellFormed() const {

} // namespace xgrammar

#endif // XGRAMMAR_GRAMMAR_STATE_MATCHER_STATE_H_
#endif // XGRAMMAR_GRAMMAR_MATCHER_STATE_H_
4 changes: 2 additions & 2 deletions cpp/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ PYBIND11_MODULE(xgrammar_bindings, m) {
.def_static("get_rejected_tokens_from_bitmask", &GrammarMatcher_GetRejectedTokensFromBitMask)
.def("is_terminated", &GrammarMatcher::IsTerminated)
.def("reset", &GrammarMatcher::Reset)
.def_property_readonly("vocab_size", &GrammarMatcher::GetVocabSize)
.def_property_readonly("mask_vocab_size", &GrammarMatcher::GetMaskVocabSize)
.def("find_jump_forward_string", &GrammarMatcher::FindJumpForwardString)
.def("rollback", &GrammarMatcher::Rollback)
.def_property_readonly("max_rollback_steps", &GrammarMatcher::GetMaxRollbackSteps);
.def_property_readonly("max_rollback_tokens", &GrammarMatcher::GetMaxRollbackTokens);
}
6 changes: 3 additions & 3 deletions cpp/pybind/python_methods.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,19 @@ std::vector<pybind11::bytes> TokenizerInfo_GetRawVocab(TokenizerInfo& tokenizer)
}

torch::Tensor GrammarMatcher_FindNextTokenBitmask(GrammarMatcher& matcher) {
auto buffer_size = GrammarMatcher::GetBufferSize(matcher.GetVocabSize());
auto buffer_size = GrammarMatcher::GetBufferSize(matcher.GetMaskVocabSize());
auto result = torch::empty({buffer_size}, torch::dtype(torch::kInt32).device(torch::kCPU, 0));
auto result_dltensor = at::toDLPack(result)->dl_tensor;
matcher.FindNextTokenBitmask(&result_dltensor);
return result;
}

std::vector<int> GrammarMatcher_GetRejectedTokensFromBitMask(
torch::Tensor token_bitmask, size_t vocab_size
torch::Tensor token_bitmask, size_t mask_vocab_size
) {
std::vector<int> result;
auto token_bitmask_dltensor = at::toDLPack(token_bitmask)->dl_tensor;
GrammarMatcher::GetRejectedTokensFromBitMask(token_bitmask_dltensor, vocab_size, &result);
GrammarMatcher::GetRejectedTokensFromBitMask(token_bitmask_dltensor, mask_vocab_size, &result);
return result;
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/pybind/python_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ std::vector<pybind11::bytes> TokenizerInfo_GetRawVocab(TokenizerInfo& tokenizer)
torch::Tensor GrammarMatcher_FindNextTokenBitmask(GrammarMatcher& matcher);

std::vector<int> GrammarMatcher_GetRejectedTokensFromBitMask(
torch::Tensor token_bitmask, size_t vocab_size
torch::Tensor token_bitmask, size_t mask_vocab_size
);

} // namespace xgrammar
Expand Down
18 changes: 9 additions & 9 deletions include/xgrammar/xgrammar.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ class GrammarMatcherInitContext {
* GrammarMatcher matcher(init_ctx, 10);
* matcher->AcceptToken(67);
*
* // Construct a DLTensor with shape (tokenizer.GetVocabSize() + 31) / 32, and dtype uint32.
* // Construct a DLTensor with shape (tokenizer.GetMaskVocabSize() + 31) / 32, and dtype uint32.
* DLTensor next_token_bitmask = ...;
* matcher->FindNextTokenBitmask(&next_token_bitmask);
*
Expand All @@ -238,7 +238,7 @@ class GrammarMatcher {
std::optional<std::vector<int>> stop_token_ids = std::nullopt,
bool terminate_without_stop_token = false,
std::optional<int> mask_vocab_size = std::nullopt,
int max_rollback_steps = 0
int max_rollback_tokens = 0
);

/*!
Expand All @@ -255,18 +255,18 @@ class GrammarMatcher {

bool AcceptString(const std::string& input_str, bool verbose = false);

static uint32_t GetBufferSize(size_t vocab_size);
static uint32_t GetBufferSize(size_t mask_vocab_size);

/*!
* \brief Find the set of tokens that are acceptable for the next step and store them in a
* bitmask.
* \param next_token_bitmask The bitmask to store the result. The bitmask must be pre-allocated
* and with shape (GetBufferSize(vocab_size),) and dtype uint32.
* and with shape (GetBufferSize(mask_vocab_size),) and dtype uint32.
*/
void FindNextTokenBitmask(DLTensor* next_token_bitmask);

static void GetRejectedTokensFromBitMask(
const DLTensor& token_bitmask, size_t vocab_size, std::vector<int>* rejected_tokens
const DLTensor& token_bitmask, size_t mask_vocab_size, std::vector<int>* rejected_tokens
);

/*!
Expand All @@ -279,14 +279,14 @@ class GrammarMatcher {
/*!
* \brief Rollback the matcher to a previous state.
* \param num_tokens The number of tokens to rollback. It cannot exceed the current number of
* steps, nor can it exceed the specified maximum number of rollback steps.
* steps, nor can it exceed the specified maximum number of rollback tokens.
*/
void Rollback(int num_tokens = 1);

/*! \brief Get the maximum number of rollback steps allowed. */
int GetMaxRollbackSteps() const;
/*! \brief Get the maximum number of rollback tokens allowed. */
int GetMaxRollbackTokens() const;

size_t GetVocabSize() const;
size_t GetMaskVocabSize() const;

/*!
* \brief Check if the matcher has accepted the stop token and terminated.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ max-args = 10

[tool.isort]
profile = "black"
src_paths = ["python"]
Loading

0 comments on commit 53e5174

Please sign in to comment.