Skip to content

Commit

Permalink
Support doc link for the sklearn module. (#10287)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Aug 5, 2024
1 parent a269055 commit 3d8107a
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 13 deletions.
32 changes: 22 additions & 10 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,27 @@ def _register_log_callback(lib: ctypes.CDLL) -> None:
raise XGBoostError(lib.XGBGetLastError())


def _parse_version(ver: str) -> Tuple[Tuple[int, int, int], str]:
"""Avoid dependency on packaging (PEP 440)."""
# 2.0.0-dev, 2.0.0, 2.0.0.post1, or 2.0.0rc1
if ver.find("post") != -1:
major, minor, patch = ver.split(".")[:-1]
postfix = ver.split(".")[-1]
elif "-dev" in ver:
major, minor, patch = ver.split("-")[0].split(".")
postfix = "dev"
else:
major, minor, patch = ver.split(".")
rc = patch.find("rc")
if rc != -1:
postfix = patch[rc:]
patch = patch[:rc]
else:
postfix = ""

return (int(major), int(minor), int(patch)), postfix


def _load_lib() -> ctypes.CDLL:
"""Load xgboost Library."""
lib_paths = find_lib_path()
Expand Down Expand Up @@ -237,17 +258,8 @@ def _load_lib() -> ctypes.CDLL:
)
_register_log_callback(lib)

def parse(ver: str) -> Tuple[int, int, int]:
"""Avoid dependency on packaging (PEP 440)."""
# 2.0.0-dev, 2.0.0, or 2.0.0rc1
major, minor, patch = ver.split("-")[0].split(".")
rc = patch.find("rc")
if rc != -1:
patch = patch[:rc]
return int(major), int(minor), int(patch)

libver = _lib_version(lib)
pyver = parse(_py_version())
pyver, _ = _parse_version(_py_version())

# verify that we are loading the correct binary.
if pyver != libver:
Expand Down
28 changes: 28 additions & 0 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
XGBoostError,
_deprecate_positional_args,
_parse_eval_str,
_parse_version,
_py_version,
)
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_alike, _is_pandas_df
from .training import train
Expand Down Expand Up @@ -795,6 +797,32 @@ def _more_tags(self) -> Dict[str, bool]:
def __sklearn_is_fitted__(self) -> bool:
return hasattr(self, "_Booster")

@property
def _doc_link_module(self) -> str:
return "xgboost"

@property
def _doc_link_template(self) -> str:
ver = _py_version()
(major, minor, _), post = _parse_version(ver)

if post == "dev":
rel = "latest"
else:
# RTD tracks the release branch. We don't have independent branches for
# patch releases.
rel = f"release_{major}.{minor}.0"

module = self.__class__.__module__
# All sklearn estimators are forwarded to the top level module in both source
# code and sphinx api doc.
if module == "xgboost.sklearn":
module = module.split(".")[0]
name = self.__class__.__name__

base = "https://xgboost.readthedocs.io/en"
return f"{base}/{rel}/python/python_api.html#{module}.{name}"

def get_booster(self) -> Booster:
"""Get the underlying xgboost Booster of this model.
Expand Down
2 changes: 1 addition & 1 deletion tests/ci_build/conda_env/linux_cpu_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- pylint
- numpy
- scipy
- scikit-learn
- scikit-learn>=1.4.1
- pandas
- matplotlib
- dask
Expand Down
2 changes: 1 addition & 1 deletion tests/ci_build/conda_env/macos_cpu_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:
- numpy
- scipy
- llvm-openmp
- scikit-learn
- scikit-learn>=1.4.1
- pandas
- matplotlib
- dask
Expand Down
2 changes: 1 addition & 1 deletion tests/ci_build/conda_env/win64_cpu_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dependencies:
- wheel
- numpy
- scipy
- scikit-learn
- scikit-learn>=1.4.1
- pandas
- matplotlib
- dask
Expand Down
12 changes: 12 additions & 0 deletions tests/python/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import xgboost as xgb
from xgboost import testing as tm
from xgboost.core import _parse_version

dpath = "demo/data/"
rng = np.random.RandomState(1994)
Expand Down Expand Up @@ -315,3 +316,14 @@ def test_Booster_init_invalid_path(self):
"""An invalid model_file path should raise XGBoostError."""
with pytest.raises(xgb.core.XGBoostError):
xgb.Booster(model_file=Path("invalidpath"))


def test_parse_ver() -> None:
(major, minor, patch), post = _parse_version("2.1.0")
assert post == ""
(major, minor, patch), post = _parse_version("2.1.0-dev")
assert post == "dev"
(major, minor, patch), post = _parse_version("2.1.0rc1")
assert post == "rc1"
(major, minor, patch), post = _parse_version("2.1.0.post1")
assert post == "post1"
13 changes: 13 additions & 0 deletions tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,3 +1484,16 @@ def test_tags() -> None:

tags = xgb.XGBRanker()._more_tags()
assert "multioutput" not in tags


def test_doc_link() -> None:
for est in [
xgb.XGBRegressor(),
xgb.XGBClassifier(),
xgb.XGBRanker(),
xgb.XGBRFRegressor(),
xgb.XGBRFClassifier(),
]:
name = est.__class__.__name__
link = est._get_doc_link()
assert f"xgboost.{name}" in link
14 changes: 14 additions & 0 deletions tests/test_distributed/test_with_dask/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from sklearn.datasets import make_classification, make_regression

import xgboost as xgb
from xgboost import dask as dxgb
from xgboost import testing as tm
from xgboost.data import _is_cudf_df
from xgboost.testing.params import hist_cache_strategy, hist_parameter_strategy
Expand Down Expand Up @@ -2324,3 +2325,16 @@ async def test_worker_restarted(c, s, a, b):
d_train,
evals=[(d_train, "train")],
)


def test_doc_link() -> None:
for est in [
dxgb.DaskXGBRegressor(),
dxgb.DaskXGBClassifier(),
dxgb.DaskXGBRanker(),
dxgb.DaskXGBRFRegressor(),
dxgb.DaskXGBRFClassifier(),
]:
name = est.__class__.__name__
link = est._get_doc_link()
assert f"xgboost.dask.{name}" in link

0 comments on commit 3d8107a

Please sign in to comment.