Skip to content

Commit

Permalink
poisson
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Dec 16, 2023
1 parent 4400198 commit 779f986
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
4 changes: 2 additions & 2 deletions onnxmltools/convert/xgboost/operator_converters/XGBoost.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ def convert(scope, operator, container):
)

if objective == "count:poisson":
cst = scope.get_unique_variable_name("half")
container.add_initializer(cst, TensorProto.FLOAT, [1], [0.5])
cst = scope.get_unique_variable_name("poisson")
container.add_initializer(cst, TensorProto.FLOAT, [1], [base_score])
new_name = scope.get_unique_variable_name("exp")
container.add_node("Exp", names, [new_name])
container.add_node("Mul", [new_name, cst], operator.output_full_names)
Expand Down
5 changes: 2 additions & 3 deletions tests/xgboost/test_xgboost_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ def test_xgb_regressor_poisson(self):
x = iris.data
y = iris.target / 100
x_train, x_test, y_train, _ = train_test_split(
x, y, test_size=0.5, random_state=42
x, y, test_size=0.5, random_state=17
)
for nest in [5, 50]:
xgb = XGBRegressor(
objective="count:poisson",
random_state=0,
random_state=5,
max_depth=3,
n_estimators=nest,
)
Expand Down Expand Up @@ -716,5 +716,4 @@ def test_xgb_classifier_13(self):


if __name__ == "__main__":
TestXGBoostModels().test_xgb_regressor_poisson()
unittest.main(verbosity=2)

0 comments on commit 779f986

Please sign in to comment.