Skip to content

Commit

Permalink
added support for returning matplot instead of showing on grid.visual…
Browse files Browse the repository at this point in the history
…ize and grid.dB_map_2D
  • Loading branch information
devbrones committed Dec 11, 2023
1 parent 3f4df26 commit b25f029
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions fdtd/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def visualize(
index=None, # index for each frame of animation (visualize fn runs in a loop, loop variable is passed as index)
save=False, # True to save frames (requires parameters index, folder)
folder=None, # folder path to save frames
gradio_support=False, # True to return figure for gradio support

):
"""visualize a projection of the grid and the optical energy inside the grid
Expand All @@ -62,6 +64,8 @@ def visualize(
save: save frames in a folder
folder: path to folder to save frames
"""
if gradio_support:
plt.style.use("https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle")
if norm not in ("linear", "lin", "log"):
raise ValueError("Color map normalization should be 'linear' or 'log'.")
# imports (placed here to circumvent circular imports)
Expand Down Expand Up @@ -327,8 +331,11 @@ def visualize(
if show:
plt.show()

if gradio_support:
return plt.gcf() # return figure for gradio support


def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16"):
def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16", gradio_support=False):
"""
Displays detector readings from an 'fdtd.BlockDetector' in a decibel map spanning a 2D slice region inside the BlockDetector.
Compatible with continuous sources (not pulse).
Expand All @@ -349,7 +356,8 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16"):
)

# TODO: convert all 2D slices (y-z, x-z plots) into x-y plot data structure

if gradio_support:
plt.style.use("https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle")
plt.ioff()
plt.close()
a = [] # array to store wave intensities
Expand All @@ -374,7 +382,10 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16"):
plt.imshow(a, cmap="inferno", interpolation=interpolation)
cbar = plt.colorbar()
cbar.ax.set_ylabel("dB scale", rotation=270)
plt.show()
if gradio_support:
return plt.gcf()
else:
plt.show()


def plot_detection(detector_dict=None, specific_plot=None):
Expand Down

0 comments on commit b25f029

Please sign in to comment.