Skip to content

Commit

Permalink
moved gradio_support to ret_plot for better dx
Browse files Browse the repository at this point in the history
  • Loading branch information
devbrones committed Dec 15, 2023
1 parent b25f029 commit f7f1db4
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions fdtd/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def visualize(
index=None, # index for each frame of animation (visualize fn runs in a loop, loop variable is passed as index)
save=False, # True to save frames (requires parameters index, folder)
folder=None, # folder path to save frames
gradio_support=False, # True to return figure for gradio support
ret_plot=False, # True to return figure for gradio support
pltstyle="https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle",

):
"""visualize a projection of the grid and the optical energy inside the grid
Expand All @@ -63,9 +64,11 @@ def visualize(
index: index for each frame of animation (typically a loop variable is passed)
save: save frames in a folder
folder: path to folder to save frames
ret_plot: return figure instead of showing it
pltstyle: Matplotlib style sheet to use for plotting. Default "https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle".
"""
if gradio_support:
plt.style.use("https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle")
if ret_plot:
plt.style.use(pltstyle)
if norm not in ("linear", "lin", "log"):
raise ValueError("Color map normalization should be 'linear' or 'log'.")
# imports (placed here to circumvent circular imports)
Expand Down Expand Up @@ -331,11 +334,11 @@ def visualize(
if show:
plt.show()

if gradio_support:
if ret_plot:
return plt.gcf() # return figure for gradio support


def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16", gradio_support=False):
def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16", ret_plot=False, pltstyle="https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle"):
"""
Displays detector readings from an 'fdtd.BlockDetector' in a decibel map spanning a 2D slice region inside the BlockDetector.
Compatible with continuous sources (not pulse).
Expand All @@ -345,6 +348,8 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16", gradio_su
block_det (numpy array): 5 axes numpy array (timestep, row, column, height, {x, y, z} parameter) created by 'fdtd.BlockDetector'.
(optional) choose_axis (int): Choose between {0, 1, 2} to display {x, y, z} data. Default 2 (-> z).
(optional) interpolation (string): Preferred 'matplotlib.pyplot.imshow' interpolation. Default "spline16".
ret_plot (bool): True to return figure instead of showing it.
pltstyle (string): Matplotlib style sheet to use for plotting. Default "https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle".
"""
if block_det is None:
raise ValueError(
Expand All @@ -356,8 +361,8 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16", gradio_su
)

# TODO: convert all 2D slices (y-z, x-z plots) into x-y plot data structure
if gradio_support:
plt.style.use("https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle")
if ret_plot:
plt.style.use(pltstyle)
plt.ioff()
plt.close()
a = [] # array to store wave intensities
Expand All @@ -382,7 +387,7 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16", gradio_su
plt.imshow(a, cmap="inferno", interpolation=interpolation)
cbar = plt.colorbar()
cbar.ax.set_ylabel("dB scale", rotation=270)
if gradio_support:
if ret_plot:
return plt.gcf()
else:
plt.show()
Expand Down

0 comments on commit f7f1db4

Please sign in to comment.