diff --git a/mne/viz/topo.py b/mne/viz/topo.py index a01ee72a0c2..ceab61a0e0c 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -26,6 +26,7 @@ _setup_ax_spines, _check_cov, _plot_masked_image, + SelectFromCollection, ) @@ -195,8 +196,11 @@ def format_coord_multiaxis(x, y, ch_name=None): under_ax.set(xlim=[0, 1], ylim=[0, 1]) axs = list() + + shown_ch_names = [] for idx, name in iter_ch: ch_idx = ch_names.index(name) + shown_ch_names.append(name) if not unified: # old, slow way ax = plt.axes(pos[idx]) ax.patch.set_facecolor(axis_facecolor) @@ -237,15 +241,22 @@ def format_coord_multiaxis(x, y, ch_name=None): ], [1, 0, 2], ) - if not img: - under_ax.add_collection( - collections.PolyCollection( - verts, - facecolor=axis_facecolor, - edgecolor=axis_spinecolor, - linewidth=1.0, - ) - ) # Not needed for image plots. + if not img: # Not needed for image plots. + collection = collections.PolyCollection( + verts, + facecolor=axis_facecolor, + edgecolor=axis_spinecolor, + ) + under_ax.add_collection(collection) + fig.lasso = SelectFromCollection( + ax=under_ax, + collection=collection, + names=shown_ch_names, + alpha_nonselected=0, + alpha_selected=1, + linewidth_nonselected=0, + linewidth_selected=0.7, + ) for ax in axs: yield ax, ax._mne_ch_idx @@ -344,6 +355,9 @@ def _plot_topo_onpick(event, show_func): """Onpick callback that shows a single channel in a new figure.""" # make sure that the swipe gesture in OS-X doesn't open many figures orig_ax = event.inaxes + if orig_ax.figure.canvas._key in ["shift", "alt"]: + return + import matplotlib.pyplot as plt try: diff --git a/mne/viz/ui_events.py b/mne/viz/ui_events.py index ba5b1db9a33..c4a90e44132 100644 --- a/mne/viz/ui_events.py +++ b/mne/viz/ui_events.py @@ -192,6 +192,26 @@ class Contours(UIEvent): contours: List[str] +@dataclass +@fill_doc +class ChannelsSelect(UIEvent): + """Indicates that the user has selected one or more channels. + + Parameters + ---------- + ch_names : list of str + The names of the channels that were selected. + + Attributes + ---------- + %(ui_event_name_source)s + ch_names : list of str + The names of the channels that were selected. + """ + + ch_names: List[str] + + def _get_event_channel(fig): """Get the event channel associated with a figure. diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 264505b67ad..75ede905d96 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -64,6 +64,7 @@ _check_decim, ) from ..transforms import apply_trans +from .ui_events import publish, subscribe, ChannelsSelect _channel_type_prettyprint = { @@ -1044,7 +1045,7 @@ def plot_sensors( Whether to plot the sensors as 3d, topomap or as an interactive sensor selection dialog. Available options ``'topomap'``, ``'3d'``, ``'select'``. If ``'select'``, a set of channels can be selected - interactively by using lasso selector or clicking while holding control + interactively by using lasso selector or clicking while holding the shift key. The selected channels are returned along with the figure instance. Defaults to ``'topomap'``. ch_type : None | str @@ -1255,7 +1256,7 @@ def _onpick_sensor(event, fig, ax, pos, ch_names, show_names): if event.mouseevent.inaxes != ax: return - if event.mouseevent.key == "control" and fig.lasso is not None: + if event.mouseevent.key in ["shift", "alt"] and fig.lasso is not None: for ind in event.ind: fig.lasso.select_one(ind) @@ -1360,7 +1361,7 @@ def _plot_sensors( lw=linewidth, ) if kind == "select": - fig.lasso = SelectFromCollection(ax, pts, ch_names) + fig.lasso = SelectFromCollection(ax, pts, names=ch_names) else: fig.lasso = None @@ -1693,11 +1694,11 @@ def _draw_without_rendering(cbar): class SelectFromCollection: - """Select channels from a matplotlib collection using ``LassoSelector``. + """Select objects from a matplotlib collection using ``LassoSelector``. - Selected channels are saved in the ``selection`` attribute. This tool - highlights selected points by fading other points out (i.e., reducing their - alpha values). + The names of the selected objects are saved in the ``selection`` attribute. + This tool highlights selected objects by fading other objects out (i.e., + reducing their alpha values). Parameters ---------- @@ -1705,60 +1706,83 @@ class SelectFromCollection: Axes to interact with. collection : instance of matplotlib collection Collection you want to select from. - alpha_other : 0 <= float <= 1 - To highlight a selection, this tool sets all selected points to an - alpha value of 1 and non-selected points to ``alpha_other``. - Defaults to 0.3. - linewidth_other : float - Linewidth to use for non-selected sensors. Default is 1. + names : list of str + The names of the object. The selection is returned as a subset of these names. + alpha_selected : float + Alpha for selected objects (0=tranparant, 1=opaque). + alpha_nonselected : float + Alpha for non-selected objects (0=tranparant, 1=opaque). + linewidth_selected : float + Linewidth for the borders of selected objects. + linewidth_nonselected : float + Linewidth for the borders of non-selected objects. Notes ----- - This tool selects collection objects based on their *origins* - (i.e., ``offsets``). Calls all callbacks in self.callbacks when selection - is ready. + This tool selects collection objects which bounding boxes intersect with a lasso + path. Calls all callbacks in self.callbacks when selection is ready. """ def __init__( self, ax, collection, - ch_names, - alpha_other=0.5, - linewidth_other=0.5, + *, + names, alpha_selected=1, + alpha_nonselected=0.5, linewidth_selected=1, + linewidth_nonselected=0.5, ): from matplotlib.widgets import LassoSelector + self.fig = ax.figure self.canvas = ax.figure.canvas self.collection = collection - self.ch_names = ch_names - self.alpha_other = alpha_other - self.linewidth_other = linewidth_other + self.names = names self.alpha_selected = alpha_selected + self.alpha_nonselected = alpha_nonselected self.linewidth_selected = linewidth_selected + self.linewidth_nonselected = linewidth_nonselected + + from matplotlib.collections import PolyCollection + from matplotlib.path import Path - self.xys = collection.get_offsets() - self.Npts = len(self.xys) + if isinstance(collection, PolyCollection): + self.paths = collection.get_paths() + else: + self.paths = [Path([point]) for point in collection.get_offsets()] + self.Npts = len(self.paths) + if self.Npts != len(names): + raise ValueError( + f"Number of names ({len(names)}) does not match the number of objects " + f"in the collection ({self.Npts})." + ) - # Ensure that we have separate colors for each object + # Ensure that we have colors for each object. self.fc = collection.get_facecolors() self.ec = collection.get_edgecolors() - self.lw = collection.get_linewidths() if len(self.fc) == 0: raise ValueError("Collection must have a facecolor") elif len(self.fc) == 1: self.fc = np.tile(self.fc, self.Npts).reshape(self.Npts, -1) + if len(self.ec) == 0: + self.ec = np.zeros((self.Npts, 4)) # all black + elif len(self.ec) == 1: self.ec = np.tile(self.ec, self.Npts).reshape(self.Npts, -1) - self.fc[:, -1] = self.alpha_other # deselect in the beginning - self.ec[:, -1] = self.alpha_other - self.lw = np.full(self.Npts, self.linewidth_other) + self.lw = np.full(self.Npts, float(self.linewidth_nonselected)) + # Initialize the lasso selector line_kw = _prop_kw("line", dict(color="red", linewidth=0.5)) self.lasso = LassoSelector(ax, onselect=self.on_select, **line_kw) self.selection = list() - self.callbacks = list() + self.selection_inds = np.array([], dtype="int") + + # Deselect everything in the beginning. + self.style_objects([]) + + # Respond to UI-Events + subscribe(self.fig, "channels_select", self._on_channels_select) def on_select(self, verts): """Select a subset from the collection.""" @@ -1768,44 +1792,45 @@ def on_select(self, verts): return path = Path(verts) - inds = np.nonzero([path.contains_point(xy) for xy in self.xys])[0] - if self.canvas._key == "control": # Appending selection. - sels = [np.where(self.ch_names == c)[0][0] for c in self.selection] - inters = set(inds) - set(sels) - inds = list(inters.union(set(sels) - set(inds))) + inds = np.nonzero([path.intersects_path(p) for p in self.paths])[0] + if self.canvas._key == "shift": # Appending selection. + self.selection_inds = np.union1d(self.selection_inds, inds) + elif self.canvas._key == "alt": # Removing selection. + self.selection_inds = np.setdiff1d(self.selection_inds, inds) + else: + self.selection_inds = inds + ch_names = [self.names[i] for i in self.selection_inds] + publish(self.fig, ChannelsSelect(ch_names=ch_names)) - self.selection[:] = np.array(self.ch_names)[inds].tolist() - self.style_sensors(inds) - self.notify() + def _on_channels_select(self, event): + ch_inds = {name: i for i, name in enumerate(self.names)} + self.selection = [name for name in event.ch_names if name in ch_inds] + self.selection_inds = [ch_inds[name] for name in self.selection] + self.style_objects(self.selection_inds) def select_one(self, ind): """Select or deselect one sensor.""" - ch_name = self.ch_names[ind] - if ch_name in self.selection: - sel_ind = self.selection.index(ch_name) - self.selection.pop(sel_ind) + if self.canvas._key == "shift": + self.selection_inds = np.union1d(self.selection_inds, [ind]) + elif self.canvas._key == "alt": + self.selection_inds = np.setdiff1d(self.selection_inds, [ind]) else: - self.selection.append(ch_name) - inds = np.isin(self.ch_names, self.selection).nonzero()[0] - self.style_sensors(inds) - self.notify() - - def notify(self): - """Notify listeners that a selection has been made.""" - for callback in self.callbacks: - callback() + return # don't notify() + ch_names = [self.names[i] for i in self.selection_inds] + publish(self.fig, ChannelsSelect(ch_names=ch_names)) def select_many(self, inds): """Select many sensors using indices (for predefined selections).""" - self.selection[:] = np.array(self.ch_names)[inds].tolist() - self.style_sensors(inds) + self.selected_inds = inds + ch_names = [self.names[i] for i in self.selection_inds] + publish(self.fig, ChannelsSelect(ch_names=ch_names)) - def style_sensors(self, inds): + def style_objects(self, inds): """Style selected sensors as "active".""" # reset - self.fc[:, -1] = self.alpha_other - self.ec[:, -1] = self.alpha_other / 2 - self.lw[:] = self.linewidth_other + self.fc[:, -1] = self.alpha_nonselected + self.ec[:, -1] = self.alpha_nonselected / 2 + self.lw[:] = self.linewidth_nonselected # style sensors at `inds` self.fc[inds, -1] = self.alpha_selected self.ec[inds, -1] = self.alpha_selected