Skip to content

Commit

Permalink
lint.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Oct 22, 2024
1 parent e7df79f commit cfad13a
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions python-package/xgboost/testing/data_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import numpy as np

import xgboost
from xgboost import testing as tm

from ..core import DataIter, ExtMemQuantileDMatrix, QuantileDMatrix


def run_mixed_sparsity(device: str) -> None:
"""Check QDM with mixed batches."""
Expand All @@ -25,22 +26,22 @@ def run_mixed_sparsity(device: str) -> None:
X = [cp.array(batch) for batch in X]

it = tm.IteratorForTest(X, y, None, cache=None, on_host=False)
Xy_0 = xgboost.QuantileDMatrix(it)
Xy_0 = QuantileDMatrix(it)

X_1, y_1 = tm.make_sparse_regression(256, 16, 0.1, True)
X = [X_0, X_1, X_2]
y = [y_0, y_1, y_2]
X_arr = np.concatenate(X, axis=0)
y_arr = np.concatenate(y, axis=0)
Xy_1 = xgboost.QuantileDMatrix(X_arr, y_arr)
Xy_1 = QuantileDMatrix(X_arr, y_arr)

assert tm.predictor_equal(Xy_0, Xy_1)


def check_invalid_cat_batches(device: str) -> None:
"""Check error message for inconsistent feature types."""

class _InvalidCatIter(xgboost.DataIter):
class _InvalidCatIter(DataIter):
def __init__(self) -> None:
super().__init__(cache_prefix=None)
self._it = 0
Expand All @@ -57,8 +58,8 @@ def next(self, input_data: Callable) -> bool:
cat_ratio=1.0 if self._it == 0 else 0.5,
)
if device == "cuda":
import cudf
import cupy
import cudf # pylint: disable=import-error
import cupy # pylint: disable=import-error

X = cudf.DataFrame(X)
y = cupy.array(y)
Expand All @@ -74,10 +75,10 @@ def reset(self) -> None:
import pytest

with pytest.raises(ValueError, match="Inconsistent feature types between batches"):
xgboost.ExtMemQuantileDMatrix(it, enable_categorical=True)
ExtMemQuantileDMatrix(it, enable_categorical=True)


class CatIter(xgboost.DataIter): # pylint: disable=too-many-instance-attributes
class CatIter(DataIter): # pylint: disable=too-many-instance-attributes
"""An iterator for testing categorical features."""

def __init__( # pylint: disable=too-many-arguments,too-many-locals
Expand Down

0 comments on commit cfad13a

Please sign in to comment.