diff --git a/fdtd/visualization.py b/fdtd/visualization.py index c8fd284..35ab1e5 100644 --- a/fdtd/visualization.py +++ b/fdtd/visualization.py @@ -42,7 +42,8 @@ def visualize( index=None, # index for each frame of animation (visualize fn runs in a loop, loop variable is passed as index) save=False, # True to save frames (requires parameters index, folder) folder=None, # folder path to save frames - gradio_support=False, # True to return figure for gradio support + ret_plot=False, # True to return figure for gradio support + pltstyle="https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle", ): """visualize a projection of the grid and the optical energy inside the grid @@ -63,9 +64,11 @@ def visualize( index: index for each frame of animation (typically a loop variable is passed) save: save frames in a folder folder: path to folder to save frames + ret_plot: return figure instead of showing it + pltstyle: Matplotlib style sheet to use for plotting. Default "https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle". """ - if gradio_support: - plt.style.use("https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle") + if ret_plot: + plt.style.use(pltstyle) if norm not in ("linear", "lin", "log"): raise ValueError("Color map normalization should be 'linear' or 'log'.") # imports (placed here to circumvent circular imports) @@ -331,11 +334,11 @@ def visualize( if show: plt.show() - if gradio_support: + if ret_plot: return plt.gcf() # return figure for gradio support -def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16", gradio_support=False): +def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16", ret_plot=False, pltstyle="https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle"): """ Displays detector readings from an 'fdtd.BlockDetector' in a decibel map spanning a 2D slice region inside the BlockDetector. Compatible with continuous sources (not pulse). @@ -345,6 +348,8 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16", gradio_su block_det (numpy array): 5 axes numpy array (timestep, row, column, height, {x, y, z} parameter) created by 'fdtd.BlockDetector'. (optional) choose_axis (int): Choose between {0, 1, 2} to display {x, y, z} data. Default 2 (-> z). (optional) interpolation (string): Preferred 'matplotlib.pyplot.imshow' interpolation. Default "spline16". + ret_plot (bool): True to return figure instead of showing it. + pltstyle (string): Matplotlib style sheet to use for plotting. Default "https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle". """ if block_det is None: raise ValueError( @@ -356,8 +361,8 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16", gradio_su ) # TODO: convert all 2D slices (y-z, x-z plots) into x-y plot data structure - if gradio_support: - plt.style.use("https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle") + if ret_plot: + plt.style.use(pltstyle) plt.ioff() plt.close() a = [] # array to store wave intensities @@ -382,7 +387,7 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16", gradio_su plt.imshow(a, cmap="inferno", interpolation=interpolation) cbar = plt.colorbar() cbar.ax.set_ylabel("dB scale", rotation=270) - if gradio_support: + if ret_plot: return plt.gcf() else: plt.show()