Skip to content

Commit

Permalink
[EM] CPU categorical feature support.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Oct 22, 2024
1 parent e9c1784 commit a1d0bda
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 4 deletions.
7 changes: 7 additions & 0 deletions python-package/xgboost/testing/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def __init__( # pylint: disable=too-many-arguments
*,
n_batches: int,
n_cats: int,
sparsity: float,
onehot: bool,
device: str,
) -> None:
Expand Down Expand Up @@ -402,6 +403,7 @@ def _create_dmatrix( # pylint: disable=too-many-arguments
*,
n_cats: int,
device: str,
sparsity: float,
tree_method: str,
onehot: bool,
extmem: bool,
Expand All @@ -413,6 +415,7 @@ def _create_dmatrix( # pylint: disable=too-many-arguments
n_samples // n_batches,
n_features,
n_batches=n_batches,
sparsity=sparsity,
n_cats=n_cats,
onehot=onehot,
device=device,
Expand All @@ -430,6 +433,7 @@ def _create_dmatrix( # pylint: disable=too-many-arguments
n_samples,
n_features=n_features,
n_categories=n_cats,
sparsity=sparsity,
onehot=onehot,
)
Xy = xgb.DMatrix(cat, label, enable_categorical=enable_categorical)
Expand Down Expand Up @@ -463,6 +467,7 @@ def check_categorical_ohe( # pylint: disable=too-many-arguments
cols,
n_cats=cats,
device=device,
sparsity=0.0,
onehot=True,
tree_method=tree_method,
extmem=extmem,
Expand All @@ -481,6 +486,7 @@ def check_categorical_ohe( # pylint: disable=too-many-arguments
cols,
n_cats=cats,
device=device,
sparsity=0.0,
tree_method=tree_method,
onehot=False,
extmem=extmem,
Expand Down Expand Up @@ -550,6 +556,7 @@ def check_categorical_missing( # pylint: disable=too-many-arguments
rows,
cols,
n_cats=cats,
sparsity=0.5,
device=device,
tree_method=tree_method,
onehot=False,
Expand Down
2 changes: 0 additions & 2 deletions src/common/column_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
#include <limits>
#include <memory>
#include <type_traits> // for enable_if_t, is_same_v, is_signed_v
#include <utility> // for move

#include "../data/adapter.h"
#include "../data/gradient_index.h"
#include "algorithm.h"
#include "bitfield.h" // for RBitField8
#include "hist_util.h"
#include "ref_resource_view.h" // for RefResourceView
Expand Down
2 changes: 1 addition & 1 deletion src/common/partition_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class PartitionBuilder {
// Analog of std::stable_partition, but in no-inplace manner
template <bool default_left, bool any_missing, typename ColumnType, typename Predicate>
std::pair<size_t, size_t> PartitionKernel(ColumnType* p_column,
common::Span<const bst_idx_t> row_indices,
common::Span<bst_idx_t const> row_indices,
common::Span<bst_idx_t> left_part,
common::Span<bst_idx_t> right_part,
bst_idx_t base_rowid, Predicate&& pred) {
Expand Down
1 change: 1 addition & 0 deletions src/predictor/cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ struct GHistIndexMatrixView {
base_rowid{_page.base_rowid} {}

SparsePage::Inst operator[](size_t r) {
r += base_rowid;
auto t = omp_get_thread_num();
auto const beg = (n_features_ * kUnroll * t) + (current_unroll_[t] * n_features_);
size_t non_missing{static_cast<std::size_t>(beg)};
Expand Down
27 changes: 26 additions & 1 deletion tests/python/test_data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from xgboost import testing as tm
from xgboost.data import SingleBatchInternalIter as SingleBatch
from xgboost.testing import IteratorForTest, make_batches, non_increasing
from xgboost.testing.updater import check_extmem_qdm, check_quantile_loss_extmem
from xgboost.testing.updater import (
check_categorical_missing,
check_categorical_ohe,
check_extmem_qdm,
check_quantile_loss_extmem,
)

pytestmark = tm.timeout(30)

Expand Down Expand Up @@ -324,3 +329,23 @@ def test_extmem_qdm(
device="cpu",
on_host=False,
)


@pytest.mark.parametrize("tree_method", ["hist", "approx"])
def test_categorical_missing(tree_method: str) -> None:
check_categorical_missing(
1024, 4, 5, device="cpu", tree_method=tree_method, extmem=True
)


@pytest.mark.parametrize("tree_method", ["hist", "approx"])
def test_categorical_ohe(tree_method: str) -> None:
check_categorical_ohe(
rows=1024,
cols=16,
rounds=4,
cats=5,
device="cpu",
tree_method=tree_method,
extmem=True,
)

0 comments on commit a1d0bda

Please sign in to comment.