Skip to content

Commit

Permalink
test: generate max_depth parameter and test values using helper asser…
Browse files Browse the repository at this point in the history
…t function in test_sklearn_predict
  • Loading branch information
Ishticode committed Dec 30, 2023
1 parent 515200e commit b0cb95d
Showing 1 changed file with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from ivy.functional.frontends.sklearn.tree import DecisionTreeClassifier as ivy_DTC
import numpy as np
import ivy
from hypothesis import given
import ivy_tests.test_ivy.helpers as helpers
Expand Down Expand Up @@ -28,13 +27,14 @@ def _get_sklearn_predict(X, y, max_depth, DecisionTreeClassifier):
y=helpers.array_values(
shape=(5,), dtype=helpers.get_dtypes("signed_integer", prune_function=False)
),
max_depth=helpers.ints(max_value=5, min_value=1),
)
def test_sklearn_tree_predict(X, y):
def test_sklearn_tree_predict(X, y, max_depth):
try:
from sklearn.tree import DecisionTreeClassifier as sklearn_DTC
except ImportError:
print("sklearn not installed, skipping test_sklearn_tree_predict")
return
sklearn_pred = _get_sklearn_predict(X, y, 3, sklearn_DTC)(X)
ivy_pred = _get_sklearn_predict(ivy.array(X), ivy.array(y), 3, ivy_DTC)(X)
assert np.allclose(ivy_pred.to_numpy(), sklearn_pred)
sklearn_pred = _get_sklearn_predict(X, y, max_depth, sklearn_DTC)(X)
ivy_pred = _get_sklearn_predict(ivy.array(X), ivy.array(y), max_depth, ivy_DTC)(X)
helpers.assert_same_type_and_shape([sklearn_pred, ivy_pred])

0 comments on commit b0cb95d

Please sign in to comment.