Skip to content

Commit

Permalink
fix: mypy mistakes axes3D for axes
Browse files Browse the repository at this point in the history
  • Loading branch information
AthenaCaesura committed Oct 19, 2023
1 parent 28a0f13 commit 7a17567
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 13 deletions.
5 changes: 2 additions & 3 deletions src/orqviz/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import matplotlib.ticker as tck
import numpy as np
from matplotlib.cm import ScalarMappable
import mpl_toolkits


def normalize_color_and_colorbar(
Expand Down Expand Up @@ -65,7 +66,6 @@ def get_colorbar_from_ax(
_, ax = _check_and_create_fig_ax(ax=ax)

if image_index is None:

len_collections = len(ax.collections)
len_images = len(ax.images)

Expand Down Expand Up @@ -158,7 +158,6 @@ def _check_and_create_fig_ax(
fig: Optional[plt.Figure] = None,
ax: Optional[plt.Axes] = None,
) -> Tuple[plt.Figure, plt.Axes]:

if fig is None:
fig = plt.gcf()

Expand All @@ -170,7 +169,7 @@ def _check_and_create_fig_ax(

def _check_and_create_3D_ax(
ax: Optional[plt.Axes] = None,
) -> plt.Axes:
) -> mpl_toolkits.mplot3d.Axes3D:
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(projection="3d")
Expand Down
17 changes: 7 additions & 10 deletions src/orqviz/scans/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
import matplotlib
import numpy as np

# this import is unused but solves issues with older matplotlib versions and 3d plots
from mpl_toolkits.mplot3d import Axes3D

from ..plot_utils import _check_and_create_3D_ax, _check_and_create_fig_ax
from .data_structures import Scan1DResult, Scan2DResult

Expand Down Expand Up @@ -161,9 +158,9 @@ def plot_2D_scan_result_as_3D(
plot_kwargs: kwargs for plotting with matplotlib.pyplot.plot_surface
(plt.plot_surface)
"""
ax = _check_and_create_3D_ax(ax=ax)
ax3D = _check_and_create_3D_ax(ax=ax)

assert ax is not None
assert ax3D is not None

x, y = scan2D_result._get_coordinates_on_directions(
in_units_of_direction=in_units_of_direction
Expand All @@ -173,10 +170,10 @@ def plot_2D_scan_result_as_3D(
plot_kwargs_defaults = {"cmap": "viridis", "alpha": 0.8}
plot_kwargs = {**plot_kwargs_defaults, **plot_kwargs}

ax.plot_surface(XX, YY, scan2D_result.values, **plot_kwargs)
ax3D.plot_surface(XX, YY, scan2D_result.values, **plot_kwargs)

ax.view_init(elev=35, azim=-70)
ax3D.view_init(elev=35, azim=-70)

ax.set_xlabel("Scan Direction x")
ax.set_ylabel("Scan Direction y")
ax.set_zlabel("Loss Value")
ax3D.set_xlabel("Scan Direction x")
ax3D.set_ylabel("Scan Direction y")
ax3D.set_zlabel("Loss Value")

0 comments on commit 7a17567

Please sign in to comment.