diff --git a/onnxmltools/convert/xgboost/_parse.py b/onnxmltools/convert/xgboost/_parse.py index 89b1a37a..af4df543 100644 --- a/onnxmltools/convert/xgboost/_parse.py +++ b/onnxmltools/convert/xgboost/_parse.py @@ -69,24 +69,16 @@ def _get_attributes(booster): except AttributeError: ntrees = trees // num_class if num_class > 0 else trees else: + trees = len(res) + ntrees = getattr(booster, "best_ntree_limit", trees) config = json.loads(booster.save_config())["learner"]["learner_model_param"] - if "num_class" in config: - num_class = int(config["num_class"]) - ntrees = len(res) - num_class = 1 - else: - trees = len(res) - if hasattr(booster, "best_ntree_limit"): - ntrees = booster.best_ntree_limit - elif hasattr(booster, "best_iteration"): - ntrees = booster.best_iteration - else: - raise RuntimeError("Unable to guess the number of classes.") + num_class = int(config["num_class"]) if "num_class" in config else 0 + if num_class == 0 and ntrees > 0: num_class = trees // ntrees if num_class == 0: raise RuntimeError( - "Unable to retrieve the number of classes, trees=%d, ntrees=%d." - % (trees, ntrees) + f"Unable to retrieve the number of classes, num_class={num_class}, " + f"trees={trees}, ntrees={ntrees}, config={config}." ) kwargs = atts.copy() diff --git a/onnxmltools/convert/xgboost/common.py b/onnxmltools/convert/xgboost/common.py index 1be57c4e..799f9d63 100644 --- a/onnxmltools/convert/xgboost/common.py +++ b/onnxmltools/convert/xgboost/common.py @@ -20,25 +20,24 @@ def get_xgb_params(xgb_node): config = json.loads(xgb_node.save_config()) else: config = json.loads(xgb_node.get_booster().save_config()) - num_class = int(config["learner"]["learner_model_param"]["num_class"]) params = {k: v for k, v in params.items() if v is not None} - params["num_class"] = num_class + num_class = int(config["learner"]["learner_model_param"]["num_class"]) + if num_class > 0: + params["num_class"] = num_class if "n_estimators" not in params and hasattr(xgb_node, "n_estimators"): # xgboost >= 1.0.2 if xgb_node.n_estimators is not None: params["n_estimators"] = xgb_node.n_estimators if params.get("base_score", None) is None: + bs = float(config["learner"]["learner_model_param"]["base_score"]) # xgboost >= 2.0 - params["base_score"] = float( - config["learner"]["learner_model_param"]["base_score"] - ) + params["base_score"] = bs return params def get_n_estimators_classifier(xgb_node, params, js_trees): if "n_estimators" not in params: - config = json.loads(xgb_node.get_booster().save_config()) - num_class = int(config["learner"]["learner_model_param"]["num_class"]) + num_class = params.get("num_class", 0) if num_class == 0: return len(js_trees) return len(js_trees) // num_class diff --git a/onnxmltools/convert/xgboost/operator_converters/XGBoost.py b/onnxmltools/convert/xgboost/operator_converters/XGBoost.py index ee98539e..519d94ce 100644 --- a/onnxmltools/convert/xgboost/operator_converters/XGBoost.py +++ b/onnxmltools/convert/xgboost/operator_converters/XGBoost.py @@ -161,8 +161,7 @@ def _fill_node_attributes( false_child_id=remap[jsnode["no"]], # ['children'][1]['nodeid'], weights=None, weight_id_bias=None, - missing=jsnode.get("missing", -1) - == jsnode["yes"], # ['children'][0]['nodeid'], + missing=jsnode.get("missing", -1) == jsnode["yes"], hitrate=jsnode.get("cover", 0), ) @@ -294,17 +293,17 @@ def convert(scope, operator, container): params = XGBConverter.get_xgb_params(xgb_node) n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees) + num_class = params.get("num_class", None) attr_pairs = XGBClassifierConverter._get_default_tree_attribute_pairs() XGBConverter.fill_tree_attributes( js_trees, attr_pairs, [1 for _ in js_trees], True ) - if "num_class" in params: - ncl = params["num_class"] + if num_class is not None: + ncl = num_class n_estimators = len(js_trees) // ncl else: ncl = (max(attr_pairs["class_treeids"]) + 1) // n_estimators - print("**", params) bst = xgb_node.get_booster() best_ntree_limit = getattr(bst, "best_ntree_limit", len(js_trees)) * ncl @@ -325,8 +324,9 @@ def convert(scope, operator, container): attr_pairs["post_transform"] = "LOGISTIC" attr_pairs["class_ids"] = [0 for v in attr_pairs["class_treeids"]] if js_trees[0].get("leaf", None) == 0: - attr_pairs["base_values"] = [0.5] + attr_pairs["base_values"] = [base_score] elif base_score != 0.5: + # 0.5 -> cst = 0 cst = -np.log(1 / np.float32(base_score) - 1.0) attr_pairs["base_values"] = [cst] else: diff --git a/onnxmltools/convert/xgboost/shape_calculators/Classifier.py b/onnxmltools/convert/xgboost/shape_calculators/Classifier.py index 50095fea..c38bcdac 100644 --- a/onnxmltools/convert/xgboost/shape_calculators/Classifier.py +++ b/onnxmltools/convert/xgboost/shape_calculators/Classifier.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -import json import numpy as np from ...common._registration import register_shape_calculator from ...common.utils import check_input_and_output_numbers, check_input_and_output_types @@ -27,8 +26,7 @@ def calculate_xgboost_classifier_output_shapes(operator): ntrees = len(js_trees) objective = params["objective"] n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees) - config = json.loads(xgb_node.get_booster().save_config()) - num_class = int(config["learner"]["learner_model_param"]["num_class"]) + num_class = params.get("num_class", None) if num_class is not None: ncl = num_class diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 2dc45777..00000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,20 +0,0 @@ -black -catboost -cython -dill -libsvm -lightgbm -mleap -numpy -openpyxl -pandas -pyspark -pytest -pytest-cov -pytest-spark -ruff -scikit-learn>=1.2.0 -scipy -wheel -xgboost -onnxruntime diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 361b3238..00000000 --- a/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -numpy -onnx -skl2onnx diff --git a/tests/baseline/test_convert_baseline.py b/tests/baseline/test_convert_baseline.py index 3b56bbba..44e6acd4 100644 --- a/tests/baseline/test_convert_baseline.py +++ b/tests/baseline/test_convert_baseline.py @@ -45,11 +45,11 @@ def get_diff(self, input_file, ref_file): def normalize_diff(self, diff): invalid_comparisons = [] invalid_comparisons.append( - re.compile('producer_version: "\d+\.\d+\.\d+\.\d+.*') + re.compile('producer_version: "\\d+\\.\\d+\\.\\d+\\.\\d+.*') ) - invalid_comparisons.append(re.compile('\s+name: ".*')) - invalid_comparisons.append(re.compile("ir_version: \d+")) - invalid_comparisons.append(re.compile("\s+")) + invalid_comparisons.append(re.compile('\\s+name: ".*')) + invalid_comparisons.append(re.compile("ir_version: \\d+")) + invalid_comparisons.append(re.compile("\\s+")) valid_diff = set() for line in diff: if any(comparison.match(line) for comparison in invalid_comparisons): diff --git a/tests/xgboost/test_xgboost_13.py b/tests/xgboost/test_xgboost_13.py deleted file mode 100644 index 18d31bab..00000000 --- a/tests/xgboost/test_xgboost_13.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -""" -Tests scilit-learn's tree-based methods' converters. -""" -import os -import unittest -import numpy as np -from numpy.testing import assert_almost_equal -import pandas -from sklearn.model_selection import train_test_split -from xgboost import XGBClassifier -from onnx.defs import onnx_opset_version -from onnxconverter_common.onnx_ex import DEFAULT_OPSET_NUMBER -from onnxmltools.convert import convert_xgboost -from onnxmltools.convert.common.data_types import FloatTensorType -from onnxruntime import InferenceSession - - -TARGET_OPSET = min(DEFAULT_OPSET_NUMBER, onnx_opset_version()) - - -class TestXGBoost13(unittest.TestCase): - def test_xgb_regressor(self): - this = os.path.dirname(__file__) - df = pandas.read_csv(os.path.join(this, "data_fail_empty.csv")) - X, y = df.drop("y", axis=1), df["y"] - X_train, X_test, y_train, y_test = train_test_split(X, y) - - clr = XGBClassifier( - max_delta_step=0, - tree_method="hist", - n_estimators=100, - booster="gbtree", - objective="binary:logistic", - eval_metric="logloss", - learning_rate=0.1, - gamma=10, - max_depth=7, - min_child_weight=50, - subsample=0.75, - colsample_bytree=0.75, - random_state=42, - verbosity=0, - ) - - clr.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=40) - - initial_type = [("float_input", FloatTensorType([None, 797]))] - onx = convert_xgboost( - clr, initial_types=initial_type, target_opset=TARGET_OPSET - ) - expected = clr.predict(X_test), clr.predict_proba(X_test) - sess = InferenceSession( - onx.SerializeToString(), providers=["CPUExecutionProvider"] - ) - X_test = X_test.values.astype(np.float32) - got = sess.run(None, {"float_input": X_test}) - assert_almost_equal(expected[1], got[1]) - assert_almost_equal(expected[0], got[0]) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/xgboost/test_xgboost_converters.py b/tests/xgboost/test_xgboost_converters.py index e6dd15de..b4b6d3e2 100644 --- a/tests/xgboost/test_xgboost_converters.py +++ b/tests/xgboost/test_xgboost_converters.py @@ -116,7 +116,7 @@ def test_xgb_regressor_poisson(self): basename=f"SklearnXGBRegressorPoisson{nest}-Dec3", ) - def test_xgb_classifier(self): + def test_xgb0_classifier(self): xgb, x_test = _fit_classification_model(XGBClassifier(), 2) conv_model = convert_xgboost( xgb, @@ -198,7 +198,7 @@ def test_xgb_classifier_multi_discrete_int_labels(self): basename="SklearnXGBClassifierMultiDiscreteIntLabels", ) - def test_xgboost_booster_classifier_bin(self): + def test_xgb1_booster_classifier_bin(self): x, y = make_classification( n_classes=2, n_features=5, n_samples=100, random_state=42, n_informative=3 ) @@ -218,10 +218,10 @@ def test_xgboost_booster_classifier_bin(self): target_opset=TARGET_OPSET, ) dump_data_and_model( - x_test.astype(np.float32), model, model_onnx, basename="XGBBoosterMCl" + x_test.astype(np.float32)[:20], model, model_onnx, basename="XGBBoosterMCl" ) - def test_xgboost_booster_classifier_multiclass_softprob(self): + def test_xgb0_booster_classifier_multiclass_softprob(self): x, y = make_classification( n_classes=3, n_features=5, n_samples=100, random_state=42, n_informative=3 ) @@ -431,7 +431,7 @@ def test_xgboost_example_mnist(self): X_test.astype(np.float32), clf, onnx_model, basename="XGBoostExample" ) - def test_xgb_empty_tree(self): + def test_xgb0_empty_tree_classifier(self): xgb = XGBClassifier(n_estimators=2, max_depth=2) # simple dataset @@ -451,7 +451,7 @@ def test_xgb_empty_tree(self): assert_almost_equal(xgb.predict_proba(X), res[1]) assert_almost_equal(xgb.predict(X), res[0]) - def test_xgb_best_tree_limit(self): + def test_xgb_best_tree_limit_classifier(self): # Train iris = load_iris() X, y = iris.data, iris.target @@ -499,7 +499,7 @@ def test_xgb_best_tree_limit(self): ) assert_almost_equal(bst_loaded.predict(dtest), res[0]) - def test_onnxrt_python_xgbclassifier(self): + def test_xgb_classifier(self): x = np.random.randn(100, 10).astype(np.float32) y = ((x.sum(axis=1) + np.random.randn(x.shape[0]) / 50 + 0.5) >= 0).astype( np.int64 @@ -675,7 +675,44 @@ def test_xgb_classifier_hinge(self): x_test, xgb, conv_model, basename="SklearnXGBClassifierHinge" ) + def test_xgb_classifier_13(self): + this = os.path.dirname(__file__) + df = pandas.read_csv(os.path.join(this, "data_fail_empty.csv")) + X, y = df.drop("y", axis=1), df["y"] + X_train, X_test, y_train, y_test = train_test_split(X, y) + + clr = XGBClassifier( + max_delta_step=0, + tree_method="hist", + n_estimators=100, + booster="gbtree", + objective="binary:logistic", + eval_metric="logloss", + learning_rate=0.1, + gamma=10, + max_depth=7, + min_child_weight=50, + subsample=0.75, + colsample_bytree=0.75, + random_state=42, + verbosity=0, + ) + + clr.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=40) + + initial_type = [("float_input", FloatTensorType([None, 797]))] + onx = convert_xgboost( + clr, initial_types=initial_type, target_opset=TARGET_OPSET + ) + expected = clr.predict(X_test), clr.predict_proba(X_test) + sess = InferenceSession( + onx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + X_test = X_test.values.astype(np.float32) + got = sess.run(None, {"float_input": X_test}) + assert_almost_equal(expected[1], got[1]) + assert_almost_equal(expected[0], got[0]) + if __name__ == "__main__": - TestXGBoostModels().test_xgb_best_tree_limit() unittest.main(verbosity=2)