Skip to content

Commit

Permalink
MRG, FIX: Fix tick labeling (#7811)
Browse files Browse the repository at this point in the history
* FIX: Fix tick labeling

* FIX: Add debugging
  • Loading branch information
larsoner committed Sep 4, 2020
1 parent 76fc838 commit 4577b5f
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 28 deletions.
11 changes: 11 additions & 0 deletions doc/changes/0.20.inc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@
- "API" for backward-incompatible changes


.. _changes_0_20_8:

Version 0.20.8
--------------

Bug
~~~

- Fix bug with :meth:`mne.io.Raw.plot` with newer matplotlib where tick labeling would raise an error by `Eric Larson`_


.. _changes_0_20_7:

Version 0.20.7
Expand Down
2 changes: 1 addition & 1 deletion mne/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# Dev branch marker is: 'X.Y.devN' where N is an integer.
#

__version__ = '0.20.7'
__version__ = '0.20.8'

# have to import verbose first since it's needed by many things
from .utils import (set_log_level, set_log_file, verbose, set_config,
Expand Down
15 changes: 11 additions & 4 deletions mne/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def pytest_configure(config):
ignore:.*pandas\.util\.testing is deprecated.*:
ignore:.*tostring.*is deprecated.*:DeprecationWarning
always:.*get_data.* is deprecated in favor of.*:DeprecationWarning
ignore:.*Passing the dash.*:
""" # noqa: E501
for warning_line in warning_lines.split('\n'):
warning_line = warning_line.strip()
Expand Down Expand Up @@ -115,13 +116,19 @@ def matplotlib_config():
"""Configure matplotlib for viz tests."""
import matplotlib
from matplotlib import cbook
# "force" should not really be necessary but should not hurt
kwargs = dict()
# Allow for easy interactive debugging with a call like:
#
# $ MNE_MPL_TESTING_BACKEND=Qt5Agg pytest mne/viz/tests/test_raw.py -k annotation -x --pdb # noqa: E501
#
try:
want = os.environ['MNE_MPL_TESTING_BACKEND']
except KeyError:
want = 'agg' # don't pop up windows
with warnings.catch_warnings(record=True): # ignore warning
warnings.filterwarnings('ignore')
matplotlib.use('agg', force=True, **kwargs) # don't pop up windows
matplotlib.use(want, force=True)
import matplotlib.pyplot as plt
assert plt.get_backend() == 'agg'
assert plt.get_backend() == want
# overwrite some params that can horribly slow down tests that
# users might have changed locally (but should not otherwise affect
# functionality)
Expand Down
2 changes: 1 addition & 1 deletion mne/viz/_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3192,7 +3192,7 @@ def plot_brain_colorbar(ax, clim, colormap='auto', transparent=True,
colormap, lims = _linearize_map(mapdata)
del mapdata
norm = Normalize(vmin=lims[0], vmax=lims[2])
cbar = ColorbarBase(ax, colormap, norm=norm, ticks=ticks,
cbar = ColorbarBase(ax, cmap=colormap, norm=norm, ticks=ticks,
label=label, orientation=orientation)
# make the colorbar background match the brain color
cbar.patch.set(facecolor=bgcolor)
Expand Down
36 changes: 21 additions & 15 deletions mne/viz/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, n_channels=20,
projs = epochs.info['projs']
noise_cov = _check_cov(noise_cov, epochs.info)

params = dict(epochs=epochs, info=epochs.info.copy(), t_start=0.,
params = dict(epochs=epochs, info=epochs.info.copy(), t_start=0,
bad_color=(0.8, 0.8, 0.8), histogram=None, decim=decim,
data_picks=data_picks, noise_cov=noise_cov,
use_noise_cov=noise_cov is not None,
Expand Down Expand Up @@ -1092,10 +1092,10 @@ def _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings,
ax.set_yticks(offsets)
ax.set_ylim(ylim)
ticks = epoch_times + 0.5 * n_times
ax.set_xticks(ticks)
ax2.set_xticks(ticks[:n_epochs])
for ax_ in (ax, ax2):
ax_.set_xticks(ticks[:n_epochs])
labels = list(range(0, len(ticks))) # epoch numbers
ax.set_xticklabels(labels)
ax.set_xticklabels(labels[:n_epochs])
xlim = epoch_times[-1] + len(orig_epoch_times)
ax_hscroll.set_xlim(0, xlim)
vertline_t = ax_hscroll.text(0, 1, '', color='y', va='bottom', ha='right')
Expand Down Expand Up @@ -1252,13 +1252,14 @@ def _plot_traces(params):

n_times = len(epochs.times)
tick_list = list()
start_idx = int(params['t_start'] / n_times)
start_idx = params['t_start'] // n_times
end = params['t_start'] + params['duration']
end_idx = int(end / n_times)
xlabels = params['labels'][start_idx:]
event_ids = params['epochs'].events[:, 2]
params['ax2'].set_xticklabels(event_ids[start_idx:])
end_idx = end // n_times
xlabels = params['labels'][start_idx:end_idx]
event_ids = params['epochs'].events[start_idx:end_idx, 2]
params['ax2'].set_xticklabels(event_ids)
ax.set_xticklabels(xlabels)
del event_ids, xlabels
ylabels = ax.yaxis.get_ticklabels()
# do the plotting
for line_idx in range(n_channels):
Expand Down Expand Up @@ -1339,7 +1340,7 @@ def _plot_traces(params):
0, ylim + 1, ylim / (4 * max(len(chan_types_split), 1)))
offset_pos = np.arange(2, (len(chan_types_split) * 4) + 1, 4)
ax.set_yticks(ticks)
labels = [''] * 20
labels = [''] * len(ticks)
labels = [0 if idx in range(2, len(labels), 4) else label
for idx, label in enumerate(labels)]
for idx_chan, chan_type in enumerate(chan_types_split):
Expand All @@ -1356,6 +1357,7 @@ def _plot_traces(params):
labels[li] = round(label, 2)
ax.set_yticklabels(labels, fontsize=12, color='black')
else:
ax.set_yticks(params['offsets'][:len(tick_list)])
ax.set_yticklabels(tick_list, fontsize=12)
_set_ax_label_style(ax, params)

Expand All @@ -1382,7 +1384,7 @@ def _plot_update_epochs_proj(params, bools=None):
n_epochs = params['n_epochs']
params['projector'], params['whitened_ch_names'] = _setup_plot_projector(
params['info'], params['noise_cov'], True, params['use_noise_cov'])
start = int(params['t_start'] / len(epochs.times))
start = params['t_start'] // len(epochs.times)
end = start + n_epochs
if epochs.preload:
data = np.concatenate(epochs.get_data()[start:end], axis=1)
Expand Down Expand Up @@ -1419,6 +1421,7 @@ def _handle_picks(epochs):
def _plot_window(value, params):
"""Deal with horizontal shift of the viewport."""
max_times = len(params['times']) - params['duration']
value = int(round(value))
if value > max_times:
value = len(params['times']) - params['duration']
if value < 0:
Expand Down Expand Up @@ -1681,7 +1684,8 @@ def _plot_onkey(event, params):
return
n_times = len(params['epochs'].times)
ticks = params['epoch_times'] + 0.5 * n_times
params['ax2'].set_xticks(ticks[:n_epochs])
for key in ('ax', 'ax2'):
params[key].set_xticks(ticks[:n_epochs])
params['n_epochs'] = n_epochs
params['duration'] -= n_times
params['hsel_patch'].set_width(params['duration'])
Expand All @@ -1693,7 +1697,8 @@ def _plot_onkey(event, params):
if n_times * n_epochs > len(params['times']):
return
ticks = params['epoch_times'] + 0.5 * n_times
params['ax2'].set_xticks(ticks[:n_epochs])
for key in ('ax', 'ax2'):
params[key].set_xticks(ticks[:n_epochs])
params['n_epochs'] = n_epochs
if len(params['vert_lines']) > 0:
ax = params['ax']
Expand Down Expand Up @@ -1842,7 +1847,8 @@ def _update_channels_epochs(event, params):
n_epochs = int(np.around(params['epoch_slider'].val))
n_times = len(params['epochs'].times)
ticks = params['epoch_times'] + 0.5 * n_times
params['ax2'].set_xticks(ticks[:n_epochs])
for key in ('ax', 'ax2'):
params[key].set_xticks(ticks[:n_epochs])
params['n_epochs'] = n_epochs
params['duration'] = n_times * n_epochs
params['hsel_patch'].set_width(params['duration'])
Expand Down Expand Up @@ -2003,7 +2009,7 @@ def _draw_event_lines(params):
includes_tzero = False
epochs = params['epochs']
n_times = len(epochs.times)
start_idx = int(params['t_start'] / n_times)
start_idx = params['t_start'] // n_times
color = params['event_colors']
ax = params['ax']
for ev_line in params['ev_lines']:
Expand Down
1 change: 1 addition & 0 deletions mne/viz/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,7 @@ def _plot_raw_traces(params, color, bad_color, event_lines=None,
params['times'][0] + params['first_time'] +
params['duration'], False)
if not butterfly:
params['ax'].set_yticks(params['offsets'][:len(tick_list)])
params['ax'].set_yticklabels(tick_list, rotation=0)
_set_ax_label_style(params['ax'], params)
if 'fig_selection' not in params:
Expand Down
2 changes: 1 addition & 1 deletion mne/viz/tests/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _annotation_helper(raw, events=False):
kind='release')
if mpl_good_enough:
assert raw.annotations.onset[n_anns] == onset
assert_allclose(raw.annotations.duration[n_anns], 1.5)
assert_allclose(raw.annotations.duration[n_anns], 1.5) # 4->1.5
# modify annotation from beginning
_fake_click(fig, data_ax, [1., 1.], xform='data', button=1, kind='press')
_fake_click(fig, data_ax, [0.5, 1.], xform='data', button=1, kind='motion')
Expand Down
12 changes: 6 additions & 6 deletions mne/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2433,8 +2433,8 @@ def _setup_butterfly(params):
ylim = (5. * len(picks), 0.)
ax.set_ylim(ylim)
offset = ylim[0] / (len(picks) + 1)
ticks = np.arange(0, ylim[0], offset)
ticks = [ticks[x] if x < len(ticks) else 0 for x in range(20)]
# ensure the last is not included
ticks = np.arange(0, ylim[0] - offset / 2., offset)
ax.set_yticks(ticks)
offsets = np.zeros(len(params['types']))

Expand All @@ -2443,8 +2443,8 @@ def _setup_butterfly(params):
offsets[pick] = offset * (group_idx + 1)
params['inds'] = params['orig_inds'].copy()
params['offsets'] = offsets
ax.set_yticklabels([''] + selections, color='black', rotation=45,
va='top')
ax.set_yticklabels(
[''] + selections, color='black', rotation=45, va='top')
else:
params['inds'] = params['orig_inds'].copy()
if 'fig_selection' not in params:
Expand Down Expand Up @@ -2600,8 +2600,8 @@ def log_fix(tval):
axes.spines['left'].set_bounds(*ybounds)
# handle axis labels
if skip_axlabel:
axes.set_yticklabels([])
axes.set_xticklabels([])
axes.set_yticklabels([''] * len(yticks))
axes.set_xticklabels([''] * len(xticks))
else:
if unit is not None:
axes.set_ylabel(unit, rotation=90)
Expand Down

0 comments on commit 4577b5f

Please sign in to comment.