Skip to content

Commit

Permalink
[DOC] Fix misleading fit_transform docstrings (#12827)
Browse files Browse the repository at this point in the history
Co-authored-by: Eric Larson <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Sep 6, 2024
1 parent c993ae5 commit f3a3ca4
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 3 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/12827.other.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve documentation clarity of ``fit_transform`` methods for :class:`mne.decoding.SSD`, :class:`mne.decoding.CSP`, and :class:`mne.decoding.SPoC` classes, by `Thomas Binns`_.
55 changes: 52 additions & 3 deletions mne/decoding/csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
_check_option,
_validate_type,
_verbose_safe_false,
copy_doc,
fill_doc,
pinv,
warn,
Expand Down Expand Up @@ -273,8 +272,31 @@ def inverse_transform(self, X):
)
return X[:, np.newaxis, :] * self.patterns_[: self.n_components].T

@copy_doc(TransformerMixin.fit_transform)
def fit_transform(self, X, y, **fit_params): # noqa: D102
def fit_transform(self, X, y=None, **fit_params):
"""Fit CSP to data, then transform it.
Fits transformer to ``X`` and ``y`` with optional parameters ``fit_params``, and
returns a transformed version of ``X``.
Parameters
----------
X : array, shape (n_epochs, n_channels, n_times)
The data on which to estimate the CSP.
y : array, shape (n_epochs,)
The class for each epoch.
**fit_params : dict
Additional fitting parameters passed to the :meth:`mne.decoding.CSP.fit`
method. Not used for this class.
Returns
-------
X_csp : array, shape (n_epochs, n_components[, n_times])
If ``self.transform_into == 'average_power'`` then returns the power of CSP
features averaged over time and shape is ``(n_epochs, n_components)``. If
``self.transform_into == 'csp_space'`` then returns the data in CSP space
and shape is ``(n_epochs, n_components, n_times)``.
"""
# use parent TransformerMixin method but with custom docstring
return super().fit_transform(X, y=y, **fit_params)

@fill_doc
Expand Down Expand Up @@ -953,3 +975,30 @@ def transform(self, X):
space and shape is (n_epochs, n_components, n_times).
"""
return super().transform(X)

def fit_transform(self, X, y=None, **fit_params):
"""Fit SPoC to data, then transform it.
Fits transformer to ``X`` and ``y`` with optional parameters ``fit_params``, and
returns a transformed version of ``X``.
Parameters
----------
X : array, shape (n_epochs, n_channels, n_times)
The data on which to estimate the SPoC.
y : array, shape (n_epochs,)
The class for each epoch.
**fit_params : dict
Additional fitting parameters passed to the :meth:`mne.decoding.CSP.fit`
method. Not used for this class.
Returns
-------
X : array, shape (n_epochs, n_components[, n_times])
If ``self.transform_into == 'average_power'`` then returns the power of CSP
features averaged over time and shape is ``(n_epochs, n_components)``. If
``self.transform_into == 'csp_space'`` then returns the data in CSP space
and shape is ``(n_epochs, n_components, n_times)``.
"""
# use parent TransformerMixin method but with custom docstring
return super().fit_transform(X, y=y, **fit_params)
25 changes: 25 additions & 0 deletions mne/decoding/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,31 @@ def transform(self, X):
X_ssd = X_ssd[:, self.sorter_spec, :][:, : self.n_components, :]
return X_ssd

def fit_transform(self, X, y=None, **fit_params):
"""Fit SSD to data, then transform it.
Fits transformer to ``X`` and ``y`` with optional parameters ``fit_params``, and
returns a transformed version of ``X``.
Parameters
----------
X : array, shape ([n_epochs, ]n_channels, n_times)
The input data from which to estimate the SSD. Either 2D array obtained from
continuous data or 3D array obtained from epoched data.
y : None
Ignored; exists for compatibility with scikit-learn pipelines.
**fit_params : dict
Additional fitting parameters passed to the :meth:`mne.decoding.SSD.fit`
method. Not used for this class.
Returns
-------
X_ssd : array, shape ([n_epochs, ]n_components, n_times)
The processed data.
"""
# use parent TransformerMixin method but with custom docstring
return super().fit_transform(X, y=y, **fit_params)

def get_spectral_ratio(self, ssd_sources):
"""Get the spectal signal-to-noise ratio for each spatial filter.
Expand Down

0 comments on commit f3a3ca4

Please sign in to comment.