From 95add680bbf804db40cb3ee41f5192b024e4994c Mon Sep 17 00:00:00 2001 From: Miguel Andres-Martinez Date: Fri, 15 Nov 2024 10:58:20 +0100 Subject: [PATCH] add sorting dymensions --- src/pymorize/generic.py | 15 ++++++++++++ tests/fixtures/sample_rules.py | 44 ++++++++++++++++++++++++++++++++++ tests/unit/template.py | 9 +++++++ tests/unit/test_array_order.py | 13 ++++++++++ 4 files changed, 81 insertions(+) create mode 100644 tests/unit/template.py create mode 100644 tests/unit/test_array_order.py diff --git a/src/pymorize/generic.py b/src/pymorize/generic.py index d3209a0..80e3087 100644 --- a/src/pymorize/generic.py +++ b/src/pymorize/generic.py @@ -250,3 +250,18 @@ def trigger_compute(data, rule_spec, *args, **kwargs): return data.compute() # Data doesn't have a compute method, do nothing return data + + +def sort_dimensions(data, rule_spec): + """Sorts the dimensions of a DataArray based on the array_order in the rule_spec.""" + dryrun = rule_spec.get("dryrun", False) + + missing_dims = rule_spec.get("sort_dimensions_missing_dims", "raise") + + logger.info( + f"Transposing dimensions of data from {data.dims} to {rule_spec.array_order}" + ) + if not dryrun: + data = data.transpose(*rule_spec.array_order, missing_dims=missing_dims) + + return data diff --git a/tests/fixtures/sample_rules.py b/tests/fixtures/sample_rules.py index 716ea38..3088d0e 100644 --- a/tests/fixtures/sample_rules.py +++ b/tests/fixtures/sample_rules.py @@ -88,3 +88,47 @@ def rule_with_units(): ) r.data_request_variable = r.data_request_variables[0] return r + + +@pytest.fixture +def rule_with_unsorted_data(): + return Rule( + array_order=["time", "lat", "lon"], + inputs=[ + { + "path": "/some/files/containing/", + "pattern": "var1.*.nc", + }, + { + "path": "/some/other/files/containing/", + "pattern": r"var1_(?P\d{4}).nc", + }, + ], + cmor_variable="var1", + pipelines=["pymorize.pipeline.TestingPipeline"], + data_request_variables=[ + DataRequestVariable( + variable_id="var1", + unit="kg m-2 s-1", + description="Some description", + time_method="instant", + table="Some Table", + frequency="mon", + realms=["atmos"], + standard_name="some_standard_name", + cell_methods="time: mean", + cell_measures="area: areacella", + ) + ], + ) + + +@pytest.fixture +def dummy_array(): + import numpy + import xarray as xr + + return xr.DataArray( + numpy.random.rand(10, 10, 10), + dims=["lat", "lon", "time"], + ) diff --git a/tests/unit/template.py b/tests/unit/template.py new file mode 100644 index 0000000..cd8872b --- /dev/null +++ b/tests/unit/template.py @@ -0,0 +1,9 @@ +import numpy as np +import pytest +import xarray as xr + + +# @pytest.mark.parametrize("test_input", a_list) +def test_name(test_input): + """Docstrig""" + pass diff --git a/tests/unit/test_array_order.py b/tests/unit/test_array_order.py new file mode 100644 index 0000000..16a7bd1 --- /dev/null +++ b/tests/unit/test_array_order.py @@ -0,0 +1,13 @@ +import numpy as np +import pytest +import xarray as xr + + +# @pytest.mark.parametrize("test_input", a_list) +def test_sort_dimensions(dummy_array, rule_with_unsorted_data): + """Test to check that dimensions are sorted correctly""" + from pymorize.generic import sort_dimensions + + dummy_array = sort_dimensions(dummy_array, rule_with_unsorted_data) + + assert dummy_array.dims == tuple(rule_with_unsorted_data.array_order)