diff --git a/xugrid/regrid/regridder.py b/xugrid/regrid/regridder.py index 88844cb31..35b0bcae3 100644 --- a/xugrid/regrid/regridder.py +++ b/xugrid/regrid/regridder.py @@ -15,10 +15,8 @@ import dask.array DaskArray = dask.array.Array - DaskRechunk = dask.array.rechunk except ImportError: DaskArray = () - DaskRechunk = () import xugrid from xugrid.constants import FloatArray @@ -148,9 +146,14 @@ def _regrid_array(self, source): size = self._target.size if isinstance(source, DaskArray): - # for DaskArray's from multiple partitions, rechunk first to single size per dimension - # for now always rechunk, could be optional only when explicit chunks in single dimension - source = DaskRechunk(source, source.shape) + # It's possible that the topology dimensions are chunked (e.g. from + # reading multiple partitions). The regrid operation does not + # support this, since we might need multiple source chunks for a + # single target chunk, which destroys the 1:1 relation between + # chunks. Here we ensure that the topology dimensions are contained + # in a single contiguous chunk. + contiguous_chunks = (source.chunks[0], (source.shape[-1],)) + source = source.rechunk(contiguous_chunks) chunks = source.chunks[: -source_grid.ndim] + (self._target.shape) out = dask.array.map_blocks( self._regrid, # func @@ -161,9 +164,6 @@ def _regrid_array(self, source): chunks=chunks, meta=np.array((), dtype=source.dtype), ) - # TODO: for now we compute first, since .reshape and dask.array.reshape - # does not reshapes the underlying data somehow. This need to be evaluated. - out = out.compute() elif isinstance(source, np.ndarray): out = self._regrid(source, self._weights, size) else: