diff --git a/doc/changes/devel/12827.other.rst b/doc/changes/devel/12827.other.rst new file mode 100644 index 00000000000..3ccbaa0bff6 --- /dev/null +++ b/doc/changes/devel/12827.other.rst @@ -0,0 +1 @@ +Improve documentation clarity of ``fit_transform`` methods for :class:`mne.decoding.SSD`, :class:`mne.decoding.CSP`, and :class:`mne.decoding.SPoC` classes, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index b0e5d7b4bc8..6d92b5d17bd 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -15,7 +15,6 @@ _check_option, _validate_type, _verbose_safe_false, - copy_doc, fill_doc, pinv, warn, @@ -273,8 +272,31 @@ def inverse_transform(self, X): ) return X[:, np.newaxis, :] * self.patterns_[: self.n_components].T - @copy_doc(TransformerMixin.fit_transform) - def fit_transform(self, X, y, **fit_params): # noqa: D102 + def fit_transform(self, X, y=None, **fit_params): + """Fit CSP to data, then transform it. + + Fits transformer to ``X`` and ``y`` with optional parameters ``fit_params``, and + returns a transformed version of ``X``. + + Parameters + ---------- + X : array, shape (n_epochs, n_channels, n_times) + The data on which to estimate the CSP. + y : array, shape (n_epochs,) + The class for each epoch. + **fit_params : dict + Additional fitting parameters passed to the :meth:`mne.decoding.CSP.fit` + method. Not used for this class. + + Returns + ------- + X_csp : array, shape (n_epochs, n_components[, n_times]) + If ``self.transform_into == 'average_power'`` then returns the power of CSP + features averaged over time and shape is ``(n_epochs, n_components)``. If + ``self.transform_into == 'csp_space'`` then returns the data in CSP space + and shape is ``(n_epochs, n_components, n_times)``. + """ + # use parent TransformerMixin method but with custom docstring return super().fit_transform(X, y=y, **fit_params) @fill_doc @@ -953,3 +975,30 @@ def transform(self, X): space and shape is (n_epochs, n_components, n_times). """ return super().transform(X) + + def fit_transform(self, X, y=None, **fit_params): + """Fit SPoC to data, then transform it. + + Fits transformer to ``X`` and ``y`` with optional parameters ``fit_params``, and + returns a transformed version of ``X``. + + Parameters + ---------- + X : array, shape (n_epochs, n_channels, n_times) + The data on which to estimate the SPoC. + y : array, shape (n_epochs,) + The class for each epoch. + **fit_params : dict + Additional fitting parameters passed to the :meth:`mne.decoding.CSP.fit` + method. Not used for this class. + + Returns + ------- + X : array, shape (n_epochs, n_components[, n_times]) + If ``self.transform_into == 'average_power'`` then returns the power of CSP + features averaged over time and shape is ``(n_epochs, n_components)``. If + ``self.transform_into == 'csp_space'`` then returns the data in CSP space + and shape is ``(n_epochs, n_components, n_times)``. + """ + # use parent TransformerMixin method but with custom docstring + return super().fit_transform(X, y=y, **fit_params) diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index f5f1ff94516..23e3136ce36 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -261,6 +261,31 @@ def transform(self, X): X_ssd = X_ssd[:, self.sorter_spec, :][:, : self.n_components, :] return X_ssd + def fit_transform(self, X, y=None, **fit_params): + """Fit SSD to data, then transform it. + + Fits transformer to ``X`` and ``y`` with optional parameters ``fit_params``, and + returns a transformed version of ``X``. + + Parameters + ---------- + X : array, shape ([n_epochs, ]n_channels, n_times) + The input data from which to estimate the SSD. Either 2D array obtained from + continuous data or 3D array obtained from epoched data. + y : None + Ignored; exists for compatibility with scikit-learn pipelines. + **fit_params : dict + Additional fitting parameters passed to the :meth:`mne.decoding.SSD.fit` + method. Not used for this class. + + Returns + ------- + X_ssd : array, shape ([n_epochs, ]n_components, n_times) + The processed data. + """ + # use parent TransformerMixin method but with custom docstring + return super().fit_transform(X, y=y, **fit_params) + def get_spectral_ratio(self, ssd_sources): """Get the spectal signal-to-noise ratio for each spatial filter.