Skip to content

Commit

Permalink
Merge pull request #59 from fusion-energy/simpler_2d_mesh_plotting
Browse files Browse the repository at this point in the history
Improved 2d mesh plotting by squeezing end of tally shape only
  • Loading branch information
shimwell authored Nov 27, 2023
2 parents 5a4a8a7 + 97c2772 commit 772b74b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 33 deletions.
60 changes: 28 additions & 32 deletions src/openmc_regular_mesh_plotter/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
_default_outline_kwargs = {"colors": "black", "linestyles": "solid", "linewidths": 1}


def _squeeze_end_of_array(array, dims_required=3):
while len(array.shape) > dims_required:
array = np.squeeze(array, axis=len(array.shape) - 1)
return array


def plot_mesh_tally(
tally: "openmc.Tally",
basis: str = "xy",
Expand Down Expand Up @@ -95,12 +101,7 @@ def plot_mesh_tally(

mesh = tally.find_filter(filter_type=openmc.MeshFilter).mesh
if not isinstance(mesh, openmc.RegularMesh):
raise NotImplemented(
f"Only RegularMesh are currently supported not {type(mesh)}"
)
# if mesh.n_dimension != 3:
# msg = "Your mesh has {mesh.n_dimension} dimension and currently only RegularMesh with 3 dimensions are supported"
# raise NotImplementedError(msg)
raise NotImplemented(f"Only RegularMesh are supported not {type(mesh)}")

# if score is not specified and tally has a single score then we know which score to use
if score is None:
Expand All @@ -112,19 +113,32 @@ def plot_mesh_tally(

tally_slice = tally.get_slice(scores=[score])

# if mesh.n_dimension == 3:
basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis]

if 1 in mesh.dimension:
index_of_2d = mesh.dimension.index(1)
axis_of_2d = {0: "x", 1: "y", 2: "z"}[index_of_2d]
if (
axis_of_2d in basis
): # checks if the axis is being plotted, e.g is 'x' in 'xy'
raise ValueError(
"The selected tally has a mesh that has 1 dimension in the "
f"{axis_of_2d} axis, minimum of 2 needed to plot with a basis "
f"of {basis}."
)

# todo check if 1 appears twice or three times, raise value error if so
# TODO check if 1 appears twice or three times, raise value error if so

tally_data = tally_slice.get_reshaped_data(expand_dims=True, value=value).squeeze()
tally_data = tally_slice.get_reshaped_data(
expand_dims=True, value=value
) # .squeeze()

basis_to_index = {"xy": 2, "xz": 1, "yz": 0}[basis]
if len(tally_data.shape) == 3:
tally_data = _squeeze_end_of_array(tally_data, dims_required=3)

# if len(tally_data.shape) == 3:
if mesh.n_dimension == 3:
if slice_index is None:
# finds the mid index
slice_index = int(tally_data.shape[basis_to_index] / 2)

if basis == "xz":
Expand All @@ -139,29 +153,11 @@ def plot_mesh_tally(
slice_data = tally_data[:, :, slice_index]
data = np.rot90(slice_data, -3)
xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]"
# elif mesh.n_dimension == 2:
elif len(tally_data.shape) == 2:
if basis_to_index == index_of_2d:
slice_data = tally_data[:, :]
if basis == "xz":
data = np.flip(np.rot90(slice_data, -1))
xlabel, ylabel = f"x [{axis_units}]", f"z [{axis_units}]"
elif basis == "yz":
data = np.flip(np.rot90(slice_data, -1))
xlabel, ylabel = f"y [{axis_units}]", f"z [{axis_units}]"
else: # basis == 'xy'
data = np.rot90(slice_data, -3)
xlabel, ylabel = f"x [{axis_units}]", f"y [{axis_units}]"

else:
raise ValueError(
"The selected tally has a mesh that has 1 dimension in the "
f"{axis_of_2d} axis, minimum of 2 needed to plot with a basis "
f"of {basis}."
)

else:
raise ValueError("mesh n_dimension")
raise ValueError(
f"mesh n_dimension is not 3 or 2 but is {mesh.n_dimension} which is not supported"
)

if volume_normalization:
# in a regular mesh all volumes are the same so we just divide by the first
Expand Down
4 changes: 3 additions & 1 deletion tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ def test_plot_2d_mesh_tally(model):
tally_result = statepoint.get_tally(name="mesh-tal")

plot = plot_mesh_tally(
tally=tally_result, basis="yz", slice_index=29 # max value of slice selected
tally=tally_result, basis="yz", slice_index=0 # max value of slice selected
)

plot = plot_mesh_tally(tally=tally_result, basis="yz")
# axis_units defaults to cm
assert plot.xaxis.get_label().get_text() == "y [cm]"
assert plot.yaxis.get_label().get_text() == "z [cm]"
Expand Down

0 comments on commit 772b74b

Please sign in to comment.