diff --git a/xbout/fci/evaluate_at.py b/xbout/fci/evaluate_at.py index 7a3876b1..8f785f47 100644 --- a/xbout/fci/evaluate_at.py +++ b/xbout/fci/evaluate_at.py @@ -54,7 +54,7 @@ def weights(rz, mesh): return weigth, ij -def _evalat(ds, r, phi, z, key, delta_phi, fill_value, progress): +def _evalat(ds, r, phi, z, key, delta_phi, fill_value, progress, slow=True): r0, phi0, z0 = r, phi, z dims, shape, coords, (r, phi, z) = get_out_shape(r, phi, z) plus = tuple([f"delta_{x}" for x in "xyz"]) @@ -90,8 +90,9 @@ def add_dims(out): out[k] = dims, v return out - if key is None: - return add_dims( + return evaluate_at_keys( + ds, + add_dims( xr.Dataset( dict( x=(dimsplus, ids[..., 0]), @@ -101,27 +102,11 @@ def add_dims(out): missing=(dims, missing), ) ) - ) - if isinstance(key, str): - key = (key,) - slc = {x: xr.DataArray(ids[..., i], dims=dimsplus) for i, x in enumerate("xyz")} - weights = xr.DataArray(weights, dims=dimsplus) - out = xr.Dataset() - for k in key: - theisel = ds[k].isel(**slc, missing_dims="ignore") - try: - theisel = theisel.compute() - except AttributeError: - pass - out[k] = (theisel * weights).sum(dim=plus) - if np.any(missing): - # out[k].isel( - assert ( - tuple(dims) == out[k].dims[-len(dims) :] - ), f"{tuple(dims)} != {out[k].dims}" - out[k].values[..., missing] = _fill_value(fill_value, out[k].dtype) - # raise NotImplementedError("Missing data") - return add_dims(out) + ), + key, + fill_value=fill_value, + slow=slow, + ) def _fill_value(fill_value, dtype): @@ -161,19 +146,53 @@ def evaluate_at_keys(ds, keys, key, fill_value=np.nan, slow=False): for x in "xyz": _startswith(slc[x].dims, missing.dims) slc[x].values[missing] = 0 - print(x, np.min(slc[x]), np.max(slc[x])) + # print(x, np.min(slc[x]), np.max(slc[x])) # Fix periodic indexing for x in "yz": slc[x] %= len(ds[x]) if slow: theisel = ds[k].isel(**slc, missing_dims="ignore") + try: + theisel = theisel.compute() + except AttributeError: + pass else: slcp = [slc[d] if d in slc else slice(None) for d in ds[k].dims] - theisel = ds[k].values[tuple(slcp)] - out[k] = (theisel * weights).sum(dim=tuple([f"delta_{x}" for x in "xyz"])) + dat = np.asanyarray(ds[k].values) + ds[c] = ds[k].dims, dat + theisel = dat[tuple(slcp)] + assert theisel.shape[-3:] == (2, 2, 2) + assert weights.shape[-3:] == (2, 2, 2) + try: + assert theisel.dims[-3:] == tuple( + [f"delta_{x}" for x in "xyz"] + ), f"Unexpected dims {theisel.dims}" + except AttributeError: + pass + try: + assert weights.dims[-3:] == tuple( + [f"delta_{x}" for x in "xyz"] + ), f"Unexpected dims {weights.dims}" + except AttributeError: + pass + outd = [] + for d in ds[k].dims: + if d == "x": + outd += weights.dims[:-3] + if d in "xyz": + continue + outd.append(d) + tmp = np.einsum("...ijk,...ijk->...", theisel, weights) + out[k] = tuple(outd), tmp if np.any(missing): outk = out[k].transpose(*missing.dims, ...) outk.values[missing.values] = _fill_value(fill_value, outk.dtype) + + for k in "R", "phi", "Z": + k = f"dim_{k}" + if k not in out and k in keys: + out[k] = keys[k] + return out @@ -217,6 +236,7 @@ def _evaluate_get_single(ds, r, phi, z, delta_phi): f1 += len(ds.y) if f2 < 0: f2 += len(ds.y) + assert np.isclose(f1 + f2, 1) # print(f1, f2, dy) fs = [f1, f2] mshs = [] @@ -296,6 +316,7 @@ def evaluate_at_rpz( delta_phi: float = None, fill_value=np.nan, progress=False, + slow=False, ): """ Evaluate the field key in the dataset at the positions given by @@ -328,7 +349,7 @@ def evaluate_at_rpz( Show the progress of the mapping. Defaults to False. """ - return _evalat(self, r, phi, z, key, delta_phi, fill_value, progress) + return _evalat(self, r, phi, z, key, delta_phi, fill_value, progress, slow=slow) setattr(evaluate_at_rpz, "evaluate_at_rpz", BoutDatasetAccessor)