diff --git a/c3s_eqc_automatic_quality_control/diagnostics.py b/c3s_eqc_automatic_quality_control/diagnostics.py index 08771f5..56dc3bc 100644 --- a/c3s_eqc_automatic_quality_control/diagnostics.py +++ b/c3s_eqc_automatic_quality_control/diagnostics.py @@ -30,6 +30,8 @@ "annual_weighted_mean", "annual_weighted_std", "grid_cell_area", + "monthly_weighted_mean", + "monthly_weighted_std", "regrid", "rolling_weighted_filter", "seasonal_weighted_mean", @@ -239,6 +241,37 @@ def attrs_func(attrs: dict[str, Any]) -> dict[str, Any]: return _apply_attrs_func(coverage, obj, attrs_func) +def monthly_weighted_mean( + obj: xr.DataArray | xr.Dataset, + time_name: Hashable | None = None, + weights: xr.DataArray | bool = True, + **kwargs: Any, +) -> xr.DataArray | xr.Dataset: + """ + Calculate monthly weighted mean. + + Parameters + ---------- + obj: DataArray or Dataset + Input data + time_name: str, optional + Name of time coordinate + weights: DataArray, bool, default: True + Weights to apply: + - True: weights are the number of days in each month + - False: unweighted + - DataArray: custom weights + + Returns + ------- + DataArray or Dataset + Reduced object + """ + return _time_weighted.TimeWeighted(obj, time_name, weights).reduce( + "mean", "month", **kwargs + ) + + def seasonal_weighted_mean( obj: xr.DataArray | xr.Dataset, time_name: Hashable | None = None, @@ -270,6 +303,37 @@ def seasonal_weighted_mean( ) +def monthly_weighted_std( + obj: xr.DataArray | xr.Dataset, + time_name: Hashable | None = None, + weights: xr.DataArray | bool = True, + **kwargs: Any, +) -> xr.DataArray | xr.Dataset: + """ + Calculate monthly weighted std. + + Parameters + ---------- + obj: DataArray or Dataset + Input data + time_name: str, optional + Name of time coordinate + weights: DataArray, bool, default: True + Weights to apply: + - True: weights are the number of days in each month + - False: unweighted + - DataArray: custom weights + + Returns + ------- + DataArray or Dataset + Reduced object + """ + return _time_weighted.TimeWeighted(obj, time_name, weights).reduce( + "std", "month", **kwargs + ) + + def seasonal_weighted_std( obj: xr.DataArray | xr.Dataset, time_name: Hashable | None = None, diff --git a/tests/test_21_time_diagnostics.py b/tests/test_21_time_diagnostics.py index 4efbc9c..454f7b4 100644 --- a/tests/test_21_time_diagnostics.py +++ b/tests/test_21_time_diagnostics.py @@ -74,6 +74,17 @@ def test_time_weighted_std( actual = diagnostics.time_weighted_std(obj, weights=weights) xr.testing.assert_equal(expected, actual) + def test_monthly_weighted_mean( + self, obj: xr.DataArray | xr.Dataset, weights: bool + ) -> None: + if weights: + expected = (obj).groupby("time.month").map(weighted_mean) + expected = expected.where(expected != 0) + else: + expected = obj.groupby("time.month").mean("time") + actual = diagnostics.monthly_weighted_mean(obj, weights=weights) + xr.testing.assert_equal(expected, actual) + def test_seasonal_weighted_mean( self, obj: xr.DataArray | xr.Dataset, weights: bool ) -> None: @@ -85,6 +96,17 @@ def test_seasonal_weighted_mean( actual = diagnostics.seasonal_weighted_mean(obj, weights=weights) xr.testing.assert_equal(expected, actual) + def test_monthly_weighted_std( + self, obj: xr.DataArray | xr.Dataset, weights: bool + ) -> None: + if weights: + expected = obj.groupby("time.month").map(weighted_std) + expected = expected.where(expected != 0) + else: + expected = obj.groupby("time.month").std("time") + actual = diagnostics.monthly_weighted_std(obj, weights=weights) + xr.testing.assert_equal(expected, actual) + def test_seasonal_weighted_std( self, obj: xr.DataArray | xr.Dataset, weights: bool ) -> None: