Skip to content

Commit

Permalink
BUG: Fix bug with CSP rank="full" (#12694)
Browse files Browse the repository at this point in the history
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
larsoner and autofix-ci[bot] authored Jul 1, 2024
1 parent bbc2a82 commit 31ef32e
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 16 deletions.
2 changes: 2 additions & 0 deletions doc/changes/devel/12694.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix regression with :class:`mne.decoding.CSP` where using ``rank="full"`` errantly
raised an error, by `Eric Larson`_.
2 changes: 1 addition & 1 deletion doc/development/governance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ Substantial portions of this document were adapted from the
<https://github.com/scipy/scipy/blob/main/doc/source/dev/governance.rst>`_,
which in turn was adapted from
`Jupyter/IPython project's governance document
<https://github.com/jupyter/governance/blob/master/governance.md>`_ and
<https://github.com/jupyter/governance/blob/main/archive/governance.md>`_ and
`NumPy's governance document
<https://github.com/numpy/numpy/blob/master/doc/source/dev/governance/governance.rst>`_.

Expand Down
10 changes: 5 additions & 5 deletions mne/decoding/csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def fit(self, X, y):
)

# Convert rank to one that will run
_validate_type(self.rank, (dict, None), "rank")
_validate_type(self.rank, (dict, None, str), "rank")

covs, sample_weights = self._compute_covariance_matrices(X, y)
eigen_vectors, eigen_values = self._decompose_covs(covs, sample_weights)
Expand Down Expand Up @@ -554,16 +554,16 @@ def _compute_covariance_matrices(self, X, y):
# Someday we could allow the user to pass this, then we wouldn't need to convert
# but in the meantime they can use a pipeline with a scaler
self._info = create_info(n_channels, 1000.0, "mag")
if self.rank is None:
if isinstance(self.rank, dict):
self._rank = {"mag": sum(self.rank.values())}
else:
self._rank = _compute_rank_raw_array(
X.transpose(1, 0, 2).reshape(X.shape[1], -1),
self._info,
rank=None,
rank=self.rank,
scalings=None,
log_ch_type="data",
)
else:
self._rank = {"mag": sum(self.rank.values())}

covs = []
sample_weights = []
Expand Down
34 changes: 24 additions & 10 deletions mne/decoding/tests/test_csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
assert_equal,
)

from mne import Epochs, io, pick_types, read_events
from mne import Epochs, compute_proj_raw, io, pick_types, read_events
from mne.decoding import CSP, LinearModel, Scaler, SPoC, get_coef
from mne.decoding.csp import _ajd_pham
from mne.utils import catch_logging
Expand Down Expand Up @@ -255,7 +255,7 @@ def test_csp():
# Even the "reg is None and rank is None" case should pass now thanks to the
# do_compute_rank
@pytest.mark.parametrize("ch_type", ("mag", "eeg", ("mag", "eeg")))
@pytest.mark.parametrize("rank", (None, "correct"))
@pytest.mark.parametrize("rank", (None, "full", "correct"))
@pytest.mark.parametrize("reg", [None, 0.001, "oas"])
def test_regularized_csp(ch_type, rank, reg):
"""Test Common Spatial Patterns algorithm using regularized covariance."""
Expand All @@ -268,29 +268,36 @@ def test_regularized_csp(ch_type, rank, reg):
n_orig = len(raw.ch_names)
ch_decim = 2
raw.pick_channels(raw.ch_names[::ch_decim])
raw.info.normalize_proj()
if "eeg" in ch_type:
raw.set_eeg_reference(projection=True)
# TODO: for some reason we need to add a second EEG projector in order to get
# the non-semidefinite error for EEG data. Hopefully this won't make much
# difference in practice given our default is rank=None and regularization
# is easy to use.
raw.add_proj(compute_proj_raw(raw, n_eeg=1, n_mag=0, n_grad=0, n_jobs=1))
n_eig = len(raw.ch_names) - len(raw.info["projs"])
n_ch = n_orig // ch_decim
if ch_type == "eeg":
assert n_eig == n_ch - 1
assert n_eig == n_ch - 2
elif ch_type == "mag":
assert n_eig == n_ch - 3
else:
assert n_eig == n_ch - 4
assert n_eig == n_ch - 5
if rank == "correct":
if isinstance(ch_type, str):
rank = {ch_type: n_eig}
else:
assert ch_type == ("mag", "eeg")
rank = dict(
mag=102 // ch_decim - 3,
eeg=60 // ch_decim - 1,
eeg=60 // ch_decim - 2,
)
else:
assert rank is None, rank
raw.info.normalize_proj()
raw.filter(2, 40)
assert rank is None or rank == "full", rank
if rank == "full":
n_eig = n_ch
raw.filter(2, 40).apply_proj()
events = read_events(event_name)
# map make left and right events the same
events[events[:, 2] == 2, 2] = 1
Expand All @@ -307,18 +314,25 @@ def test_regularized_csp(ch_type, rank, reg):
epochs_data_orig = epochs_data.copy()
epochs_data = sc.fit_transform(epochs_data)
csp = CSP(n_components=n_components, reg=reg, norm_trace=False, rank=rank)
if rank == "full" and reg is None:
with pytest.raises(np.linalg.LinAlgError, match="leading minor"):
csp.fit(epochs_data, epochs.events[:, -1])
return
with catch_logging(verbose=True) as log:
X = csp.fit_transform(epochs_data, epochs.events[:, -1])
log = log.getvalue()
assert "Setting small MAG" not in log
assert "Setting small data eigen" in log
if rank != "full":
assert "Setting small data eigen" in log
else:
assert "Setting small data eigen" not in log
if rank is None:
assert "Computing rank from data" in log
assert " mag: rank" not in log.lower()
assert " data: rank" in log
assert "rank (mag)" not in log.lower()
assert "rank (data)" in log
else: # if rank is passed no computation is done
elif rank != "full": # if rank is passed no computation is done
assert "Computing rank" not in log
assert ": rank" not in log
assert "rank (" not in log
Expand Down

0 comments on commit 31ef32e

Please sign in to comment.