Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Dec 16, 2023
1 parent c94357b commit 15f566d
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tests/xgboost/test_xgboost_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ def test_xgboost_booster_reg(self):
n_classes=2, n_features=5, n_samples=100, random_state=42, n_informative=3
)
y = y.astype(np.float32) + 0.567
print(y)
x_train, x_test, y_train, _ = train_test_split(
x, y, test_size=0.5, random_state=42
)
Expand Down Expand Up @@ -770,6 +769,8 @@ def test_xgb_classifier_13(self):

initial_types = [("float_input", FloatTensorType([None, x_train.shape[1]]))]
onnx_model = convert_xgboost(model, initial_types=initial_types)
# with open("debug.onnx", "wb") as f:
# f.write(onnx_model.SerializeToString())

expected = model.predict(x_test), model.predict_proba(x_test)
sess = InferenceSession(onnx_model.SerializeToString())
Expand All @@ -779,4 +780,5 @@ def test_xgb_classifier_13(self):


if __name__ == "__main__":
TestXGBoostModels().test_xgb_classifier_13()
unittest.main(verbosity=2)

0 comments on commit 15f566d

Please sign in to comment.