Skip to content

Commit

Permalink
Apply black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
mikekryjak authored and github-actions[bot] committed Mar 22, 2024
1 parent d76febe commit d1fdc43
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 80 deletions.
2 changes: 1 addition & 1 deletion xbout/boutdataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,7 @@ def pcolormesh(self, ax=None, **kwargs):
Colour-plot a radial-poloidal slice on the R-Z plane
"""
return plotfuncs.plot2d_wrapper(self.data, xr.plot.pcolormesh, ax=ax, **kwargs)

def polygon(self, ax=None, **kwargs):
"""
Colour-plot of a radial-poloidal slice on the R-Z plane using polygons
Expand Down
33 changes: 17 additions & 16 deletions xbout/geometries.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,14 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):
"total_poloidal_distance",
"zShift",
"zShift_ylow",
"Rxy_corners", # Lower left corners
"Rxy_corners", # Lower left corners
"Rxy_lower_right_corners",
"Rxy_upper_left_corners",
"Rxy_upper_right_corners",
"Zxy_corners", # Lower left corners
"Zxy_corners", # Lower left corners
"Zxy_lower_right_corners",
"Zxy_upper_left_corners",
"Zxy_upper_right_corners"
"Zxy_upper_right_corners",
],
)

Expand Down Expand Up @@ -427,23 +427,24 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):
ds = ds.set_coords(("R", "Z"))
else:
ds = ds.set_coords(("Rxy", "Zxy"))

# Add cell corners as coordinates for polygon plotting
if "Rxy_lower_right_corners" in ds:
ds = ds.rename(
Rxy_corners = "Rxy_lower_left_corners",
Zxy_corners = "Zxy_lower_left_corners"
Rxy_corners="Rxy_lower_left_corners", Zxy_corners="Zxy_lower_left_corners"
)
ds = ds.set_coords(
(
"Rxy_lower_left_corners",
"Rxy_lower_right_corners",
"Rxy_upper_left_corners",
"Rxy_upper_right_corners",
"Zxy_lower_left_corners",
"Zxy_lower_right_corners",
"Zxy_upper_left_corners",
"Zxy_upper_right_corners",
)
ds = ds.set_coords((
"Rxy_lower_left_corners",
"Rxy_lower_right_corners",
"Rxy_upper_left_corners",
"Rxy_upper_right_corners",
"Zxy_lower_left_corners",
"Zxy_lower_right_corners",
"Zxy_upper_left_corners",
"Zxy_upper_right_corners"
))
)

# Rename zShift_ylow if it was added from grid file, to be consistent with name if
# it was added from dump file
Expand Down
117 changes: 67 additions & 50 deletions xbout/plotting/plotfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,9 +752,9 @@ def create_or_update_plot(plot_objects=None, tind=None, this_save_as=None):
X, Y, Z, scalars=data, vmin=vmin, vmax=vmax, **kwargs
)
else:
plot_objects[
region_name + str(i)
].mlab_source.scalars = data
plot_objects[region_name + str(i)].mlab_source.scalars = (
data
)

if mayavi_view is not None:
mlab.view(*mayavi_view)
Expand Down Expand Up @@ -850,49 +850,64 @@ def animation_func():
else:
raise ValueError(f"Unrecognised plot3d() 'engine' argument: {engine}")


def plot2d_polygon(
da,
ax = None,
cax = None,
cmap = "viridis",
norm = None,
logscale = False,
antialias = False,
vmin = None,
vmax = None,
extend = None,
add_colorbar = True,
colorbar_label = None,
separatrix = True,
separatrix_kwargs = {"color":"white", "linestyle":"-", "linewidth":1},
targets = False,
ax=None,
cax=None,
cmap="viridis",
norm=None,
logscale=False,
antialias=False,
vmin=None,
vmax=None,
extend=None,
add_colorbar=True,
colorbar_label=None,
separatrix=True,
separatrix_kwargs={"color": "white", "linestyle": "-", "linewidth": 1},
targets=False,
add_limiter_hatching=True,
grid_only = False,
linewidth = 0,
linecolor = "black",

grid_only=False,
linewidth=0,
linecolor="black",
):

if ax == None:
fig, ax = plt.subplots(figsize=(3, 6), dpi = 120)
fig, ax = plt.subplots(figsize=(3, 6), dpi=120)
else:
fig = ax.get_figure()

if cax == None:
cax = ax
cax = ax

if vmin is None:
vmin = np.nanmin(da.values)

if vmax is None:
vmax = np.nanmax(da.max().values)


if "Rxy_lower_right_corners" in da.coords:
r_nodes = ["R", "Rxy_lower_left_corners", "Rxy_lower_right_corners", "Rxy_upper_left_corners", "Rxy_upper_right_corners"]
z_nodes = ["Z", "Zxy_lower_left_corners", "Zxy_lower_right_corners", "Zxy_upper_left_corners", "Zxy_upper_right_corners"]
cell_r = np.concatenate([np.expand_dims(da[x], axis = 2) for x in r_nodes], axis = 2)
cell_z = np.concatenate([np.expand_dims(da[x], axis = 2) for x in z_nodes], axis = 2)
r_nodes = [
"R",
"Rxy_lower_left_corners",
"Rxy_lower_right_corners",
"Rxy_upper_left_corners",
"Rxy_upper_right_corners",
]
z_nodes = [
"Z",
"Zxy_lower_left_corners",
"Zxy_lower_right_corners",
"Zxy_upper_left_corners",
"Zxy_upper_right_corners",
]
cell_r = np.concatenate(
[np.expand_dims(da[x], axis=2) for x in r_nodes], axis=2
)
cell_z = np.concatenate(
[np.expand_dims(da[x], axis=2) for x in z_nodes], axis=2
)
else:
raise Exception("Cell corners not present in mesh, cannot do polygon plot")

Expand All @@ -908,45 +923,47 @@ def plot2d_polygon(
for i in range(Nx):
for j in range(Ny):
p = matplotlib.patches.Polygon(
np.concatenate((cell_r[i][j][tuple(idx)], cell_z[i][j][tuple(idx)])).reshape(2, 5).T,
np.concatenate((cell_r[i][j][tuple(idx)], cell_z[i][j][tuple(idx)]))
.reshape(2, 5)
.T,
fill=False,
closed=True,
facecolor = None,
facecolor=None,
)
patches.append(p)


# create colorbar
norm = _create_norm(logscale, norm, vmin, vmax)

if grid_only is True:
cmap = matplotlib.colors.ListedColormap(["white"])
colors = da.data.flatten()
polys = matplotlib.collections.PatchCollection(
patches, alpha = 1, norm = norm, cmap = cmap,
antialiaseds = antialias,
edgecolors = linecolor,
linewidths = linewidth,
joinstyle = "bevel")



patches,
alpha=1,
norm=norm,
cmap=cmap,
antialiaseds=antialias,
edgecolors=linecolor,
linewidths=linewidth,
joinstyle="bevel",
)

polys.set_array(colors)

if add_colorbar:
fig.colorbar(polys, ax = cax, label = colorbar_label)
ax.add_collection(polys)
fig.colorbar(polys, ax=cax, label=colorbar_label)
ax.add_collection(polys)

ax.set_aspect("equal", adjustable="box")
ax.set_xlabel("R [m]")
ax.set_ylabel("Z [m]")
ax.set_ylim(cell_z.min(), cell_z.max())
ax.set_xlim(cell_r.min(), cell_r.max())
ax.set_title(da.name)

if separatrix:
plot_separatrices(da, ax, x = "R", y = "Z", **separatrix_kwargs)
plot_separatrices(da, ax, x="R", y="Z", **separatrix_kwargs)

if targets:
plot_targets(da, ax, x = "R", y = "Z", hatching = add_limiter_hatching)
plot_targets(da, ax, x="R", y="Z", hatching=add_limiter_hatching)
4 changes: 3 additions & 1 deletion xbout/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def plot_separatrices(da, ax, *, x="R", y="Z", **kwargs):
)
default_style = {"color": "black", "linestyle": "--"}
if any(x for x in kwargs if x in ["c", "ls"]):
raise ValueError("When passing separatrix plot style kwargs, use 'color' and 'linestyle' instead lf 'c' and 'ls'")
raise ValueError(
"When passing separatrix plot style kwargs, use 'color' and 'linestyle' instead lf 'c' and 'ls'"
)
style = {**default_style, **kwargs}
ax.plot(x_sep, y_sep, **style)

Expand Down
3 changes: 1 addition & 2 deletions xbout/tests/test_against_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,5 +220,4 @@ def test_new_collect_indexing_slice(self, tmp_path_factory):


@pytest.mark.skip
class test_speed_against_old_collect:
...
class test_speed_against_old_collect: ...
6 changes: 2 additions & 4 deletions xbout/tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,7 @@ def test_combine_along_y(self, tmp_path_factory, bout_xyt_example_files):
xrt.assert_identical(actual, fake)

@pytest.mark.skip
def test_combine_along_t(self):
...
def test_combine_along_t(self): ...

@pytest.mark.parametrize(
"bout_v5,metric_3D", [(False, False), (True, False), (True, True)]
Expand Down Expand Up @@ -623,8 +622,7 @@ def test_drop_vars(self, tmp_path_factory, bout_xyt_example_files):
assert "n" in ds.keys()

@pytest.mark.skip
def test_combine_along_tx(self):
...
def test_combine_along_tx(self): ...

def test_restarts(self):
datapath = Path(__file__).parent.joinpath(
Expand Down
16 changes: 10 additions & 6 deletions xbout/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,16 @@ def _1d_coord_from_spacing(spacing, dim, ds=None, *, origin_at=None):
)

point_to_use = {
spacing.metadata["bout_xdim"]: spacing.metadata.get("MXG", 0)
if spacing.metadata["keep_xboundaries"]
else 0,
spacing.metadata["bout_ydim"]: spacing.metadata.get("MYG", 0)
if spacing.metadata["keep_yboundaries"]
else 0,
spacing.metadata["bout_xdim"]: (
spacing.metadata.get("MXG", 0)
if spacing.metadata["keep_xboundaries"]
else 0
),
spacing.metadata["bout_ydim"]: (
spacing.metadata.get("MYG", 0)
if spacing.metadata["keep_yboundaries"]
else 0
),
spacing.metadata["bout_zdim"]: spacing.metadata.get("MZG", 0),
}

Expand Down

0 comments on commit d1fdc43

Please sign in to comment.