Skip to content

Commit

Permalink
First working version of lasso select in plot_evoked_topo
Browse files Browse the repository at this point in the history
  • Loading branch information
wmvanvliet committed Oct 4, 2023
1 parent 905c12c commit 2c307fe
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 66 deletions.
32 changes: 23 additions & 9 deletions mne/viz/topo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_setup_ax_spines,
_check_cov,
_plot_masked_image,
SelectFromCollection,
)


Expand Down Expand Up @@ -195,8 +196,11 @@ def format_coord_multiaxis(x, y, ch_name=None):
under_ax.set(xlim=[0, 1], ylim=[0, 1])

axs = list()

shown_ch_names = []
for idx, name in iter_ch:
ch_idx = ch_names.index(name)
shown_ch_names.append(name)
if not unified: # old, slow way
ax = plt.axes(pos[idx])
ax.patch.set_facecolor(axis_facecolor)
Expand Down Expand Up @@ -237,15 +241,22 @@ def format_coord_multiaxis(x, y, ch_name=None):
],
[1, 0, 2],
)
if not img:
under_ax.add_collection(
collections.PolyCollection(
verts,
facecolor=axis_facecolor,
edgecolor=axis_spinecolor,
linewidth=1.0,
)
) # Not needed for image plots.
if not img: # Not needed for image plots.
collection = collections.PolyCollection(
verts,
facecolor=axis_facecolor,
edgecolor=axis_spinecolor,
)
under_ax.add_collection(collection)
fig.lasso = SelectFromCollection(
ax=under_ax,
collection=collection,
names=shown_ch_names,
alpha_nonselected=0,
alpha_selected=1,
linewidth_nonselected=0,
linewidth_selected=0.7,
)
for ax in axs:
yield ax, ax._mne_ch_idx

Expand Down Expand Up @@ -344,6 +355,9 @@ def _plot_topo_onpick(event, show_func):
"""Onpick callback that shows a single channel in a new figure."""
# make sure that the swipe gesture in OS-X doesn't open many figures
orig_ax = event.inaxes
if orig_ax.figure.canvas._key in ["shift", "alt"]:
return

import matplotlib.pyplot as plt

try:
Expand Down
20 changes: 20 additions & 0 deletions mne/viz/ui_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,26 @@ class Contours(UIEvent):
contours: List[str]


@dataclass
@fill_doc
class ChannelsSelect(UIEvent):
"""Indicates that the user has selected one or more channels.
Parameters
----------
ch_names : list of str
The names of the channels that were selected.
Attributes
----------
%(ui_event_name_source)s
ch_names : list of str
The names of the channels that were selected.
"""

ch_names: List[str]


def _get_event_channel(fig):
"""Get the event channel associated with a figure.
Expand Down
139 changes: 82 additions & 57 deletions mne/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
_check_decim,
)
from ..transforms import apply_trans
from .ui_events import publish, subscribe, ChannelsSelect


_channel_type_prettyprint = {
Expand Down Expand Up @@ -1044,7 +1045,7 @@ def plot_sensors(
Whether to plot the sensors as 3d, topomap or as an interactive
sensor selection dialog. Available options ``'topomap'``, ``'3d'``,
``'select'``. If ``'select'``, a set of channels can be selected
interactively by using lasso selector or clicking while holding control
interactively by using lasso selector or clicking while holding the shift
key. The selected channels are returned along with the figure instance.
Defaults to ``'topomap'``.
ch_type : None | str
Expand Down Expand Up @@ -1255,7 +1256,7 @@ def _onpick_sensor(event, fig, ax, pos, ch_names, show_names):
if event.mouseevent.inaxes != ax:
return

if event.mouseevent.key == "control" and fig.lasso is not None:
if event.mouseevent.key in ["shift", "alt"] and fig.lasso is not None:
for ind in event.ind:
fig.lasso.select_one(ind)

Expand Down Expand Up @@ -1360,7 +1361,7 @@ def _plot_sensors(
lw=linewidth,
)
if kind == "select":
fig.lasso = SelectFromCollection(ax, pts, ch_names)
fig.lasso = SelectFromCollection(ax, pts, names=ch_names)
else:
fig.lasso = None

Expand Down Expand Up @@ -1693,72 +1694,95 @@ def _draw_without_rendering(cbar):


class SelectFromCollection:
"""Select channels from a matplotlib collection using ``LassoSelector``.
"""Select objects from a matplotlib collection using ``LassoSelector``.
Selected channels are saved in the ``selection`` attribute. This tool
highlights selected points by fading other points out (i.e., reducing their
alpha values).
The names of the selected objects are saved in the ``selection`` attribute.
This tool highlights selected objects by fading other objects out (i.e.,
reducing their alpha values).
Parameters
----------
ax : instance of Axes
Axes to interact with.
collection : instance of matplotlib collection
Collection you want to select from.
alpha_other : 0 <= float <= 1
To highlight a selection, this tool sets all selected points to an
alpha value of 1 and non-selected points to ``alpha_other``.
Defaults to 0.3.
linewidth_other : float
Linewidth to use for non-selected sensors. Default is 1.
names : list of str
The names of the object. The selection is returned as a subset of these names.
alpha_selected : float
Alpha for selected objects (0=tranparant, 1=opaque).
alpha_nonselected : float
Alpha for non-selected objects (0=tranparant, 1=opaque).
linewidth_selected : float
Linewidth for the borders of selected objects.
linewidth_nonselected : float
Linewidth for the borders of non-selected objects.
Notes
-----
This tool selects collection objects based on their *origins*
(i.e., ``offsets``). Calls all callbacks in self.callbacks when selection
is ready.
This tool selects collection objects which bounding boxes intersect with a lasso
path. Calls all callbacks in self.callbacks when selection is ready.
"""

def __init__(
self,
ax,
collection,
ch_names,
alpha_other=0.5,
linewidth_other=0.5,
*,
names,
alpha_selected=1,
alpha_nonselected=0.5,
linewidth_selected=1,
linewidth_nonselected=0.5,
):
from matplotlib.widgets import LassoSelector

self.fig = ax.figure
self.canvas = ax.figure.canvas
self.collection = collection
self.ch_names = ch_names
self.alpha_other = alpha_other
self.linewidth_other = linewidth_other
self.names = names
self.alpha_selected = alpha_selected
self.alpha_nonselected = alpha_nonselected
self.linewidth_selected = linewidth_selected
self.linewidth_nonselected = linewidth_nonselected

from matplotlib.collections import PolyCollection
from matplotlib.path import Path

self.xys = collection.get_offsets()
self.Npts = len(self.xys)
if isinstance(collection, PolyCollection):
self.paths = collection.get_paths()
else:
self.paths = [Path([point]) for point in collection.get_offsets()]
self.Npts = len(self.paths)
if self.Npts != len(names):
raise ValueError(
f"Number of names ({len(names)}) does not match the number of objects "
f"in the collection ({self.Npts})."
)

# Ensure that we have separate colors for each object
# Ensure that we have colors for each object.
self.fc = collection.get_facecolors()
self.ec = collection.get_edgecolors()
self.lw = collection.get_linewidths()
if len(self.fc) == 0:
raise ValueError("Collection must have a facecolor")
elif len(self.fc) == 1:
self.fc = np.tile(self.fc, self.Npts).reshape(self.Npts, -1)
if len(self.ec) == 0:
self.ec = np.zeros((self.Npts, 4)) # all black
elif len(self.ec) == 1:
self.ec = np.tile(self.ec, self.Npts).reshape(self.Npts, -1)
self.fc[:, -1] = self.alpha_other # deselect in the beginning
self.ec[:, -1] = self.alpha_other
self.lw = np.full(self.Npts, self.linewidth_other)
self.lw = np.full(self.Npts, float(self.linewidth_nonselected))

# Initialize the lasso selector
line_kw = _prop_kw("line", dict(color="red", linewidth=0.5))
self.lasso = LassoSelector(ax, onselect=self.on_select, **line_kw)
self.selection = list()
self.callbacks = list()
self.selection_inds = np.array([], dtype="int")

# Deselect everything in the beginning.
self.style_objects([])

# Respond to UI-Events
subscribe(self.fig, "channels_select", self._on_channels_select)

def on_select(self, verts):
"""Select a subset from the collection."""
Expand All @@ -1768,44 +1792,45 @@ def on_select(self, verts):
return

path = Path(verts)
inds = np.nonzero([path.contains_point(xy) for xy in self.xys])[0]
if self.canvas._key == "control": # Appending selection.
sels = [np.where(self.ch_names == c)[0][0] for c in self.selection]
inters = set(inds) - set(sels)
inds = list(inters.union(set(sels) - set(inds)))
inds = np.nonzero([path.intersects_path(p) for p in self.paths])[0]
if self.canvas._key == "shift": # Appending selection.
self.selection_inds = np.union1d(self.selection_inds, inds)
elif self.canvas._key == "alt": # Removing selection.
self.selection_inds = np.setdiff1d(self.selection_inds, inds)
else:
self.selection_inds = inds
ch_names = [self.names[i] for i in self.selection_inds]
publish(self.fig, ChannelsSelect(ch_names=ch_names))

self.selection[:] = np.array(self.ch_names)[inds].tolist()
self.style_sensors(inds)
self.notify()
def _on_channels_select(self, event):
ch_inds = {name: i for i, name in enumerate(self.names)}
self.selection = [name for name in event.ch_names if name in ch_inds]
self.selection_inds = [ch_inds[name] for name in self.selection]
self.style_objects(self.selection_inds)

def select_one(self, ind):
"""Select or deselect one sensor."""
ch_name = self.ch_names[ind]
if ch_name in self.selection:
sel_ind = self.selection.index(ch_name)
self.selection.pop(sel_ind)
if self.canvas._key == "shift":
self.selection_inds = np.union1d(self.selection_inds, [ind])
elif self.canvas._key == "alt":
self.selection_inds = np.setdiff1d(self.selection_inds, [ind])
else:
self.selection.append(ch_name)
inds = np.isin(self.ch_names, self.selection).nonzero()[0]
self.style_sensors(inds)
self.notify()

def notify(self):
"""Notify listeners that a selection has been made."""
for callback in self.callbacks:
callback()
return # don't notify()
ch_names = [self.names[i] for i in self.selection_inds]
publish(self.fig, ChannelsSelect(ch_names=ch_names))

def select_many(self, inds):
"""Select many sensors using indices (for predefined selections)."""
self.selection[:] = np.array(self.ch_names)[inds].tolist()
self.style_sensors(inds)
self.selected_inds = inds
ch_names = [self.names[i] for i in self.selection_inds]
publish(self.fig, ChannelsSelect(ch_names=ch_names))

def style_sensors(self, inds):
def style_objects(self, inds):
"""Style selected sensors as "active"."""
# reset
self.fc[:, -1] = self.alpha_other
self.ec[:, -1] = self.alpha_other / 2
self.lw[:] = self.linewidth_other
self.fc[:, -1] = self.alpha_nonselected
self.ec[:, -1] = self.alpha_nonselected / 2
self.lw[:] = self.linewidth_nonselected
# style sensors at `inds`
self.fc[inds, -1] = self.alpha_selected
self.ec[inds, -1] = self.alpha_selected
Expand Down

0 comments on commit 2c307fe

Please sign in to comment.