Skip to content

Commit

Permalink
many fixes for xgboost
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Dec 15, 2023
1 parent 47cbee2 commit e4d9dd4
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 129 deletions.
20 changes: 6 additions & 14 deletions onnxmltools/convert/xgboost/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,16 @@ def _get_attributes(booster):
except AttributeError:
ntrees = trees // num_class if num_class > 0 else trees
else:
trees = len(res)
ntrees = getattr(booster, "best_ntree_limit", trees)
config = json.loads(booster.save_config())["learner"]["learner_model_param"]
if "num_class" in config:
num_class = int(config["num_class"])
ntrees = len(res)
num_class = 1
else:
trees = len(res)
if hasattr(booster, "best_ntree_limit"):
ntrees = booster.best_ntree_limit
elif hasattr(booster, "best_iteration"):
ntrees = booster.best_iteration
else:
raise RuntimeError("Unable to guess the number of classes.")
num_class = int(config["num_class"]) if "num_class" in config else 0
if num_class == 0 and ntrees > 0:
num_class = trees // ntrees
if num_class == 0:
raise RuntimeError(
"Unable to retrieve the number of classes, trees=%d, ntrees=%d."
% (trees, ntrees)
f"Unable to retrieve the number of classes, num_class={num_class}, "
f"trees={trees}, ntrees={ntrees}, config={config}."
)

kwargs = atts.copy()
Expand Down
13 changes: 6 additions & 7 deletions onnxmltools/convert/xgboost/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,24 @@ def get_xgb_params(xgb_node):
config = json.loads(xgb_node.save_config())
else:
config = json.loads(xgb_node.get_booster().save_config())
num_class = int(config["learner"]["learner_model_param"]["num_class"])
params = {k: v for k, v in params.items() if v is not None}
params["num_class"] = num_class
num_class = int(config["learner"]["learner_model_param"]["num_class"])
if num_class > 0:
params["num_class"] = num_class
if "n_estimators" not in params and hasattr(xgb_node, "n_estimators"):
# xgboost >= 1.0.2
if xgb_node.n_estimators is not None:
params["n_estimators"] = xgb_node.n_estimators
if params.get("base_score", None) is None:
bs = float(config["learner"]["learner_model_param"]["base_score"])
# xgboost >= 2.0
params["base_score"] = float(
config["learner"]["learner_model_param"]["base_score"]
)
params["base_score"] = bs
return params


def get_n_estimators_classifier(xgb_node, params, js_trees):
if "n_estimators" not in params:
config = json.loads(xgb_node.get_booster().save_config())
num_class = int(config["learner"]["learner_model_param"]["num_class"])
num_class = params.get("num_class", 0)
if num_class == 0:
return len(js_trees)
return len(js_trees) // num_class
Expand Down
12 changes: 6 additions & 6 deletions onnxmltools/convert/xgboost/operator_converters/XGBoost.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ def _fill_node_attributes(
false_child_id=remap[jsnode["no"]], # ['children'][1]['nodeid'],
weights=None,
weight_id_bias=None,
missing=jsnode.get("missing", -1)
== jsnode["yes"], # ['children'][0]['nodeid'],
missing=jsnode.get("missing", -1) == jsnode["yes"],
hitrate=jsnode.get("cover", 0),
)

Expand Down Expand Up @@ -294,17 +293,17 @@ def convert(scope, operator, container):

params = XGBConverter.get_xgb_params(xgb_node)
n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees)
num_class = params.get("num_class", None)

attr_pairs = XGBClassifierConverter._get_default_tree_attribute_pairs()
XGBConverter.fill_tree_attributes(
js_trees, attr_pairs, [1 for _ in js_trees], True
)
if "num_class" in params:
ncl = params["num_class"]
if num_class is not None:
ncl = num_class
n_estimators = len(js_trees) // ncl
else:
ncl = (max(attr_pairs["class_treeids"]) + 1) // n_estimators
print("**", params)

bst = xgb_node.get_booster()
best_ntree_limit = getattr(bst, "best_ntree_limit", len(js_trees)) * ncl
Expand All @@ -325,8 +324,9 @@ def convert(scope, operator, container):
attr_pairs["post_transform"] = "LOGISTIC"
attr_pairs["class_ids"] = [0 for v in attr_pairs["class_treeids"]]
if js_trees[0].get("leaf", None) == 0:
attr_pairs["base_values"] = [0.5]
attr_pairs["base_values"] = [base_score]
elif base_score != 0.5:
# 0.5 -> cst = 0
cst = -np.log(1 / np.float32(base_score) - 1.0)
attr_pairs["base_values"] = [cst]
else:
Expand Down
4 changes: 1 addition & 3 deletions onnxmltools/convert/xgboost/shape_calculators/Classifier.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0

import json
import numpy as np
from ...common._registration import register_shape_calculator
from ...common.utils import check_input_and_output_numbers, check_input_and_output_types
Expand All @@ -27,8 +26,7 @@ def calculate_xgboost_classifier_output_shapes(operator):
ntrees = len(js_trees)
objective = params["objective"]
n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees)
config = json.loads(xgb_node.get_booster().save_config())
num_class = int(config["learner"]["learner_model_param"]["num_class"])
num_class = params.get("num_class", None)

if num_class is not None:
ncl = num_class
Expand Down
20 changes: 0 additions & 20 deletions requirements-dev.txt

This file was deleted.

3 changes: 0 additions & 3 deletions requirements.txt

This file was deleted.

8 changes: 4 additions & 4 deletions tests/baseline/test_convert_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def get_diff(self, input_file, ref_file):
def normalize_diff(self, diff):
invalid_comparisons = []
invalid_comparisons.append(
re.compile('producer_version: "\d+\.\d+\.\d+\.\d+.*')
re.compile('producer_version: "\\d+\\.\\d+\\.\\d+\\.\\d+.*')
)
invalid_comparisons.append(re.compile('\s+name: ".*'))
invalid_comparisons.append(re.compile("ir_version: \d+"))
invalid_comparisons.append(re.compile("\s+"))
invalid_comparisons.append(re.compile('\\s+name: ".*'))
invalid_comparisons.append(re.compile("ir_version: \\d+"))
invalid_comparisons.append(re.compile("\\s+"))
valid_diff = set()
for line in diff:
if any(comparison.match(line) for comparison in invalid_comparisons):
Expand Down
64 changes: 0 additions & 64 deletions tests/xgboost/test_xgboost_13.py

This file was deleted.

53 changes: 45 additions & 8 deletions tests/xgboost/test_xgboost_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_xgb_regressor_poisson(self):
basename=f"SklearnXGBRegressorPoisson{nest}-Dec3",
)

def test_xgb_classifier(self):
def test_xgb0_classifier(self):
xgb, x_test = _fit_classification_model(XGBClassifier(), 2)
conv_model = convert_xgboost(
xgb,
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_xgb_classifier_multi_discrete_int_labels(self):
basename="SklearnXGBClassifierMultiDiscreteIntLabels",
)

def test_xgboost_booster_classifier_bin(self):
def test_xgb1_booster_classifier_bin(self):
x, y = make_classification(
n_classes=2, n_features=5, n_samples=100, random_state=42, n_informative=3
)
Expand All @@ -218,10 +218,10 @@ def test_xgboost_booster_classifier_bin(self):
target_opset=TARGET_OPSET,
)
dump_data_and_model(
x_test.astype(np.float32), model, model_onnx, basename="XGBBoosterMCl"
x_test.astype(np.float32)[:20], model, model_onnx, basename="XGBBoosterMCl"
)

def test_xgboost_booster_classifier_multiclass_softprob(self):
def test_xgb0_booster_classifier_multiclass_softprob(self):
x, y = make_classification(
n_classes=3, n_features=5, n_samples=100, random_state=42, n_informative=3
)
Expand Down Expand Up @@ -431,7 +431,7 @@ def test_xgboost_example_mnist(self):
X_test.astype(np.float32), clf, onnx_model, basename="XGBoostExample"
)

def test_xgb_empty_tree(self):
def test_xgb0_empty_tree_classifier(self):
xgb = XGBClassifier(n_estimators=2, max_depth=2)

# simple dataset
Expand All @@ -451,7 +451,7 @@ def test_xgb_empty_tree(self):
assert_almost_equal(xgb.predict_proba(X), res[1])
assert_almost_equal(xgb.predict(X), res[0])

def test_xgb_best_tree_limit(self):
def test_xgb_best_tree_limit_classifier(self):
# Train
iris = load_iris()
X, y = iris.data, iris.target
Expand Down Expand Up @@ -499,7 +499,7 @@ def test_xgb_best_tree_limit(self):
)
assert_almost_equal(bst_loaded.predict(dtest), res[0])

def test_onnxrt_python_xgbclassifier(self):
def test_xgb_classifier(self):
x = np.random.randn(100, 10).astype(np.float32)
y = ((x.sum(axis=1) + np.random.randn(x.shape[0]) / 50 + 0.5) >= 0).astype(
np.int64
Expand Down Expand Up @@ -675,7 +675,44 @@ def test_xgb_classifier_hinge(self):
x_test, xgb, conv_model, basename="SklearnXGBClassifierHinge"
)

def test_xgb_classifier_13(self):
this = os.path.dirname(__file__)
df = pandas.read_csv(os.path.join(this, "data_fail_empty.csv"))
X, y = df.drop("y", axis=1), df["y"]
X_train, X_test, y_train, y_test = train_test_split(X, y)

clr = XGBClassifier(
max_delta_step=0,
tree_method="hist",
n_estimators=100,
booster="gbtree",
objective="binary:logistic",
eval_metric="logloss",
learning_rate=0.1,
gamma=10,
max_depth=7,
min_child_weight=50,
subsample=0.75,
colsample_bytree=0.75,
random_state=42,
verbosity=0,
)

clr.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=40)

initial_type = [("float_input", FloatTensorType([None, 797]))]
onx = convert_xgboost(
clr, initial_types=initial_type, target_opset=TARGET_OPSET
)
expected = clr.predict(X_test), clr.predict_proba(X_test)
sess = InferenceSession(
onx.SerializeToString(), providers=["CPUExecutionProvider"]
)
X_test = X_test.values.astype(np.float32)
got = sess.run(None, {"float_input": X_test})
assert_almost_equal(expected[1], got[1])
assert_almost_equal(expected[0], got[0])


if __name__ == "__main__":
TestXGBoostModels().test_xgb_best_tree_limit()
unittest.main(verbosity=2)

0 comments on commit e4d9dd4

Please sign in to comment.