diff --git a/skpro/model_selection/_tuning.py b/skpro/model_selection/_tuning.py index 721da987..cccaac24 100644 --- a/skpro/model_selection/_tuning.py +++ b/skpro/model_selection/_tuning.py @@ -488,8 +488,7 @@ def get_test_params(cls, parameter_set="default"): from skpro.metrics import CRPS, PinballLoss from skpro.regression.residual import ResidualDouble - from skpro.survival.coxph import CoxPH - from skpro.utils.validation._dependencies import _check_estimator_deps + from skpro.survival.compose._reduce_cond_unc import ConditionUncensored linreg1 = LinearRegression() linreg2 = LinearRegression(fit_intercept=False) @@ -510,18 +509,14 @@ def get_test_params(cls, parameter_set="default"): "error_score": "raise", } - params = [param1, param2] - - # testing with survival predictor - if _check_estimator_deps(CoxPH, severity="none"): - param3 = { - "estimator": CoxPH(alpha=0.05), - "cv": KFold(n_splits=4), - "param_grid": {"method": ["lpl", "elastic_net"]}, - "scoring": PinballLoss(), - "error_score": "raise", - } - params.append(param3) + params3 = { + "estimator": ConditionUncensored(ResidualDouble(LinearRegression())), + "cv": KFold(n_splits=4), + "param_grid": {"estimator__fit_intercept": [True, False]}, + "scoring": PinballLoss(), + "error_score": "raise", + } + params = [param1, param2, params3] return params @@ -747,8 +742,7 @@ def get_test_params(cls, parameter_set="default"): from skpro.metrics import CRPS, PinballLoss from skpro.regression.residual import ResidualDouble - from skpro.survival.coxph import CoxPH - from skpro.utils.validation._dependencies import _check_estimator_deps + from skpro.survival.compose._reduce_cond_unc import ConditionUncensored linreg1 = LinearRegression() linreg2 = LinearRegression(fit_intercept=False) @@ -769,17 +763,13 @@ def get_test_params(cls, parameter_set="default"): "error_score": "raise", } - params = [param1, param2] - - # testing with survival predictor - if _check_estimator_deps(CoxPH, severity="none"): - param3 = { - "estimator": CoxPH(alpha=0.05), - "cv": KFold(n_splits=4), - "param_distributions": {"method": ["lpl", "elastic_net"]}, - "scoring": PinballLoss(), - "error_score": "raise", - } - params += [param3] + params3 = { + "estimator": ConditionUncensored(ResidualDouble(LinearRegression())), + "cv": KFold(n_splits=4), + "param_distributions": {"estimator__fit_intercept": [True, False]}, + "scoring": PinballLoss(), + "error_score": "raise", + } + params = [param1, param2, params3] return params