Skip to content

Commit

Permalink
fix early stopping
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Dec 16, 2023
1 parent 15f566d commit 8832ef3
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 13 deletions.
8 changes: 8 additions & 0 deletions onnxmltools/convert/xgboost/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ def get_xgb_params(xgb_node):
bs = float(config["learner"]["learner_model_param"]["base_score"])
# xgboost >= 2.0
params["base_score"] = bs

bst = xgb_node.get_booster()
if hasattr(bst, "best_ntree_limit"):
params["best_ntree_limit"] = bst.best_ntree_limit
if "gradient_booster" in config["learner"]:
gbp = config["learner"]["gradient_booster"]["gbtree_model_param"]
if "num_trees" in gbp:
params["best_ntree_limit"] = int(gbp["num_trees"])
return params


Expand Down
23 changes: 15 additions & 8 deletions onnxmltools/convert/xgboost/operator_converters/XGBoost.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,18 @@ def common_members(xgb_node, inputs):
params = XGBConverter.get_xgb_params(xgb_node)
objective = params["objective"]
base_score = params["base_score"]
if hasattr(xgb_node, "best_iteration"):
best_ntree_limit = xgb_node.best_iteration + 1
else:
best_ntree_limit = params.get("best_ntree_limit", None)
if base_score is None:
base_score = 0.5
booster = xgb_node.get_booster()
# The json format was available in October 2017.
# XGBoost 0.7 was the first version released with it.
js_tree_list = booster.get_dump(with_stats=True, dump_format="json")
js_trees = [json.loads(s) for s in js_tree_list]
return objective, base_score, js_trees
return objective, base_score, js_trees, best_ntree_limit

@staticmethod
def _get_default_tree_attribute_pairs(is_classifier):
Expand Down Expand Up @@ -231,17 +235,17 @@ def _get_default_tree_attribute_pairs():
def convert(scope, operator, container):
xgb_node = operator.raw_operator
inputs = operator.inputs
objective, base_score, js_trees = XGBConverter.common_members(xgb_node, inputs)
objective, base_score, js_trees, best_ntree_limit = XGBConverter.common_members(
xgb_node, inputs
)

if objective in ["reg:gamma", "reg:tweedie"]:
raise RuntimeError("Objective '{}' not supported.".format(objective))

attr_pairs = XGBRegressorConverter._get_default_tree_attribute_pairs()
attr_pairs["base_values"] = [base_score]

bst = xgb_node.get_booster()
best_ntree_limit = getattr(bst, "best_ntree_limit", len(js_trees))
if best_ntree_limit < len(js_trees):
if best_ntree_limit and best_ntree_limit < len(js_trees):
js_trees = js_trees[:best_ntree_limit]

XGBConverter.fill_tree_attributes(
Expand Down Expand Up @@ -289,7 +293,9 @@ def convert(scope, operator, container):
xgb_node = operator.raw_operator
inputs = operator.inputs

objective, base_score, js_trees = XGBConverter.common_members(xgb_node, inputs)
objective, base_score, js_trees, best_ntree_limit = XGBConverter.common_members(
xgb_node, inputs
)

params = XGBConverter.get_xgb_params(xgb_node)
n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees)
Expand All @@ -305,8 +311,9 @@ def convert(scope, operator, container):
else:
ncl = (max(attr_pairs["class_treeids"]) + 1) // n_estimators

bst = xgb_node.get_booster()
best_ntree_limit = getattr(bst, "best_ntree_limit", len(js_trees)) * ncl
best_ntree_limit = best_ntree_limit or len(js_trees)
if ncl > 0:
best_ntree_limit *= ncl
if 0 < best_ntree_limit < len(js_trees):
js_trees = js_trees[:best_ntree_limit]
attr_pairs = XGBClassifierConverter._get_default_tree_attribute_pairs()
Expand Down
17 changes: 12 additions & 5 deletions tests/xgboost/test_xgboost_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def test_xgb_classifier_13(self):
assert_almost_equal(expected[1], got[1])
assert_almost_equal(expected[0], got[0])

def test_xgb_classifier_13(self):
def test_xgb_classifier_13_2(self):
this = os.path.dirname(__file__)
df = pandas.read_csv(os.path.join(this, "data_bug.csv"))
X, y = df.drop("y", axis=1), df["y"]
Expand All @@ -753,7 +753,7 @@ def test_xgb_classifier_13(self):
model_param = {
"objective": "binary:logistic",
"n_estimators": 1000,
"early_stopping_rounds": 120,
"early_stopping_rounds": 113,
"random_state": 42,
"max_depth": 3,
}
Expand All @@ -769,8 +769,15 @@ def test_xgb_classifier_13(self):

initial_types = [("float_input", FloatTensorType([None, x_train.shape[1]]))]
onnx_model = convert_xgboost(model, initial_types=initial_types)
# with open("debug.onnx", "wb") as f:
# f.write(onnx_model.SerializeToString())
for att in onnx_model.graph.node[0].attribute:
if att.name == "nodes_treeids":
self.assertLess(max(att.ints), 1000)
if att.name == "class_ids":
self.assertEqual(set(att.ints), {0})
if att.name == "base_values":
self.assertEqual(len(att.floats), 1)
if att.name == "post_transform":
self.assertEqual(att.s, b"LOGISTIC")

expected = model.predict(x_test), model.predict_proba(x_test)
sess = InferenceSession(onnx_model.SerializeToString())
Expand All @@ -780,5 +787,5 @@ def test_xgb_classifier_13(self):


if __name__ == "__main__":
TestXGBoostModels().test_xgb_classifier_13()
TestXGBoostModels().test_xgb_classifier_13_2()
unittest.main(verbosity=2)

0 comments on commit 8832ef3

Please sign in to comment.