Skip to content

Commit

Permalink
fix many issues
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Dec 11, 2023
1 parent 34c0965 commit b03a98e
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 15 deletions.
18 changes: 14 additions & 4 deletions onnxmltools/convert/xgboost/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,19 @@ def _get_attributes(booster):
except AttributeError:
ntrees = trees // num_class if num_class > 0 else trees
else:
trees = len(res)
ntrees = booster.best_ntree_limit
num_class = trees // ntrees
config = json.loads(booster.save_config())["learner"]["learner_model_param"]
if "num_class" in config:
num_class = int(config["num_class"])
ntrees = len(res) // num_class
else:
trees = len(res)
if hasattr(booster, "best_ntree_limit"):
ntrees = booster.best_ntree_limit
elif hasattr(booster, "best_iteration"):
ntrees = booster.best_iteration
else:
raise RuntimeError("Unable to guess the number of classes.")
num_class = trees // ntrees
if num_class == 0:
raise RuntimeError(
"Unable to retrieve the number of classes, trees=%d, ntrees=%d."
Expand Down Expand Up @@ -137,7 +147,7 @@ def __init__(self, booster):
self.operator_name = "XGBRegressor"

def get_xgb_params(self):
return self.kwargs
return {k: v for k, v in self.kwargs.items() if v is not None}

def get_booster(self):
return self.booster_
Expand Down
15 changes: 13 additions & 2 deletions onnxmltools/convert/xgboost/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ def get_xgb_params(xgb_node):
else:
# XGBoost < 0.7
params = xgb_node.__dict__

params = {k: v for k, v in params.items() if v is not None}
if "n_estimators" not in params and hasattr(xgb_node, "n_estimators"):
# xgboost >= 1.0.2
params["n_estimators"] = xgb_node.n_estimators
if xgb_node.n_estimators is not None:
params["n_estimators"] = xgb_node.n_estimators
if params.get("base_score", None) is None:
# xgboost >= 2.0
if hasattr("xgb_node", "save_config"):
Expand All @@ -31,3 +32,13 @@ def get_xgb_params(xgb_node):
config["learner"]["learner_model_param"]["base_score"]
)
return params


def get_n_estimators_classifier(xgb_node, params, js_trees):
if "n_estimators" not in params:
config = json.loads(xgb_node.get_booster().save_config())
num_class = int(config["learner"]["learner_model_param"]["num_class"])
if num_class == 0:
return len(js_trees)
return len(js_trees) // num_class
return params["n_estimators"]
10 changes: 6 additions & 4 deletions onnxmltools/convert/xgboost/operator_converters/XGBoost.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
except ImportError:
XGBRFClassifier = None
from ...common._registration import register_converter
from ..common import get_xgb_params
from ..common import get_xgb_params, get_n_estimators_classifier


class XGBConverter:
Expand Down Expand Up @@ -293,11 +293,13 @@ def convert(scope, operator, container):
objective, base_score, js_trees = XGBConverter.common_members(xgb_node, inputs)

params = XGBConverter.get_xgb_params(xgb_node)
n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees)

attr_pairs = XGBClassifierConverter._get_default_tree_attribute_pairs()
XGBConverter.fill_tree_attributes(
js_trees, attr_pairs, [1 for _ in js_trees], True
)
ncl = (max(attr_pairs["class_treeids"]) + 1) // params["n_estimators"]
ncl = (max(attr_pairs["class_treeids"]) + 1) // n_estimators

bst = xgb_node.get_booster()
best_ntree_limit = getattr(bst, "best_ntree_limit", len(js_trees)) * ncl
Expand Down Expand Up @@ -373,7 +375,7 @@ def convert(scope, operator, container):
"Where", [greater, one, zero], operator.output_full_names[1]
)
elif objective in ("multi:softprob", "multi:softmax"):
ncl = len(js_trees) // params["n_estimators"]
ncl = len(js_trees) // n_estimators
if objective == "multi:softmax":
attr_pairs["post_transform"] = "NONE"
container.add_node(
Expand All @@ -385,7 +387,7 @@ def convert(scope, operator, container):
**attr_pairs,
)
elif objective == "reg:logistic":
ncl = len(js_trees) // params["n_estimators"]
ncl = len(js_trees) // n_estimators
if ncl == 1:
ncl = 2
container.add_node(
Expand Down
8 changes: 5 additions & 3 deletions onnxmltools/convert/xgboost/shape_calculators/Classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Int64TensorType,
StringTensorType,
)
from ..common import get_xgb_params
from ..common import get_xgb_params, get_n_estimators_classifier


def calculate_xgboost_classifier_output_shapes(operator):
Expand All @@ -22,13 +22,15 @@ def calculate_xgboost_classifier_output_shapes(operator):
params = get_xgb_params(xgb_node)
booster = xgb_node.get_booster()
booster.attributes()
ntrees = len(booster.get_dump(with_stats=True, dump_format="json"))
js_trees = booster.get_dump(with_stats=True, dump_format="json")
ntrees = len(js_trees)
objective = params["objective"]
n_estimators = get_n_estimators_classifier(xgb_node, params, js_trees)

if objective == "binary:logistic":
ncl = 2
else:
ncl = ntrees // params["n_estimators"]
ncl = ntrees // n_estimators
if objective == "reg:logistic" and ncl == 1:
ncl = 2
classes = xgb_node.classes_
Expand Down
1 change: 0 additions & 1 deletion tests/h2o/test_h2o_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def test_h2o_classifier_multi_cat(self):
train, test = _prepare_one_hot("airlines.csv", y)
gbm = H2OGradientBoostingEstimator(ntrees=8, max_depth=5)
mojo_path = _make_mojo(gbm, train, y=train.columns.index(y))
print("****", mojo_path)
onnx_model = _convert_mojo(mojo_path)
self.assertIsNot(onnx_model, None)
dump_data_and_model(
Expand Down
1 change: 0 additions & 1 deletion tests/xgboost/test_xgboost_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,5 +677,4 @@ def test_xgb_classifier_hinge(self):


if __name__ == "__main__":
TestXGBoostModels().test_xgb_regressor_poisson()
unittest.main(verbosity=2)

0 comments on commit b03a98e

Please sign in to comment.