diff --git a/xbout/boutdataarray.py b/xbout/boutdataarray.py index b3a4edba..2a2c5ed6 100644 --- a/xbout/boutdataarray.py +++ b/xbout/boutdataarray.py @@ -1071,7 +1071,7 @@ def pcolormesh(self, ax=None, **kwargs): Colour-plot a radial-poloidal slice on the R-Z plane """ return plotfuncs.plot2d_wrapper(self.data, xr.plot.pcolormesh, ax=ax, **kwargs) - + def polygon(self, ax=None, **kwargs): """ Colour-plot of a radial-poloidal slice on the R-Z plane using polygons diff --git a/xbout/geometries.py b/xbout/geometries.py index 985422ee..2299f54b 100644 --- a/xbout/geometries.py +++ b/xbout/geometries.py @@ -381,14 +381,14 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None): "total_poloidal_distance", "zShift", "zShift_ylow", - "Rxy_corners", # Lower left corners + "Rxy_corners", # Lower left corners "Rxy_lower_right_corners", "Rxy_upper_left_corners", "Rxy_upper_right_corners", - "Zxy_corners", # Lower left corners + "Zxy_corners", # Lower left corners "Zxy_lower_right_corners", "Zxy_upper_left_corners", - "Zxy_upper_right_corners" + "Zxy_upper_right_corners", ], ) @@ -427,23 +427,24 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None): ds = ds.set_coords(("R", "Z")) else: ds = ds.set_coords(("Rxy", "Zxy")) - + # Add cell corners as coordinates for polygon plotting if "Rxy_lower_right_corners" in ds: ds = ds.rename( - Rxy_corners = "Rxy_lower_left_corners", - Zxy_corners = "Zxy_lower_left_corners" + Rxy_corners="Rxy_lower_left_corners", Zxy_corners="Zxy_lower_left_corners" + ) + ds = ds.set_coords( + ( + "Rxy_lower_left_corners", + "Rxy_lower_right_corners", + "Rxy_upper_left_corners", + "Rxy_upper_right_corners", + "Zxy_lower_left_corners", + "Zxy_lower_right_corners", + "Zxy_upper_left_corners", + "Zxy_upper_right_corners", ) - ds = ds.set_coords(( - "Rxy_lower_left_corners", - "Rxy_lower_right_corners", - "Rxy_upper_left_corners", - "Rxy_upper_right_corners", - "Zxy_lower_left_corners", - "Zxy_lower_right_corners", - "Zxy_upper_left_corners", - "Zxy_upper_right_corners" - )) + ) # Rename zShift_ylow if it was added from grid file, to be consistent with name if # it was added from dump file diff --git a/xbout/plotting/plotfuncs.py b/xbout/plotting/plotfuncs.py index 32007c3b..1ea10bc1 100644 --- a/xbout/plotting/plotfuncs.py +++ b/xbout/plotting/plotfuncs.py @@ -752,9 +752,9 @@ def create_or_update_plot(plot_objects=None, tind=None, this_save_as=None): X, Y, Z, scalars=data, vmin=vmin, vmax=vmax, **kwargs ) else: - plot_objects[ - region_name + str(i) - ].mlab_source.scalars = data + plot_objects[region_name + str(i)].mlab_source.scalars = ( + data + ) if mayavi_view is not None: mlab.view(*mayavi_view) @@ -850,49 +850,64 @@ def animation_func(): else: raise ValueError(f"Unrecognised plot3d() 'engine' argument: {engine}") + def plot2d_polygon( da, - ax = None, - cax = None, - cmap = "viridis", - norm = None, - logscale = False, - antialias = False, - vmin = None, - vmax = None, - extend = None, - add_colorbar = True, - colorbar_label = None, - separatrix = True, - separatrix_kwargs = {"color":"white", "linestyle":"-", "linewidth":1}, - targets = False, + ax=None, + cax=None, + cmap="viridis", + norm=None, + logscale=False, + antialias=False, + vmin=None, + vmax=None, + extend=None, + add_colorbar=True, + colorbar_label=None, + separatrix=True, + separatrix_kwargs={"color": "white", "linestyle": "-", "linewidth": 1}, + targets=False, add_limiter_hatching=True, - grid_only = False, - linewidth = 0, - linecolor = "black", - + grid_only=False, + linewidth=0, + linecolor="black", ): - + if ax == None: - fig, ax = plt.subplots(figsize=(3, 6), dpi = 120) + fig, ax = plt.subplots(figsize=(3, 6), dpi=120) else: fig = ax.get_figure() - + if cax == None: - cax = ax - + cax = ax + if vmin is None: vmin = np.nanmin(da.values) - + if vmax is None: vmax = np.nanmax(da.max().values) - if "Rxy_lower_right_corners" in da.coords: - r_nodes = ["R", "Rxy_lower_left_corners", "Rxy_lower_right_corners", "Rxy_upper_left_corners", "Rxy_upper_right_corners"] - z_nodes = ["Z", "Zxy_lower_left_corners", "Zxy_lower_right_corners", "Zxy_upper_left_corners", "Zxy_upper_right_corners"] - cell_r = np.concatenate([np.expand_dims(da[x], axis = 2) for x in r_nodes], axis = 2) - cell_z = np.concatenate([np.expand_dims(da[x], axis = 2) for x in z_nodes], axis = 2) + r_nodes = [ + "R", + "Rxy_lower_left_corners", + "Rxy_lower_right_corners", + "Rxy_upper_left_corners", + "Rxy_upper_right_corners", + ] + z_nodes = [ + "Z", + "Zxy_lower_left_corners", + "Zxy_lower_right_corners", + "Zxy_upper_left_corners", + "Zxy_upper_right_corners", + ] + cell_r = np.concatenate( + [np.expand_dims(da[x], axis=2) for x in r_nodes], axis=2 + ) + cell_z = np.concatenate( + [np.expand_dims(da[x], axis=2) for x in z_nodes], axis=2 + ) else: raise Exception("Cell corners not present in mesh, cannot do polygon plot") @@ -908,14 +923,15 @@ def plot2d_polygon( for i in range(Nx): for j in range(Ny): p = matplotlib.patches.Polygon( - np.concatenate((cell_r[i][j][tuple(idx)], cell_z[i][j][tuple(idx)])).reshape(2, 5).T, + np.concatenate((cell_r[i][j][tuple(idx)], cell_z[i][j][tuple(idx)])) + .reshape(2, 5) + .T, fill=False, closed=True, - facecolor = None, + facecolor=None, ) patches.append(p) - # create colorbar norm = _create_norm(logscale, norm, vmin, vmax) @@ -923,30 +939,31 @@ def plot2d_polygon( cmap = matplotlib.colors.ListedColormap(["white"]) colors = da.data.flatten() polys = matplotlib.collections.PatchCollection( - patches, alpha = 1, norm = norm, cmap = cmap, - antialiaseds = antialias, - edgecolors = linecolor, - linewidths = linewidth, - joinstyle = "bevel") - - - + patches, + alpha=1, + norm=norm, + cmap=cmap, + antialiaseds=antialias, + edgecolors=linecolor, + linewidths=linewidth, + joinstyle="bevel", + ) polys.set_array(colors) if add_colorbar: - fig.colorbar(polys, ax = cax, label = colorbar_label) - ax.add_collection(polys) - + fig.colorbar(polys, ax=cax, label=colorbar_label) + ax.add_collection(polys) + ax.set_aspect("equal", adjustable="box") ax.set_xlabel("R [m]") ax.set_ylabel("Z [m]") ax.set_ylim(cell_z.min(), cell_z.max()) ax.set_xlim(cell_r.min(), cell_r.max()) ax.set_title(da.name) - + if separatrix: - plot_separatrices(da, ax, x = "R", y = "Z", **separatrix_kwargs) - + plot_separatrices(da, ax, x="R", y="Z", **separatrix_kwargs) + if targets: - plot_targets(da, ax, x = "R", y = "Z", hatching = add_limiter_hatching) \ No newline at end of file + plot_targets(da, ax, x="R", y="Z", hatching=add_limiter_hatching) diff --git a/xbout/plotting/utils.py b/xbout/plotting/utils.py index 0e19331e..8a8d3c1a 100644 --- a/xbout/plotting/utils.py +++ b/xbout/plotting/utils.py @@ -120,7 +120,9 @@ def plot_separatrices(da, ax, *, x="R", y="Z", **kwargs): ) default_style = {"color": "black", "linestyle": "--"} if any(x for x in kwargs if x in ["c", "ls"]): - raise ValueError("When passing separatrix plot style kwargs, use 'color' and 'linestyle' instead lf 'c' and 'ls'") + raise ValueError( + "When passing separatrix plot style kwargs, use 'color' and 'linestyle' instead lf 'c' and 'ls'" + ) style = {**default_style, **kwargs} ax.plot(x_sep, y_sep, **style) diff --git a/xbout/tests/test_against_collect.py b/xbout/tests/test_against_collect.py index 5f22cf97..a1bf3da3 100644 --- a/xbout/tests/test_against_collect.py +++ b/xbout/tests/test_against_collect.py @@ -220,5 +220,4 @@ def test_new_collect_indexing_slice(self, tmp_path_factory): @pytest.mark.skip -class test_speed_against_old_collect: - ... +class test_speed_against_old_collect: ... diff --git a/xbout/tests/test_load.py b/xbout/tests/test_load.py index d8766236..bb4c917e 100644 --- a/xbout/tests/test_load.py +++ b/xbout/tests/test_load.py @@ -472,8 +472,7 @@ def test_combine_along_y(self, tmp_path_factory, bout_xyt_example_files): xrt.assert_identical(actual, fake) @pytest.mark.skip - def test_combine_along_t(self): - ... + def test_combine_along_t(self): ... @pytest.mark.parametrize( "bout_v5,metric_3D", [(False, False), (True, False), (True, True)] @@ -623,8 +622,7 @@ def test_drop_vars(self, tmp_path_factory, bout_xyt_example_files): assert "n" in ds.keys() @pytest.mark.skip - def test_combine_along_tx(self): - ... + def test_combine_along_tx(self): ... def test_restarts(self): datapath = Path(__file__).parent.joinpath( diff --git a/xbout/utils.py b/xbout/utils.py index 32be7edc..66dd9593 100644 --- a/xbout/utils.py +++ b/xbout/utils.py @@ -167,12 +167,16 @@ def _1d_coord_from_spacing(spacing, dim, ds=None, *, origin_at=None): ) point_to_use = { - spacing.metadata["bout_xdim"]: spacing.metadata.get("MXG", 0) - if spacing.metadata["keep_xboundaries"] - else 0, - spacing.metadata["bout_ydim"]: spacing.metadata.get("MYG", 0) - if spacing.metadata["keep_yboundaries"] - else 0, + spacing.metadata["bout_xdim"]: ( + spacing.metadata.get("MXG", 0) + if spacing.metadata["keep_xboundaries"] + else 0 + ), + spacing.metadata["bout_ydim"]: ( + spacing.metadata.get("MYG", 0) + if spacing.metadata["keep_yboundaries"] + else 0 + ), spacing.metadata["bout_zdim"]: spacing.metadata.get("MZG", 0), }