Skip to content

Commit

Permalink
FIX: Ver
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Nov 11, 2024
1 parent 1fd06b5 commit b47bed9
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
4 changes: 3 additions & 1 deletion mne/decoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import check_scoring
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
from sklearn.utils import check_array, get_tags, indexable
from sklearn.utils import check_array, indexable

from ..parallel import parallel_func
from ..utils import _pl, logger, verbose, warn
Expand Down Expand Up @@ -83,6 +83,8 @@ def __init__(self, model=None):

def __sklearn_tags__(self):
"""Get sklearn tags."""
from sklearn.utils import get_tags # added in 1.6

return get_tags(self.model)

def __getattr__(self, attr):
Expand Down
4 changes: 3 additions & 1 deletion mne/decoding/search_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sklearn.base import BaseEstimator, MetaEstimatorMixin, TransformerMixin, clone
from sklearn.metrics import check_scoring
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import check_array, get_tags
from sklearn.utils import check_array

from ..parallel import parallel_func
from ..utils import ProgressBar, _parse_verbose, array_split_idx, fill_doc, verbose
Expand Down Expand Up @@ -63,6 +63,8 @@ def _estimator_type(self):

def __sklearn_tags__(self):
"""Get sklearn tags."""
from sklearn.utils import get_tags

tags = super().__sklearn_tags__()
sub_tags = get_tags(self.base_estimator)
tags.estimator_type = sub_tags.estimator_type
Expand Down
3 changes: 1 addition & 2 deletions mne/decoding/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
)
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.utils import get_tags
from sklearn.utils.estimator_checks import parametrize_with_checks

from mne import EpochsArray, create_info
Expand Down Expand Up @@ -98,7 +97,7 @@ def test_get_coef():
assert hasattr(lm_classification, "__sklearn_tags__")
print(lm_classification.__sklearn_tags__)
assert is_classifier(lm_classification.model)
assert is_classifier(lm_classification), get_tags(lm_classification).estimator_type
assert is_classifier(lm_classification)
assert not is_regressor(lm_classification.model)
assert not is_regressor(lm_classification)

Expand Down

0 comments on commit b47bed9

Please sign in to comment.