Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for dictionary-type ref_channels in set_eeg_reference() #12366

Merged
merged 72 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
1bce965
init the PR draft
qian-chu Jan 16, 2024
f42d5fb
Merge branch 'main' into dict_ref
qian-chu Jan 16, 2024
278fdf3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
f3cff5a
Create 12366.newfeature.rst
qian-chu Jan 16, 2024
b7b5c0c
Update 12366.newfeature.rst
qian-chu Jan 17, 2024
7ece510
Merge branch 'mne-tools:main' into dict_ref
qian-chu May 2, 2024
8d4516d
Add custom reference based on dict
May 3, 2024
42f45b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 3, 2024
95a1434
BF: use isintance to check if dict
May 3, 2024
c921d69
BF: remove extra copy of data
May 3, 2024
b0c91d2
Add custom reference
May 3, 2024
78a5c7e
change doc (add Alex)
May 3, 2024
ca6908c
Add warning if bad channels in re-referencing scheme
May 3, 2024
6ac7bed
Merge branch 'dict_ref' of https://github.com/qian-chu/mne-python int…
May 3, 2024
b1165b9
add _check_before_dict_reference and enrich set_eeg_reference_see_als…
qian-chu May 3, 2024
ed56c97
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 3, 2024
d27ce12
Update reference.py
qian-chu May 31, 2024
073ca9d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2024
8d744ca
Update test_reference.py
qian-chu May 31, 2024
9e9507d
Update reference.py
qian-chu Jun 3, 2024
24d16a5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2024
2a6db40
Update test_reference.py
qian-chu Jun 3, 2024
6699de0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2024
8a5232a
formatting
qian-chu Jun 3, 2024
73ad30d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2024
3474370
dict does not accept repeated keys, no need to test
qian-chu Jun 5, 2024
cd213ee
add test for warnings and raises
qian-chu Jun 5, 2024
6c79a60
Update test_reference.py
qian-chu Jun 5, 2024
e36ddd2
Update docs.py
qian-chu Jun 5, 2024
41893f3
Merge branch 'mne-tools:main' into dict_ref
qian-chu Jun 5, 2024
5b8bd94
Data check test
AlexLepauvre Jun 7, 2024
f354d0a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2024
6dc0633
Add check of reference flag and bug correction
AlexLepauvre Jun 7, 2024
0293e7b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2024
32fa6ac
Add tests for epochs object
AlexLepauvre Jun 7, 2024
1fa6a4e
Merge branch 'dict_ref' of https://github.com/qian-chu/mne-python int…
AlexLepauvre Jun 7, 2024
18cf0d9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2024
aee05a5
Merge branch 'main' into dict_ref
qian-chu Jun 7, 2024
dcd8f9d
formatting
qian-chu Jun 10, 2024
89a4e43
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2024
c9a52d4
Add warning for re-referencing electrode by itself
AlexLepauvre Jun 19, 2024
e5cb792
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2024
f0cac35
Refactorize tests of epochs and raws
AlexLepauvre Jun 19, 2024
12ae73a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2024
10da1dd
Apply suggestions from code review
qian-chu Jun 19, 2024
d284912
Re-organize warnings and errors
qian-chu Jun 19, 2024
bc09e77
Merge branch 'main' into dict_ref
qian-chu Jun 21, 2024
0122aeb
Merge branch 'main' into dict_ref
qian-chu Jul 2, 2024
4432adc
new dict check function as suggested
qian-chu Jul 3, 2024
b3fade6
Merge branch 'main' into dict_ref
qian-chu Jul 3, 2024
d5770b0
simplify (now that we guarantee list-like dict vals)
drammock Jul 9, 2024
85d1eb1
warn when keys (not just vals) are bad chs
drammock Jul 9, 2024
eab984b
clearer var name; only compute mismatch pairs if needed
drammock Jul 9, 2024
1a57807
refactor: convert to ch indices in helper func
drammock Jul 9, 2024
dfaab21
return None (like REST/proj references) instead of copy of inst data
drammock Jul 9, 2024
6ff72b6
slightly less misleading docstring
drammock Jul 9, 2024
2ec411f
modularize check_ssp and adds the check to dict-based reference
qian-chu Jul 29, 2024
2d6eeff
Merge branch 'main' into dict_ref
qian-chu Jul 29, 2024
1025e98
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 29, 2024
8858bfc
doc improvement
qian-chu Jul 30, 2024
d39c160
Merge branch 'main' into dict_ref
qian-chu Jul 30, 2024
7dae296
Merge branch 'main' into dict_ref
qian-chu Aug 10, 2024
71b31bb
Alex not new contributor now
qian-chu Aug 23, 2024
79f5bc4
Merge branch 'main' into dict_ref
qian-chu Aug 23, 2024
7704e35
Merge branch 'dict_ref' of https://github.com/qian-chu/mne-python int…
qian-chu Aug 23, 2024
2af639b
Update 12366.newfeature.rst
qian-chu Aug 23, 2024
f677570
Update 12366.newfeature.rst
qian-chu Aug 23, 2024
85b90ff
Merge branch 'main' into dict_ref
qian-chu Sep 13, 2024
c1d0339
MAINT: Reorder
larsoner Sep 20, 2024
8297a5f
Merge branch 'main' into dict_ref
larsoner Sep 20, 2024
2ce1e8d
Merge branch 'main' into dict_ref
qian-chu Sep 21, 2024
9344a4f
Merge branch 'main' into dict_ref
qian-chu Sep 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel/12366.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for `dict` type argument ``ref_channels`` to :func:`mne.set_eeg_reference`, to allow flexible re-referencing (e.g. ``raw.set_eeg_reference(ref_channels={'A1': ['A2', 'A3']})`` will set the new A1 data to be ``A1 - (A2 + A3)/2``), by :newcontrib:`Alex Lepauvre` and `Qian Chu`_
2 changes: 2 additions & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

.. _Alex Kiefer: https://home.alexk101.dev

.. _Alex Lepauvre: https://github.com/AlexLepauvre

.. _Alex Rockhill: https://github.com/alexrockhill/

.. _Alexander Rudiuk: https://github.com/ARudiuk
Expand Down
90 changes: 88 additions & 2 deletions mne/_fiff/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,71 @@ def _check_before_reference(inst, ref_from, ref_to, ch_type):
return ref_to


def _check_before_dict_reference(inst, ref_dict):
qian-chu marked this conversation as resolved.
Show resolved Hide resolved
"""Prepare instance for dict-based referencing."""
# Check to see that data is preloaded
_check_preload(inst, "Applying a reference")

# Promote all values to list-like. This simplifies our logic and also helps catch
# self-referencing cases like `{"Cz": ["Cz"]}`
_refdict = {k: [v] if isinstance(v, str) else list(v) for k, v in ref_dict.items()}

# Check that keys are strings and values are lists-of-strings
key_types = {type(k) for k in _refdict}
value_types = {type(v) for val in _refdict.values() for v in val}
for elem_name, elem in dict(key=key_types, value=value_types).items():
if bad_elem := elem - {str}:
raise TypeError(
f"{elem_name.capitalize()}s in the ref_channels dict must be strings. "
f"Your dict has {elem_name}s of type "
f'{", ".join(map(lambda x: x.__name__, bad_elem))}.'
)

# Check that keys are valid channels and values are lists-of-valid-channels
ch_set = set(inst.ch_names)
bad_ch_set = set(inst.info["bads"])
keys = set(_refdict)
values = set(sum(_refdict.values(), []))
for elem_name, elem in dict(key=keys, value=values).items():
if bad_elem := elem - ch_set:
raise ValueError(
f'ref_channels dict contains invalid {elem_name}(s) '
f'({", ".join(bad_elem)}) '
"that are not names of channels in the instance."
)
# Check that values are not bad channels
if bad_elem := elem.intersection(bad_ch_set):
warn(
f"ref_channels dict contains {elem_name}(s) "
f"({', '.join(bad_elem)}) "
"that are marked as bad channels."
)
qian-chu marked this conversation as resolved.
Show resolved Hide resolved

# Check for self-referencing
self_ref = [[k] == v for k, v in _refdict.items()]
if any(self_ref):
which = np.array(list(_refdict))[np.nonzero(self_ref)]
for ch in which:
warn(f"Channel {ch} is self-referenced, which will nullify the channel.")

# Check that channel types match. First unpack list-like vals into separate items:
pairs = [(k, v) for k in _refdict for v in _refdict[k]]
ch_type_map = dict(zip(inst.ch_names, inst.get_channel_types()))
mismatch = [ch_type_map[k] != ch_type_map[v] for k, v in pairs]
if any(mismatch):
mismatch_pairs = np.array(pairs)[mismatch]
for k, v in mismatch_pairs:
warn(
f"Channel {k} ({ch_type_map[k]}) is referenced to channel {v} which is "
f"a different channel type ({ch_type_map[v]})."
)

# convert channel names to indices
keys_ix = pick_channels(inst.ch_names, list(_refdict), ordered=True)
vals_ix = (pick_channels(inst.ch_names, v, ordered=True) for v in _refdict.values())
return dict(zip(keys_ix, vals_ix))
qian-chu marked this conversation as resolved.
Show resolved Hide resolved


def _apply_reference(inst, ref_from, ref_to=None, forward=None, ch_type="auto"):
"""Apply a custom EEG referencing scheme."""
ref_to = _check_before_reference(inst, ref_from, ref_to, ch_type)
Expand Down Expand Up @@ -128,6 +193,22 @@ def _apply_reference(inst, ref_from, ref_to=None, forward=None, ch_type="auto"):
return inst, ref_data


def _apply_dict_reference(inst, ref_dict):
"""Apply a dict-based custom EEG referencing scheme."""
# this converts all keys to channel indices and all values to arrays of ch. indices:
ref_dict = _check_before_dict_reference(inst, ref_dict)

data = inst._data
orig_data = data.copy()
for ref_to, ref_from in ref_dict.items():
ref_data = orig_data[..., ref_from, :].mean(-2, keepdims=True)
data[..., [ref_to], :] -= ref_data

with inst.info._unlock():
inst.info["custom_ref_applied"] = FIFF.FIFFV_MNE_CUSTOM_REF_ON
return inst, None
qian-chu marked this conversation as resolved.
Show resolved Hide resolved


@fill_doc
def add_reference_channels(inst, ref_channels, copy=True):
"""Add reference channels to data that consists of all zeros.
Expand Down Expand Up @@ -319,18 +400,23 @@ def set_eeg_reference(
Returns
-------
inst : instance of Raw | Epochs | Evoked
Data with EEG channels re-referenced. If ``ref_channels='average'`` and
Data with EEG channels re-referenced. If ``ref_channels="average"`` and
``projection=True`` a projection will be added instead of directly
re-referencing the data.
ref_data : array
Array of reference data subtracted from EEG channels. This will be
``None`` if ``projection=True`` or ``ref_channels='REST'``.
``None`` if ``projection=True``, or if ``ref_channels`` is ``"REST"`` or a
:class:`dict`.
%(set_eeg_reference_see_also_notes)s
"""
from ..forward import Forward

_check_can_reref(inst)

if isinstance(ref_channels, dict):
logger.info("Applying a custom dict-based reference.")
return _apply_dict_reference(inst, ref_channels)

ch_type = _get_ch_type(inst, ch_type)

if projection: # average reference projector
Expand Down
154 changes: 154 additions & 0 deletions mne/_fiff/tests/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,160 @@ def test_set_eeg_reference_rest():
assert 0.995 < exp_var <= 1


@testing.requires_testing_data
@pytest.mark.parametrize("inst_type", ["raw", "epochs"])
@pytest.mark.parametrize(
"ref_channels, expectation",
[
(
{2: "EEG 001"},
pytest.raises(
TypeError,
match="Keys in the ref_channels dict must be strings. "
"Your dict has keys of type int.",
),
),
(
{"EEG 001": (1, 2)},
pytest.raises(
TypeError,
match="Values in the ref_channels dict must be strings. "
"Your dict has values of type int.",
),
),
(
{"EEG 001": [1, 2]},
pytest.raises(
TypeError,
match="Values in the ref_channels dict must be strings. "
"Your dict has values of type int.",
),
),
(
{"EEG 999": "EEG 001"},
pytest.raises(
ValueError,
match=r"ref_channels dict contains invalid key\(s\) \(EEG 999\) "
"that are not names of channels in the instance.",
),
),
(
{"EEG 001": "EEG 999"},
pytest.raises(
ValueError,
match=r"ref_channels dict contains invalid value\(s\) \(EEG 999\) "
"that are not names of channels in the instance.",
),
),
(
{"EEG 001": "EEG 057"},
pytest.warns(
RuntimeWarning,
match=r"ref_channels dict contains value\(s\) \(EEG 057\) "
"that are marked as bad channels.",
),
),
(
{"EEG 001": "STI 001"},
pytest.warns(
RuntimeWarning,
match=(
r"Channel EEG 001 \(eeg\) is referenced to channel "
r"STI 001 which is a different channel type \(stim\)."
),
),
),
(
{"EEG 001": "EEG 001"},
pytest.warns(
RuntimeWarning,
match=(
"Channel EEG 001 is self-referenced, "
"which will nullify the channel."
),
),
),
(
{"EEG 001": "EEG 002", "EEG 002": "EEG 003", "EEG 003": "EEG 005"},
nullcontext(),
),
(
{
"EEG 001": ["EEG 002", "EEG 003"],
"EEG 002": "EEG 003",
"EEG 003": "EEG 005",
},
nullcontext(),
),
],
)
def test_set_eeg_reference_dict(ref_channels, inst_type, expectation):
"""Test setting dict-based reference."""
if inst_type == "raw":
inst = read_raw_fif(fif_fname).crop(0, 1).pick(picks=["eeg", "stim"])
# Test re-referencing Epochs object
elif inst_type == "epochs":
raw = read_raw_fif(fif_fname, preload=False)
events = read_events(eve_fname)
inst = Epochs(
raw,
events=events,
event_id=1,
tmin=-0.2,
tmax=0.5,
preload=False,
)
with pytest.raises(
RuntimeError,
match="By default, MNE does not load data.*Applying a reference requires.*",
):
inst.set_eeg_reference(ref_channels=ref_channels)
inst.load_data()
inst.info["bads"] = ["EEG 057"]
with expectation:
reref, _ = set_eeg_reference(inst.copy(), ref_channels, copy=False)

if isinstance(expectation, nullcontext):
# Check that the custom_ref_applied is set correctly:
assert reref.info["custom_ref_applied"] == FIFF.FIFFV_MNE_CUSTOM_REF_ON

# Get raw data
_data = inst._data

# Get that channels that were and weren't re-referenced:
ch_raw = pick_channels(
inst.ch_names,
[ch for ch in inst.ch_names if ch not in list(ref_channels.keys())],
)
ch_reref = pick_channels(inst.ch_names, list(ref_channels.keys()), ordered=True)

# Check that the non re-reference channels are untouched:
assert_allclose(
_data[..., ch_raw, :], reref._data[..., ch_raw, :], 1e-6, atol=1e-15
)

# Compute the reference data:
ref_data = []
for val in ref_channels.values():
if isinstance(val, str):
val = [val] # pick_channels expects a list
ref_data.append(
_data[..., pick_channels(inst.ch_names, val, ordered=True), :].mean(
-2, keepdims=True
)
)
if inst_type == "epochs":
ref_data = np.concatenate(ref_data, axis=1)
else:
ref_data = np.squeeze(np.array(ref_data))
assert_allclose(
_data[..., ch_reref, :],
reref._data[..., ch_reref, :] + ref_data,
1e-6,
atol=1e-15,
)


@testing.requires_testing_data
@pytest.mark.parametrize("inst_type", ("raw", "epochs", "evoked"))
def test_set_bipolar_reference(inst_type):
Expand Down
21 changes: 19 additions & 2 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3693,13 +3693,20 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
"""

docdict["ref_channels_set_eeg_reference"] = """
ref_channels : list of str | str
ref_channels : list of str | str | dict
Can be:

- The name(s) of the channel(s) used to construct the reference.
- The name(s) of the channel(s) used to construct the reference for
every channel of ``ch_type``.
- ``'average'`` to apply an average reference (default)
- ``'REST'`` to use the Reference Electrode Standardization Technique
infinity reference :footcite:`Yao2001`.
- A dictionary mapping names of data channels to (lists of) names of
reference channels. For example, {'A1': 'A3'} would replace the
data in channel 'A1' with the difference between 'A1' and 'A3'. To take
the average of multiple channels as reference, supply a list of channel
names as the dictionary value, e.g. {'A1': ['A2', 'A3']} would replace
channel A1 with ``A1 - mean(A2, A3)``.
qian-chu marked this conversation as resolved.
Show resolved Hide resolved
- An empty list, in which case MNE will not attempt any re-referencing of
the data
"""
Expand Down Expand Up @@ -4027,6 +4034,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
The given EEG electrodes are referenced to a point at infinity using the
lead fields in ``forward``, which helps standardize the signals.

- Different references for different channels
Set ``ref_channels`` to a dictionary mapping source channel names (str)
to the reference channel names (str or list of str). Unlike the other
approaches where the same reference is applied globally, you can set
different references for different channels with this method. For example,
to re-reference channel 'A1' to 'A2' and 'B1' to the average of 'B2' and
'B3', set ``ref_channels={'A1': 'A2', 'B1': ['B2', 'B3']}``. Warnings are
issued when a bad channel is used as a reference or when a mapping involves
channels of different types.

1. If a reference is requested that is not the average reference, this
function removes any pre-existing average reference projections.

Expand Down
Loading