From 8832ef36b1246aa75787e14be803d72e0ebfe7eb Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sat, 16 Dec 2023 14:17:20 +0100 Subject: [PATCH] fix early stopping Signed-off-by: Xavier Dupre --- onnxmltools/convert/xgboost/common.py | 8 +++++++ .../xgboost/operator_converters/XGBoost.py | 23 ++++++++++++------- tests/xgboost/test_xgboost_converters.py | 17 ++++++++++---- 3 files changed, 35 insertions(+), 13 deletions(-) diff --git a/onnxmltools/convert/xgboost/common.py b/onnxmltools/convert/xgboost/common.py index 8f267f58..3a79aaad 100644 --- a/onnxmltools/convert/xgboost/common.py +++ b/onnxmltools/convert/xgboost/common.py @@ -32,6 +32,14 @@ def get_xgb_params(xgb_node): bs = float(config["learner"]["learner_model_param"]["base_score"]) # xgboost >= 2.0 params["base_score"] = bs + + bst = xgb_node.get_booster() + if hasattr(bst, "best_ntree_limit"): + params["best_ntree_limit"] = bst.best_ntree_limit + if "gradient_booster" in config["learner"]: + gbp = config["learner"]["gradient_booster"]["gbtree_model_param"] + if "num_trees" in gbp: + params["best_ntree_limit"] = int(gbp["num_trees"]) return params diff --git a/onnxmltools/convert/xgboost/operator_converters/XGBoost.py b/onnxmltools/convert/xgboost/operator_converters/XGBoost.py index a9f31211..16cdc5df 100644 --- a/onnxmltools/convert/xgboost/operator_converters/XGBoost.py +++ b/onnxmltools/convert/xgboost/operator_converters/XGBoost.py @@ -40,6 +40,10 @@ def common_members(xgb_node, inputs): params = XGBConverter.get_xgb_params(xgb_node) objective = params["objective"] base_score = params["base_score"] + if hasattr(xgb_node, "best_iteration"): + best_ntree_limit = xgb_node.best_iteration + 1 + else: + best_ntree_limit = params.get("best_ntree_limit", None) if base_score is None: base_score = 0.5 booster = xgb_node.get_booster() @@ -47,7 +51,7 @@ def common_members(xgb_node, inputs): # XGBoost 0.7 was the first version released with it. js_tree_list = booster.get_dump(with_stats=True, dump_format="json") js_trees = [json.loads(s) for s in js_tree_list] - return objective, base_score, js_trees + return objective, base_score, js_trees, best_ntree_limit @staticmethod def _get_default_tree_attribute_pairs(is_classifier): @@ -231,7 +235,9 @@ def _get_default_tree_attribute_pairs(): def convert(scope, operator, container): xgb_node = operator.raw_operator inputs = operator.inputs - objective, base_score, js_trees = XGBConverter.common_members(xgb_node, inputs) + objective, base_score, js_trees, best_ntree_limit = XGBConverter.common_members( + xgb_node, inputs + ) if objective in ["reg:gamma", "reg:tweedie"]: raise RuntimeError("Objective '{}' not supported.".format(objective)) @@ -239,9 +245,7 @@ def convert(scope, operator, container): attr_pairs = XGBRegressorConverter._get_default_tree_attribute_pairs() attr_pairs["base_values"] = [base_score] - bst = xgb_node.get_booster() - best_ntree_limit = getattr(bst, "best_ntree_limit", len(js_trees)) - if best_ntree_limit < len(js_trees): + if best_ntree_limit and best_ntree_limit < len(js_trees): js_trees = js_trees[:best_ntree_limit] XGBConverter.fill_tree_attributes( @@ -289,7 +293,9 @@ def convert(scope, operator, container): xgb_node = operator.raw_operator inputs = operator.inputs - objective, base_score, js_trees = XGBConverter.common_members(xgb_node, inputs) + objective, base_score, js_trees, best_ntree_limit = XGBConverter.common_members( + xgb_node, inputs + ) params = XGBConverter.get_xgb_params(xgb_node) n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees) @@ -305,8 +311,9 @@ def convert(scope, operator, container): else: ncl = (max(attr_pairs["class_treeids"]) + 1) // n_estimators - bst = xgb_node.get_booster() - best_ntree_limit = getattr(bst, "best_ntree_limit", len(js_trees)) * ncl + best_ntree_limit = best_ntree_limit or len(js_trees) + if ncl > 0: + best_ntree_limit *= ncl if 0 < best_ntree_limit < len(js_trees): js_trees = js_trees[:best_ntree_limit] attr_pairs = XGBClassifierConverter._get_default_tree_attribute_pairs() diff --git a/tests/xgboost/test_xgboost_converters.py b/tests/xgboost/test_xgboost_converters.py index ce54041c..ab33e845 100644 --- a/tests/xgboost/test_xgboost_converters.py +++ b/tests/xgboost/test_xgboost_converters.py @@ -742,7 +742,7 @@ def test_xgb_classifier_13(self): assert_almost_equal(expected[1], got[1]) assert_almost_equal(expected[0], got[0]) - def test_xgb_classifier_13(self): + def test_xgb_classifier_13_2(self): this = os.path.dirname(__file__) df = pandas.read_csv(os.path.join(this, "data_bug.csv")) X, y = df.drop("y", axis=1), df["y"] @@ -753,7 +753,7 @@ def test_xgb_classifier_13(self): model_param = { "objective": "binary:logistic", "n_estimators": 1000, - "early_stopping_rounds": 120, + "early_stopping_rounds": 113, "random_state": 42, "max_depth": 3, } @@ -769,8 +769,15 @@ def test_xgb_classifier_13(self): initial_types = [("float_input", FloatTensorType([None, x_train.shape[1]]))] onnx_model = convert_xgboost(model, initial_types=initial_types) - # with open("debug.onnx", "wb") as f: - # f.write(onnx_model.SerializeToString()) + for att in onnx_model.graph.node[0].attribute: + if att.name == "nodes_treeids": + self.assertLess(max(att.ints), 1000) + if att.name == "class_ids": + self.assertEqual(set(att.ints), {0}) + if att.name == "base_values": + self.assertEqual(len(att.floats), 1) + if att.name == "post_transform": + self.assertEqual(att.s, b"LOGISTIC") expected = model.predict(x_test), model.predict_proba(x_test) sess = InferenceSession(onnx_model.SerializeToString()) @@ -780,5 +787,5 @@ def test_xgb_classifier_13(self): if __name__ == "__main__": - TestXGBoostModels().test_xgb_classifier_13() + TestXGBoostModels().test_xgb_classifier_13_2() unittest.main(verbosity=2)