Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy rectilinear interpolator #6084

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions docs/src/whatsnew/latest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ This document explains the changes made to Iris for this release
the concatenation axis. This issue can be avoided by disabling the
problematic check. (:pull:`5926`)

#. `@fnattino`_ enabled lazy cube interpolation using the linear and
nearest-neighbour interpolators (:class:`iris.analysis.Linear` and
:class:`iris.analysis.Nearest`). (:pull:`6084`)

🔥 Deprecations
===============

Expand Down Expand Up @@ -101,6 +105,7 @@ This document explains the changes made to Iris for this release
Whatsnew author names (@github name) in alphabetical order. Note that,
core dev names are automatically included by the common_links.inc:

.. _@fnattino: https://github.com/fnattino
.. _@jrackham-mo: https://github.com/jrackham-mo


Expand Down
131 changes: 80 additions & 51 deletions lib/iris/analysis/_interpolation.py
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from numpy.lib.stride_tricks import as_strided
import numpy.ma as ma

from iris._lazy_data import map_complete_blocks
from iris.coords import AuxCoord, DimCoord
import iris.util

Expand Down Expand Up @@ -163,6 +164,15 @@ def snapshot_grid(cube):
return x.copy(), y.copy()


def _interpolated_dtype(dtype, method):
"""Determine the minimum base dtype required by the underlying interpolator."""
if method == "nearest":
result = dtype
else:
result = np.result_type(_DEFAULT_DTYPE, dtype)
return result
Comment on lines +167 to +173
Copy link
Contributor

@trexfeathers trexfeathers Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this needs to stay (see my other comment about args=[self] - #6084 (comment)), then I'd be interested in us unifying this function with RectilinearInterpolator._interpolated_dtype().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean to put this back as a (static)method of the RectilinearInterpolator? Or to merge the body of this function with RectilinearInterpolator._interpolate?



class RectilinearInterpolator:
"""Provide support for performing nearest-neighbour or linear interpolation.

Expand Down Expand Up @@ -200,13 +210,8 @@ def __init__(self, src_cube, coords, method, extrapolation_mode):
set to NaN.

"""
# Trigger any deferred loading of the source cube's data and snapshot
# its state to ensure that the interpolator is impervious to external
# changes to the original source cube. The data is loaded to prevent
# the snapshot having lazy data, avoiding the potential for the
# same data to be loaded again and again.
if src_cube.has_lazy_data():
src_cube.data
# Snapshot the cube state to ensure that the interpolator is impervious
# to external changes to the original source cube.
self._src_cube = src_cube.copy()
# Coordinates defining the dimensions to be interpolated.
self._src_coords = [self._src_cube.coord(coord) for coord in coords]
Expand Down Expand Up @@ -277,17 +282,27 @@ def _account_for_inverted(self, data):
data = data[tuple(dim_slices)]
return data

def _interpolate(self, data, interp_points):
@staticmethod
def _interpolate(
data,
src_points,
interp_points,
interp_shape,
method="linear",
extrapolation_mode="nanmask",
):
"""Interpolate a data array over N dimensions.

Create and cache the underlying interpolator instance before invoking
it to perform interpolation over the data at the given coordinate point
values.
Create the interpolator instance before invoking it to perform
interpolation over the data at the given coordinate point values.

Parameters
----------
data : ndarray
A data array, to be interpolated in its first 'N' dimensions.
src_points :
The point values defining the dimensions to be interpolated.
(len(src_points) should be N).
interp_points : ndarray
An array of interpolation coordinate values.
Its shape is (..., N) where N is the number of interpolation
Expand All @@ -296,44 +311,53 @@ def _interpolate(self, data, interp_points):
coordinate, which is mapped to the i'th data dimension.
The other (leading) dimensions index over the different required
sample points.
interp_shape :
The shape of the interpolated array in its first 'N' dimensions
(len(interp_shape) should be N).
method : str
Interpolation method (see :class:`iris.analysis._interpolation.RectilinearInterpolator`).
extrapolation_mode : str
Extrapolation mode (see :class:`iris.analysis._interpolation.RectilinearInterpolator`).

Returns
-------
:class:`np.ndarray`.
Its shape is "points_shape + extra_shape",
Its shape is "interp_shape + extra_shape",
where "extra_shape" is the remaining non-interpolated dimensions of
the data array (i.e. 'data.shape[N:]'), and "points_shape" is the
leading dimensions of interp_points,
(i.e. 'interp_points.shape[:-1]').

the data array (i.e. 'data.shape[N:]').
"""
from iris.analysis._scipy_interpolate import _RegularGridInterpolator

dtype = self._interpolated_dtype(data.dtype)
dtype = _interpolated_dtype(data.dtype, method)
if data.dtype != dtype:
# Perform dtype promotion.
data = data.astype(dtype)

mode = EXTRAPOLATION_MODES[self._mode]
if self._interpolator is None:
# Cache the interpolator instance.
# NB. The constructor of the _RegularGridInterpolator class does
# some unnecessary checks on the fill_value parameter,
# so we set it afterwards instead. Sneaky. ;-)
self._interpolator = _RegularGridInterpolator(
self._src_points,
data,
method=self.method,
bounds_error=mode.bounds_error,
fill_value=None,
)
else:
self._interpolator.values = data
Comment on lines -318 to -331
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_interpolate is now mapped to all the chunks of the data we are interpolating, thus we cannot cache the _RegularGridInterpolator instance anymore. Is this acceptable?

One can still cache the RectilinearInterpolator thought, and, as far as I can say, this would be similar to the way in which the caching of the regridder works?

Copy link
Contributor

@trexfeathers trexfeathers Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having read through the code base: I believe the caching is because the same interpolation routine gets used on a series of coordinates - start here and follow the call stack down to _points():

for coord in cube.dim_coords + cube.aux_coords:
new_coord, dims = construct_new_coord(coord)
gen_new_cube()

... meaning multiple uses of _RegularGridInterpolator but with only .values changing. I assume that there is noticeable overhead when creating a new instance of _RegularGridInterpolator, but that there is no problem applying the same instance to a series of different arrays, hence the caching.

There is a risk that your refactoring - which makes the routine much faster when parallelising large datasets - will make small-to-medium cases slower, especially for those cases that need to interpolate a larger number of coordinates. So I am very keen that we try out my suggestion using args=[self] (#6084 (comment)), which I believe would allow us to retain the existing caching? If that does not work then I'll need to assess the performance impact with some benchmarking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point. But I am slightly afraid the args=[self] approach could be problematic due to the reference to the full cube being sent to all workers (see answer to your other comment below)..

So I had a closer look at the _RegularGridInterpolator: It looks to me that only few checks are carried out when initializing the object, while all the expensive tasks (calculation of weights, interpolation) only takes place when calling interpolator:

weights = self.compute_interp_weights(xi, method)
return self.interp_using_pre_computed_weights(weights)

I have measured some timings based on the example you provided in the comment above, using it as a test case for a small dataset:

Code (run into Jupyter to benchmark timings)
import numpy as np
import iris.analysis

from iris.coords import DimCoord
from iris.cube import Cube
from iris.analysis._scipy_interpolate import _RegularGridInterpolator

# Create a series of orthogonal coordinates
longitude = np.linspace(1, 9, 5)
latitude = np.linspace(10, 90, 9)
coord_names_and_points = [
    ('longitude', longitude),
    ('latitude', latitude),
    ('altitude', np.linspace(100, 900, 18)),
    ('time', np.linspace(1000, 9000, 36)),
]
coord_list = [
    DimCoord(points, standard_name=name)
    for name, points in coord_names_and_points
]

# Create a Cube to house the above coordinates.
shape = [len(points) for _, points in coord_names_and_points]
dimension_mapping = [
    (c, ix) for ix, c in enumerate(coord_list)
]  # [(time, 0), (height, 1), ...]
data = np.arange(np.prod(shape)).reshape(shape)
cube = Cube(data, dim_coords_and_dims=dimension_mapping)

# Perform interpolation over multiple dimensions, but NOT all the dimensions of the Cube
# So we're estimating the values that would appear at:
# (3.5, 15), (3.5, 25), (3.5, 75), (8.5, 15), (8.5, 25), (8.5, 75)
coords = ("longitude", "latitude")
points = [[3.5, 8.5], [15, 25, 75]]

# Create the interpolator instance to benchmark total time of interpolation
interpolator = iris.analysis.RectilinearInterpolator(
    cube, coords, "linear", "mask"
)

%%timeit
# Measure total time of interpolation
result = interpolator(points, collapse_scalar=True)
# 1.28 ms ± 40.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

%%timeit
# Measure time required to instantiate the interpolator
_ = _RegularGridInterpolator(
    [longitude, latitude],
    data,
    method="linear",
    bounds_error=False,
    fill_value=None,
)
# 60.6 μs ± 17 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

So the time to instantiate the interpolator is quite small (<5%) compared to the total time for interpolation, even for such a small dataset. So, I am actually wondering whether caching the interpolator ultimately brings much benefits - but maybe I am not looking at a suitable test case? What do you think?

# Determine the shape of the interpolated result.
ndims_interp = len(interp_shape)
extra_shape = data.shape[ndims_interp:]
final_shape = [*interp_shape, *extra_shape]
Comment on lines +336 to +339
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for @trexfeathers: understand the logic of reshaping, and the way it is now split between _interpolate() and _points().

Copy link
Contributor

@trexfeathers trexfeathers Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HarriKeenan after you shared my struggle last week, here is the script I used to help me eventually understand, in case you were interested:

Expand to view code
"""
Script to help understand how array shapes are handled in Iris interpolation.

Creates a scenario that is easy to follow when debugging iris.analysis._interpolation.py.
"""

import numpy as np

import iris.analysis
from iris.coords import DimCoord
from iris.cube import Cube


# Create a series of orthogonal coordinates that are obviously different - to
#  help with debugging Iris internals.
coord_names_and_points = [
    ('longitude', np.linspace(1, 9, 5)),
    ('latitude', np.linspace(10, 90, 9)),
    ('altitude', np.linspace(100, 900, 18)),
    ('time', np.linspace(1000, 9000, 36)),
]
coord_list = [
    DimCoord(points, standard_name=name)
    for name, points in coord_names_and_points
]

# Create a Cube to house the above coordinates.
#  The data is not the subject of this investigation so can be arbitrary.
shape = [len(points) for name, points in coord_names_and_points]
dimension_mapping = [
    (c, ix) for ix, c in enumerate(coord_list)
]  # [(time, 0), (height, 1), ...]
cube = Cube(
    np.arange(np.prod(shape)).reshape(shape),
    dim_coords_and_dims=dimension_mapping,
)
print(cube)
unknown / (unknown)                 (longitude: 5; latitude: 9; altitude: 18; time: 36)
    Dimension coordinates:
        longitude                             x            -            -         -
        latitude                              -            x            -         -
        altitude                              -            -            x         -
        time                                  -            -            -         x
# Perform interpolation over multiple dimensions, but NOT all the dimensions
#  of the Cube, so we can see how the most complex cases are handled when
#  debugging.
# So we're estimating the values that would appear at:
#  (3.5, 15), (3.5, 25), (3.5, 75), (8.5, 15), (8.5, 25), (8.5, 75)
sampling_points = [("longitude", [3.5, 8.5]), ("latitude", [15, 25, 75])]
interpolated_cube = cube.interpolate(sampling_points, iris.analysis.Linear())
print(interpolated_cube)
unknown / (unknown)                 (longitude: 2; latitude: 3; altitude: 18; time: 36)
    Dimension coordinates:
        longitude                             x            -            -         -
        latitude                              -            x            -         -
        altitude                              -            -            x         -
        time                                  -            -            -         x


mode = EXTRAPOLATION_MODES[extrapolation_mode]
_data = np.ma.getdata(data)
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
# NB. The constructor of the _RegularGridInterpolator class does
# some unnecessary checks on the fill_value parameter,
# so we set it afterwards instead. Sneaky. ;-)
interpolator = _RegularGridInterpolator(
src_points,
_data,
method=method,
bounds_error=mode.bounds_error,
fill_value=None,
)
interpolator.fill_value = mode.fill_value
result = interpolator(interp_points)

# We may be re-using a cached interpolator, so ensure the fill
# value is set appropriately for extrapolating data values.
self._interpolator.fill_value = mode.fill_value
result = self._interpolator(interp_points)
# The interpolated result has now shape "points_shape + extra_shape"
# where "points_shape" is the leading dimension of "interp_points"
Comment on lines +356 to +357
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has there been some renaming since this was written? points_shape -> interp_shape?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The renaming of the shape in the docstring comes from the fact that I have moved some of the reshaping of the interpolated results from outside to inside _interpolate() (see comment), so the shape of the data that _interpolate() returns is actually different.

Here, I am describing the reshaping that takes place here: bringing data from shape "points_shape + extra_shape" (as returned by interpolator() ) to shape "interp_shape + extra_shape" (="final_shape") .

# (i.e. 'interp_points.shape[:-1]'). We reshape it to match the shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not true if there is more than 1 dimension not being interpolated (i.e. len(extra_shape) > 1).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uhm, I have looked at the example you have shared in the comment above (very useful, thanks!), where you have 4 dimensions, 2 of which are not interpolated.

At this point of the code, I get the following shapes:

interp_points.shape
# (6, 2) --> 6 points with 2 coordinates each

interp_points.shape[:-1]
# (6, ) --> this is "points_shape"

extra_shape
# (18, 36)

result.shape
# (6, 18, 36) --> thus equal to "points_shape + extra_shape"

Or am I missing something? Maybe I should rephrase the comment in the following way? I have just used the same description as previously provided in the docstring.

# The interpolated result has now shape "(num_points, ) + extra_shape"
# where "num_points" is the number of points for which we are carrying
# out the interpolation. We reshape it to match the shape of the
# interpolated dimensions.

# of the interpolated dimensions.
result = result.reshape(final_shape)
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved

if result.dtype != data.dtype:
# Cast the data dtype to be as expected. Note that, the dtype
Expand All @@ -346,13 +370,11 @@ def _interpolate(self, data, interp_points):
# `data` is not a masked array.
src_mask = np.ma.getmaskarray(data)
# Switch the extrapolation to work with mask values.
self._interpolator.fill_value = mode.mask_fill_value
self._interpolator.values = src_mask
mask_fraction = self._interpolator(interp_points)
interpolator.fill_value = mode.mask_fill_value
interpolator.values = src_mask
mask_fraction = interpolator(interp_points)
new_mask = mask_fraction > 0
if ma.isMaskedArray(data) or np.any(new_mask):
result = np.ma.MaskedArray(result, new_mask)

result = np.ma.MaskedArray(result, new_mask)
fnattino marked this conversation as resolved.
Show resolved Hide resolved
return result

def _resample_coord(self, sample_points, coord, coord_dims):
Expand Down Expand Up @@ -530,7 +552,7 @@ def _points(self, sample_points, data, data_dims=None):
_, src_order = zip(*sorted(dmap.items(), key=operator.itemgetter(0)))

# Prepare the sample points for interpolation and calculate the
# shape of the interpolated result.
# shape of the interpolated dimensions.
interp_points = []
interp_shape = []
for index, points in enumerate(sample_points):
Expand All @@ -539,10 +561,6 @@ def _points(self, sample_points, data, data_dims=None):
interp_points.append(points)
interp_shape.append(points.size)

interp_shape.extend(
length for dim, length in enumerate(data.shape) if dim not in di
)

# Convert the interpolation points into a cross-product array
# with shape (n_cross_points, n_dims)
interp_points = np.asarray([pts for pts in product(*interp_points)])
Expand All @@ -554,9 +572,20 @@ def _points(self, sample_points, data, data_dims=None):
# Transpose data in preparation for interpolation.
data = np.transpose(data, interp_order)

# Interpolate and reshape the data ...
result = self._interpolate(data, interp_points)
result = result.reshape(interp_shape)
# Interpolate the data, ensuring the interpolated dimensions
# are not chunked.
dims_not_chunked = [dmap[d] for d in di]
result = map_complete_blocks(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for @trexfeathers: confirm that this use of map_complete_blocks() does not suffer the memory leaks described in #5767.

data,
self._interpolate,
dims=dims_not_chunked,
out_sizes=interp_shape,
src_points=self._src_points,
interp_points=interp_points,
interp_shape=interp_shape,
method=self._method,
extrapolation_mode=self._mode,
Comment on lines +583 to +587
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you considered using self=self or something like that instead? I.e. not needing to refactor _interpolate() to be passed all these values as arguments, but instead to just access them from the existing instance of RectilinearInterpolator. I'm not 100% sure that this works but it would be worth trying?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More specifically I think it would be like this, since self is a positional argument (not a keyword argument):

Suggested change
src_points=self._src_points,
interp_points=interp_points,
interp_shape=interp_shape,
method=self._method,
extrapolation_mode=self._mode,
args=[self],
interp_shape=interp_shape,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uhm, I haven't thought about that possibility. The main reason why I have refactored _interpolate() is that the existing instance contains a copy of the data to be interpolated (self._src_cube), so I thought it was safer to make _interpolate() a static method and only pass the required input arguments.

In particular, I was afraid that serializing the RectilinearInterpolator instance and copying it to the Dask workers might cause the full data array to be loaded and copied everywhere (even though I am not 100% sure whether this is what would actually happen).

My refactored implementation gives _interpolate() access to the only data chunk that it should work on, so it looked "safer" to me? What do you think? But if you think we should really try to avoid this refactoring, I will give it a try to passing the instance to _interpolate() via map_complete_blocks..

)

if src_order != dims:
# Restore the interpolated result to the original
Expand Down Expand Up @@ -592,7 +621,7 @@ def __call__(self, sample_points, collapse_scalar=True):

sample_points = _canonical_sample_points(self._src_coords, sample_points)

data = self._src_cube.data
data = self._src_cube.core_data()
# Interpolate the cube payload.
interpolated_data = self._points(sample_points, data)

Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for @trexfeathers: review the tests.

Original file line number Diff line number Diff line change
Expand Up @@ -499,24 +499,37 @@ def test_orthogonal_cube_squash(self):
self.assertEqual(result_cube, non_collapsed_cube[0, ...])


class Test___call___real_data(ThreeDimCube):
def test_src_cube_data_loaded(self):
# If the source cube has real data when the interpolator is
# instantiated, then the interpolated result should also have
# real data.
self.assertFalse(self.cube.has_lazy_data())

# Perform interpolation and check the data is real.
interpolator = RectilinearInterpolator(
self.cube, ["latitude"], LINEAR, EXTRAPOLATE
)
res = interpolator([[1.5]])
self.assertFalse(res.has_lazy_data())


class Test___call___lazy_data(ThreeDimCube):
def test_src_cube_data_loaded(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test used to check that setting up the interpolator would trigger data loading in the source cube, in order to avoid that new interpolations would have to load the same data again and again (see #1222). I have modified the test in favour of a couple of tests that should make more sense for a lazy interpolator.

# RectilinearInterpolator operates using a snapshot of the source cube.
# If the source cube has lazy data when the interpolator is
# instantiated we want to make sure the source cube's data is
# loaded as a consequence of interpolation to avoid the risk
# of loading it again and again.
# instantiated, then the interpolated result should also have
# lazy data.

# Modify self.cube to have lazy data.
self.cube.data = as_lazy_data(self.data)
self.assertTrue(self.cube.has_lazy_data())

# Perform interpolation and check the data has been loaded.
# Perform interpolation and check the data is lazy..
interpolator = RectilinearInterpolator(
self.cube, ["latitude"], LINEAR, EXTRAPOLATE
)
interpolator([[1.5]])
self.assertFalse(self.cube.has_lazy_data())
res = interpolator([[1.5]])
self.assertTrue(res.has_lazy_data())


class Test___call___time(tests.IrisTest):
Expand Down
Loading