diff --git a/tests/conftest.py b/tests/conftest.py index 21074024e..c8663297b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ from tests.fixtures.fixture_regridder import ( disk, + disk_layered, expected_results_centroid, expected_results_linear, expected_results_overlap, diff --git a/tests/fixtures/fixture_regridder.py b/tests/fixtures/fixture_regridder.py index a79c77ed2..75e804b69 100644 --- a/tests/fixtures/fixture_regridder.py +++ b/tests/fixtures/fixture_regridder.py @@ -23,6 +23,14 @@ def disk(): return xu.data.disk()["face_z"] +@pytest.fixture(scope="function") +def disk_layered(disk): + layer = xr.DataArray([1.0, 2.0, 3.0], coords={"layer": [1, 2, 3]}, dims=("layer",)) + # Disk is first in multiplication, to ensure that object is promoted to UgridDataArray + disk_layered = disk * layer + return disk_layered.transpose("layer", disk.ugrid.grid.face_dimension) + + @pytest.fixture(scope="function") def quads_0_25(): dx = 0.25 diff --git a/tests/test_regrid/test_regridder.py b/tests/test_regrid/test_regridder.py index 391d2873e..6eb0ed2ee 100644 --- a/tests/test_regrid/test_regridder.py +++ b/tests/test_regrid/test_regridder.py @@ -156,6 +156,25 @@ def test_regridder_from_weights(cls, disk, quads_1): assert new_result.equals(result) +@pytest.mark.parametrize( + "cls", + [ + CentroidLocatorRegridder, + OverlapRegridder, + RelativeOverlapRegridder, + BarycentricInterpolator, + ], +) +def test_regridder_from_weights_layered(cls, disk, disk_layered, quads_1): + square = quads_1 + regridder = cls(source=disk, target=square) + result = regridder.regrid(disk) + weights = regridder.weights + new_regridder = cls.from_weights(weights, target=square) + new_result = new_regridder.regrid(disk_layered) + assert np.array_equal(new_result.sel(layer=1).values, result.values, equal_nan=True) + + @pytest.mark.parametrize( "cls", [ diff --git a/tests/test_regrid/test_structured.py b/tests/test_regrid/test_structured.py index 7ddcc9573..8ed4a70df 100644 --- a/tests/test_regrid/test_structured.py +++ b/tests/test_regrid/test_structured.py @@ -266,45 +266,59 @@ def test_linear_weights_1d( # -------- # source target weight - # 1 1 20% - # 0 1 80% - # 2 2 20% - # 1 2 80% + # 1 1 80% + # 0 1 20% + # 2 2 80% + # 1 2 20% # -------- source, target, weights = grid_data_a_1d.linear_weights(grid_data_c_1d) sorter = np.argsort(target) assert np.array_equal(source[sorter], np.array([1, 0, 2, 1])) assert np.array_equal(target[sorter], np.array([1, 1, 2, 2])) - assert np.allclose(weights[sorter], np.array([0.2, 0.8, 0.2, 0.8])) + assert np.allclose(weights[sorter], np.array([0.8, 0.2, 0.8, 0.2])) # -------- # source target weight - # 0 1 10% - # 1 1 90% - # 0 2 60% - # 1 2 40% *reversed in output - # 1 3 10% - # 2 3 90% + # 0 1 90% + # 1 1 10% + # 0 2 40% + # 1 2 60% *reversed in output + # 1 3 90% + # 2 3 10% # -------- source, target, weights = grid_data_a_1d.linear_weights(grid_data_d_1d) sorter = np.argsort(target) assert np.array_equal(source[sorter], np.array([0, 1, 1, 0, 1, 2])) assert np.array_equal(target[sorter], np.array([1, 1, 2, 2, 3, 3])) - assert np.allclose(weights[sorter], np.array([0.1, 0.9, 0.4, 0.6, 0.1, 0.9])) + assert np.allclose(weights[sorter], np.array([0.9, 0.1, 0.6, 0.4, 0.9, 0.1])) # non-equidistant # -------- # source target weight - # 0 1 35% - # 1 1 65% - # 1 2 10% - # 2 2 90% + # 0 1 65% + # 1 1 35% + # 1 2 90% + # 2 2 10% # -------- source, target, weights = grid_data_a_1d.linear_weights(grid_data_e_1d) sorter = np.argsort(target) assert np.array_equal(source[sorter], np.array([0, 1, 1, 2])) assert np.array_equal(target[sorter], np.array([1, 1, 2, 2])) - assert np.allclose(weights[sorter], np.array([0.35, 0.65, 0.1, 0.9])) + assert np.allclose(weights[sorter], np.array([0.65, 0.35, 0.9, 0.1])) + + # 1-1 grid + # -------- + # source target weight + # 1 1 100% + # 0 1 0% + # 2 2 100% + # 1 2 0% + # -------- + source, target, weights = grid_data_b_1d.linear_weights(grid_data_b_1d) + sorter = np.argsort(target) + assert np.array_equal(source[sorter], np.array([1, 0, 2, 1])) + assert np.array_equal(target[sorter], np.array([1, 1, 2, 2])) + assert np.allclose(weights[sorter], np.array([1.0, 0.0, 1.0, 0.0])) def test_linear_weights_2d( @@ -342,4 +356,24 @@ def test_linear_weights_2d( assert np.array_equal( target[sorter], np.array([5, 5, 5, 5, 6, 6, 6, 6, 9, 9, 9, 9, 10, 10, 10, 10]) ) - assert np.allclose(weights[sorter], np.array([0.1, 0.4, 0.1, 0.4] * 4)) + assert np.allclose(weights[sorter], np.array([0.4, 0.1, 0.4, 0.1] * 4)) + + # 1-1 + # -------- + # source targets weight + # 5 4, 5, 8, 9 0% 100% 0% 0% + # 6 5, 6, 9,10 0% 100% 0% 0% + # 9 8, 9,12,13 0% 100% 0% 0% + # 10 9,10,13,14 0% 100% 0% 0% + # -------- + source, target, weights = grid_data_b_2d.linear_weights(grid_data_b_2d) + sorter = np.argsort(target) + assert np.array_equal( + source[sorter], np.array([5, 4, 9, 8, 6, 5, 10, 9, 9, 8, 13, 12, 10, 9, 14, 13]) + ) + assert np.array_equal( + target[sorter], np.array([5, 5, 5, 5, 6, 6, 6, 6, 9, 9, 9, 9, 10, 10, 10, 10]) + ) + assert np.allclose( + weights[sorter], np.array([1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]) + ) diff --git a/xugrid/regrid/regridder.py b/xugrid/regrid/regridder.py index 8ae06a22c..91e735e25 100644 --- a/xugrid/regrid/regridder.py +++ b/xugrid/regrid/regridder.py @@ -231,21 +231,27 @@ def to_dataset(self) -> xr.Dataset: @staticmethod def _csr_from_dataset(dataset: xr.Dataset) -> WeightMatrixCSR: + """ + variable n and nnz are expected to be scalar variable + """ return WeightMatrixCSR( dataset["__regrid_data"].values, dataset["__regrid_indices"].values, dataset["__regrid_indptr"].values, - dataset["__regrid_n"].values, - dataset["__regrid_nnz"].values, + dataset["__regrid_n"].values[()], + dataset["__regrid_nnz"].values[()], ) @staticmethod def _coo_from_dataset(dataset: xr.Dataset) -> WeightMatrixCOO: + """ + variable nnz is expected to be scalar variable + """ return WeightMatrixCOO( dataset["__regrid_data"].values, dataset["__regrid_row"].values, dataset["__regrid_col"].values, - dataset["__regrid_nnz"].values, + dataset["__regrid_nnz"].values[()], ) @abc.abstractclassmethod diff --git a/xugrid/regrid/structured.py b/xugrid/regrid/structured.py index c56d52533..e7ea47a60 100644 --- a/xugrid/regrid/structured.py +++ b/xugrid/regrid/structured.py @@ -238,7 +238,7 @@ def maybe_reverse_index(self, index: IntArray) -> IntArray: else: return index - def compute_distance_to_centroids( + def compute_linear_weights_to_centroids( self, other: "StructuredGrid1d", source_index: IntArray, target_index: IntArray ) -> Tuple[FloatArray, IntArray]: """ @@ -263,10 +263,10 @@ def compute_distance_to_centroids( source_midpoint_index = self.maybe_reverse_index(source_index) target_midpoints_index = other.maybe_reverse_index(target_index) neighbor = np.ones(target_midpoints_index.size, dtype=int) - # cases where midpoint target < midpoint source + # cases where midpoint target <= midpoint source condition = ( other.midpoints[target_midpoints_index] - < self.midpoints[source_midpoint_index] + <= self.midpoints[source_midpoint_index] ) neighbor[condition] = -neighbor[condition] @@ -275,15 +275,21 @@ def compute_distance_to_centroids( f"Coordinate {self.name} has size: {self.midpoints.size}. " "At least two points are required for interpolation." ) - weights = ( - other.midpoints[target_midpoints_index] - - self.midpoints[source_midpoint_index] - ) / ( - self.midpoints[source_midpoint_index + neighbor] - - self.midpoints[source_midpoint_index] + weights = 1 - ( + ( + other.midpoints[target_midpoints_index] + - self.midpoints[source_midpoint_index] + ) + / ( + self.midpoints[source_midpoint_index + neighbor] + - self.midpoints[source_midpoint_index] + ) ) - weights[weights < 0.0] = 0.0 - weights[weights > 1.0] = 1.0 + condition = np.logical_and(weights < 0.0, weights > 1.0) + if condition.any(): + raise ValueError( + f"Computed invalid weights for dimensions: {self.name} at coords: {self.midpoints[condition]}" + ) return weights, neighbor def sorted_output( @@ -374,7 +380,7 @@ def linear_weights( weights: np.array """ source_index, target_index = self.valid_nodes_within_bounds_and_extend(other) - weights, neighbour = self.compute_distance_to_centroids( + weights, neighbour = self.compute_linear_weights_to_centroids( other, source_index, target_index ) source_index, target_index, weights = self.centroids_to_linear_sets(