diff --git a/fdtd/visualization.py b/fdtd/visualization.py index adffc42..c8fd284 100644 --- a/fdtd/visualization.py +++ b/fdtd/visualization.py @@ -42,6 +42,8 @@ def visualize( index=None, # index for each frame of animation (visualize fn runs in a loop, loop variable is passed as index) save=False, # True to save frames (requires parameters index, folder) folder=None, # folder path to save frames + gradio_support=False, # True to return figure for gradio support + ): """visualize a projection of the grid and the optical energy inside the grid @@ -62,6 +64,8 @@ def visualize( save: save frames in a folder folder: path to folder to save frames """ + if gradio_support: + plt.style.use("https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle") if norm not in ("linear", "lin", "log"): raise ValueError("Color map normalization should be 'linear' or 'log'.") # imports (placed here to circumvent circular imports) @@ -327,8 +331,11 @@ def visualize( if show: plt.show() + if gradio_support: + return plt.gcf() # return figure for gradio support + -def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16"): +def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16", gradio_support=False): """ Displays detector readings from an 'fdtd.BlockDetector' in a decibel map spanning a 2D slice region inside the BlockDetector. Compatible with continuous sources (not pulse). @@ -349,7 +356,8 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16"): ) # TODO: convert all 2D slices (y-z, x-z plots) into x-y plot data structure - + if gradio_support: + plt.style.use("https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle") plt.ioff() plt.close() a = [] # array to store wave intensities @@ -374,7 +382,10 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16"): plt.imshow(a, cmap="inferno", interpolation=interpolation) cbar = plt.colorbar() cbar.ax.set_ylabel("dB scale", rotation=270) - plt.show() + if gradio_support: + return plt.gcf() + else: + plt.show() def plot_detection(detector_dict=None, specific_plot=None):