Skip to content

Commit

Permalink
Improve _evalat
Browse files Browse the repository at this point in the history
slow=True is not only slower in the tested case, but also requires more
memmory. However, it is unclear whether that is true for all cases.
  • Loading branch information
dschwoerer committed Jan 8, 2024
1 parent 4dd3876 commit 5c8d43f
Showing 1 changed file with 49 additions and 28 deletions.
77 changes: 49 additions & 28 deletions xbout/fci/evaluate_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def weights(rz, mesh):
return weigth, ij


def _evalat(ds, r, phi, z, key, delta_phi, fill_value, progress):
def _evalat(ds, r, phi, z, key, delta_phi, fill_value, progress, slow=True):
r0, phi0, z0 = r, phi, z
dims, shape, coords, (r, phi, z) = get_out_shape(r, phi, z)
plus = tuple([f"delta_{x}" for x in "xyz"])
Expand Down Expand Up @@ -90,8 +90,9 @@ def add_dims(out):
out[k] = dims, v
return out

if key is None:
return add_dims(
return evaluate_at_keys(
ds,
add_dims(
xr.Dataset(
dict(
x=(dimsplus, ids[..., 0]),
Expand All @@ -101,27 +102,11 @@ def add_dims(out):
missing=(dims, missing),
)
)
)
if isinstance(key, str):
key = (key,)
slc = {x: xr.DataArray(ids[..., i], dims=dimsplus) for i, x in enumerate("xyz")}
weights = xr.DataArray(weights, dims=dimsplus)
out = xr.Dataset()
for k in key:
theisel = ds[k].isel(**slc, missing_dims="ignore")
try:
theisel = theisel.compute()
except AttributeError:
pass
out[k] = (theisel * weights).sum(dim=plus)
if np.any(missing):
# out[k].isel(
assert (
tuple(dims) == out[k].dims[-len(dims) :]
), f"{tuple(dims)} != {out[k].dims}"
out[k].values[..., missing] = _fill_value(fill_value, out[k].dtype)
# raise NotImplementedError("Missing data")
return add_dims(out)
),
key,
fill_value=fill_value,
slow=slow,
)


def _fill_value(fill_value, dtype):
Expand Down Expand Up @@ -161,19 +146,53 @@ def evaluate_at_keys(ds, keys, key, fill_value=np.nan, slow=False):
for x in "xyz":
_startswith(slc[x].dims, missing.dims)
slc[x].values[missing] = 0
print(x, np.min(slc[x]), np.max(slc[x]))
# print(x, np.min(slc[x]), np.max(slc[x]))
# Fix periodic indexing
for x in "yz":
slc[x] %= len(ds[x])
if slow:
theisel = ds[k].isel(**slc, missing_dims="ignore")
try:
theisel = theisel.compute()
except AttributeError:
pass
else:
slcp = [slc[d] if d in slc else slice(None) for d in ds[k].dims]
theisel = ds[k].values[tuple(slcp)]
out[k] = (theisel * weights).sum(dim=tuple([f"delta_{x}" for x in "xyz"]))
dat = np.asanyarray(ds[k].values)
ds[c] = ds[k].dims, dat
theisel = dat[tuple(slcp)]
assert theisel.shape[-3:] == (2, 2, 2)
assert weights.shape[-3:] == (2, 2, 2)
try:
assert theisel.dims[-3:] == tuple(
[f"delta_{x}" for x in "xyz"]
), f"Unexpected dims {theisel.dims}"
except AttributeError:
pass
try:
assert weights.dims[-3:] == tuple(
[f"delta_{x}" for x in "xyz"]
), f"Unexpected dims {weights.dims}"
except AttributeError:
pass
outd = []
for d in ds[k].dims:
if d == "x":
outd += weights.dims[:-3]
if d in "xyz":
continue
outd.append(d)
tmp = np.einsum("...ijk,...ijk->...", theisel, weights)
out[k] = tuple(outd), tmp
if np.any(missing):
outk = out[k].transpose(*missing.dims, ...)
outk.values[missing.values] = _fill_value(fill_value, outk.dtype)

for k in "R", "phi", "Z":
k = f"dim_{k}"
if k not in out and k in keys:
out[k] = keys[k]

return out


Expand Down Expand Up @@ -217,6 +236,7 @@ def _evaluate_get_single(ds, r, phi, z, delta_phi):
f1 += len(ds.y)
if f2 < 0:
f2 += len(ds.y)
assert np.isclose(f1 + f2, 1)
# print(f1, f2, dy)
fs = [f1, f2]
mshs = []
Expand Down Expand Up @@ -296,6 +316,7 @@ def evaluate_at_rpz(
delta_phi: float = None,
fill_value=np.nan,
progress=False,
slow=False,
):
"""
Evaluate the field key in the dataset at the positions given by
Expand Down Expand Up @@ -328,7 +349,7 @@ def evaluate_at_rpz(
Show the progress of the mapping. Defaults to False.
"""

return _evalat(self, r, phi, z, key, delta_phi, fill_value, progress)
return _evalat(self, r, phi, z, key, delta_phi, fill_value, progress, slow=slow)


setattr(evaluate_at_rpz, "evaluate_at_rpz", BoutDatasetAccessor)
Expand Down

0 comments on commit 5c8d43f

Please sign in to comment.