diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 6e2d6049..3cceb830 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -121,7 +121,7 @@ def test_averages_for_yearly_time_series(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) # Test unweighted averages result = ds.temporal.average("ts", weighted=False) @@ -139,7 +139,7 @@ def test_averages_for_yearly_time_series(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_averages_for_monthly_time_series(self): # Set up dataset @@ -293,7 +293,7 @@ def test_averages_for_daily_time_series(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) # Test unweighted averages result = ds.temporal.average("ts", weighted=False) @@ -310,7 +310,7 @@ def test_averages_for_daily_time_series(self): "weighted": "False", }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_averages_for_hourly_time_series(self): ds = xr.Dataset( @@ -378,7 +378,7 @@ def test_averages_for_hourly_time_series(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) # Test unweighted averages result = ds.temporal.average("ts", weighted=False) @@ -396,7 +396,7 @@ def test_averages_for_hourly_time_series(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) class TestGroupAverage: @@ -489,14 +489,6 @@ def test_weighted_annual_averages(self): cftime.DatetimeGregorian(2001, 1, 1), ], ), - coords={ - "time": np.array( - [ - cftime.DatetimeGregorian(2000, 1, 1), - cftime.DatetimeGregorian(2001, 1, 1), - ], - ) - }, dims=["time"], attrs={ "axis": "T", @@ -540,14 +532,6 @@ def test_weighted_annual_averages_with_chunking(self): cftime.DatetimeGregorian(2001, 1, 1), ], ), - coords={ - "time": np.array( - [ - cftime.DatetimeGregorian(2000, 1, 1), - cftime.DatetimeGregorian(2001, 1, 1), - ], - ) - }, dims=["time"], attrs={ "axis": "T", @@ -571,6 +555,195 @@ def test_weighted_annual_averages_with_chunking(self): assert result.ts.attrs == expected.ts.attrs assert result.time.attrs == expected.time.attrs + def test_weighted_annual_averages_with_masked_data_and_min_weight_threshold_of_100_percent( + self, + ): + # Set up dataset + ds = xr.Dataset( + coords={ + "lat": [-90], + "lon": [0], + "time": xr.DataArray( + data=np.array( + [ + "2000-01-01T00:00:00.000000000", + "2000-02-01T00:00:00.000000000", + "2001-01-01T00:00:00.000000000", + "2001-02-01T00:00:00.000000000", + "2002-01-01T00:00:00.000000000", + ], + dtype="datetime64[ns]", + ), + dims=["time"], + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + } + ) + ds.time.encoding = {"calendar": "standard"} + + ds["time_bnds"] = xr.DataArray( + name="time_bnds", + data=np.array( + [ + ["2000-01-01T00:00:00.000000000", "2000-02-01T00:00:00.000000000"], + ["2000-02-01T00:00:00.000000000", "2000-03-01T00:00:00.000000000"], + ["2001-01-01T00:00:00.000000000", "2000-01-01T00:00:00.000000000"], + ["2001-02-01T00:00:00.000000000", "2000-03-01T00:00:00.000000000"], + ["2002-01-01T00:00:00.000000000", "2002-02-01T00:00:00.000000000"], + ], + dtype="datetime64[ns]", + ), + coords={"time": ds.time}, + dims=["time", "bnds"], + attrs={"xcdat_bounds": "True"}, + ) + + ds["ts"] = xr.DataArray( + data=np.array([[[2]], [[np.nan]], [[1]], [[1]], [[0.5]]]), + coords={"lat": ds.lat, "lon": ds.lon, "time": ds.time}, + dims=["time", "lat", "lon"], + ) + + # NOTE: If a cell has a missing value for any of the years, the average + # for that year should be masked with a min_weight threshold of 100%. + result = ds.temporal.group_average("ts", "year", min_weight=1.0) + expected = ds.copy() + expected = expected.drop_dims("time") + expected["ts"] = xr.DataArray( + name="ts", + data=np.array([[[np.nan]], [[1]], [[0.5]]]), + coords={ + "lat": expected.lat, + "lon": expected.lon, + "time": xr.DataArray( + data=np.array( + [ + cftime.DatetimeGregorian(2000, 1, 1), + cftime.DatetimeGregorian(2001, 1, 1), + cftime.DatetimeGregorian(2002, 1, 1), + ], + ), + dims=["time"], + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + }, + dims=["time", "lat", "lon"], + attrs={ + "test_attr": "test", + "operation": "temporal_avg", + "mode": "group_average", + "freq": "year", + "weighted": "True", + }, + ) + + xr.testing.assert_allclose(result, expected) + + def test_weighted_annual_averages_with_masked_data_and_min_weight_threshold_of_50_percent( + self, + ): + # Set up dataset + ds = xr.Dataset( + coords={ + "lat": [-90], + "lon": [0], + "time": xr.DataArray( + data=np.array( + [ + "2000-01-01T00:00:00.000000000", + "2000-02-01T00:00:00.000000000", + "2001-01-01T00:00:00.000000000", + "2001-02-01T00:00:00.000000000", + "2002-01-01T00:00:00.000000000", + ], + dtype="datetime64[ns]", + ), + dims=["time"], + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + } + ) + ds.time.encoding = {"calendar": "standard"} + + ds["time_bnds"] = xr.DataArray( + name="time_bnds", + data=np.array( + [ + ["2000-01-01T00:00:00.000000000", "2000-02-01T00:00:00.000000000"], + ["2000-02-01T00:00:00.000000000", "2000-03-01T00:00:00.000000000"], + ["2001-01-01T00:00:00.000000000", "2000-01-01T00:00:00.000000000"], + ["2001-02-01T00:00:00.000000000", "2000-03-01T00:00:00.000000000"], + ["2002-01-01T00:00:00.000000000", "2002-02-01T00:00:00.000000000"], + ], + dtype="datetime64[ns]", + ), + coords={"time": ds.time}, + dims=["time", "bnds"], + attrs={"xcdat_bounds": "True"}, + ) + + ds["ts"] = xr.DataArray( + data=np.array([[[2]], [[np.nan]], [[1]], [[1]], [[0.5]]]), + coords={"lat": ds.lat, "lon": ds.lon, "time": ds.time}, + dims=["time", "lat", "lon"], + ) + + # NOTE: The second cell of "ts" has missing data, but the first cell + # has more weight (due to more days in the month of Jan vs. Feb) so the + # average for the year is not masked with a min_weight threshold of 50%. + result = ds.temporal.group_average("ts", "year", min_weight=0.50) + expected = ds.copy() + expected = expected.drop_dims("time") + expected["ts"] = xr.DataArray( + name="ts", + data=np.array([[[2.0]], [[1]], [[0.5]]]), + coords={ + "lat": expected.lat, + "lon": expected.lon, + "time": xr.DataArray( + data=np.array( + [ + cftime.DatetimeGregorian(2000, 1, 1), + cftime.DatetimeGregorian(2001, 1, 1), + cftime.DatetimeGregorian(2002, 1, 1), + ], + ), + dims=["time"], + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + }, + dims=["time", "lat", "lon"], + attrs={ + "test_attr": "test", + "operation": "temporal_avg", + "mode": "group_average", + "freq": "year", + "weighted": "True", + }, + ) + + xr.testing.assert_allclose(result, expected) + def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self): ds = self.ds.copy() @@ -619,12 +792,20 @@ def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_averages_with_DJF_without_dropping_incomplete_seasons( self, ): ds = self.ds.copy() + ds["ts"] = xr.DataArray( + data=np.array( + [[[2.0]], [[1.0]], [[1.0]], [[1.0]], [[2.0]]], dtype="float64" + ), + coords={"time": self.ds.time, "lat": self.ds.lat, "lon": self.ds.lon}, + dims=["time", "lat", "lon"], + attrs={"test_attr": "test"}, + ) result = ds.temporal.group_average( "ts", @@ -670,7 +851,7 @@ def test_weighted_seasonal_averages_with_DJF_without_dropping_incomplete_seasons }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_averages_with_JFD(self): ds = self.ds.copy() @@ -698,17 +879,102 @@ def test_weighted_seasonal_averages_with_JFD(self): cftime.DatetimeGregorian(2001, 1, 1), ], ), - coords={ - "time": np.array( - [ - cftime.DatetimeGregorian(2000, 1, 1), - cftime.DatetimeGregorian(2000, 4, 1), - cftime.DatetimeGregorian(2000, 7, 1), - cftime.DatetimeGregorian(2000, 10, 1), - cftime.DatetimeGregorian(2001, 1, 1), - ], - ) + dims=["time"], + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", }, + ), + }, + dims=["time", "lat", "lon"], + attrs={ + "test_attr": "test", + "operation": "temporal_avg", + "mode": "group_average", + "freq": "season", + "weighted": "True", + "dec_mode": "JFD", + }, + ) + + xr.testing.assert_identical(result, expected) + + def test_weighted_seasonal_averages_with_JFD_with_min_weight_threshold_of_100_percent( + self, + ): + time = xr.DataArray( + data=np.array( + [ + "2000-01-16T12:00:00.000000000", + "2000-02-15T12:00:00.000000000", + "2000-03-16T12:00:00.000000000", + "2000-06-16T00:00:00.000000000", + "2000-12-16T00:00:00.000000000", + ], + dtype="datetime64[ns]", + ), + dims=["time"], + attrs={"axis": "T", "long_name": "time", "standard_name": "time"}, + ) + time.encoding = {"calendar": "standard"} + time_bnds = xr.DataArray( + name="time_bnds", + data=np.array( + [ + ["2000-01-01T00:00:00.000000000", "2000-02-01T00:00:00.000000000"], + ["2000-02-01T00:00:00.000000000", "2000-03-01T00:00:00.000000000"], + ["2000-03-01T00:00:00.000000000", "2000-04-01T00:00:00.000000000"], + ["2000-06-01T00:00:00.000000000", "2000-07-01T00:00:00.000000000"], + ["2000-11-01T00:00:00.000000000", "2001-01-01T00:00:00.000000000"], + ], + dtype="datetime64[ns]", + ), + coords={"time": time}, + dims=["time", "bnds"], + attrs={"xcdat_bounds": "True"}, + ) + + ds = xr.Dataset( + data_vars={"time_bnds": time_bnds}, + coords={"lat": [-90], "lon": [0], "time": time}, + ) + ds.time.attrs["bounds"] = "time_bnds" + + ds["ts"] = xr.DataArray( + data=np.array( + [[[np.nan]], [[1.0]], [[1.0]], [[1.0]], [[2.0]]], dtype="float64" + ), + coords={"time": self.ds.time, "lat": self.ds.lat, "lon": self.ds.lon}, + dims=["time", "lat", "lon"], + attrs={"test_attr": "test"}, + ) + + # NOTE: If a cell has a missing value for any of the seasons, the average + # for that season should be masked with a min_weight threshold of 100%. + result = ds.temporal.group_average( + "ts", + "season", + season_config={"dec_mode": "JFD"}, + min_weight=1.0, + ) + expected = ds.copy() + expected = expected.drop_dims("time") + expected["ts"] = xr.DataArray( + name="ts", + data=np.array([[[np.nan]], [[1.0]], [[1.0]]]), + coords={ + "lat": expected.lat, + "lon": expected.lon, + "time": xr.DataArray( + data=np.array( + [ + cftime.DatetimeGregorian(2000, 1, 1), + cftime.DatetimeGregorian(2000, 4, 1), + cftime.DatetimeGregorian(2000, 7, 1), + ], + ), dims=["time"], attrs={ "axis": "T", @@ -729,7 +995,7 @@ def test_weighted_seasonal_averages_with_JFD(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_custom_seasonal_averages(self): ds = self.ds.copy() @@ -787,7 +1053,7 @@ def test_weighted_custom_seasonal_averages(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_raises_error_with_incorrect_custom_seasons_argument(self): # Test raises error with non-3 letter strings @@ -873,7 +1139,7 @@ def test_weighted_monthly_averages(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_monthly_averages_with_masked_data(self): ds = self.ds.copy() @@ -924,7 +1190,103 @@ def test_weighted_monthly_averages_with_masked_data(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) + + def test_weighted_monthly_averages_with_masked_data_and_min_weight_threshold_of_100_percent( + self, + ): + # Set up dataset + ds = xr.Dataset( + coords={ + "lat": [-90], + "lon": [0], + "time": xr.DataArray( + data=np.array( + [ + "2000-01-01T00:00:00.000000000", + "2000-02-01T00:00:00.000000000", + "2000-02-15T00:00:00.000000000", + "2000-04-01T00:00:00.000000000", + "2001-02-01T00:00:00.000000000", + ], + dtype="datetime64[ns]", + ), + dims=["time"], + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + } + ) + ds.time.encoding = {"calendar": "standard"} + + ds["time_bnds"] = xr.DataArray( + name="time_bnds", + data=np.array( + [ + ["2000-01-01T00:00:00.000000000", "2000-02-01T00:00:00.000000000"], + ["2000-02-01T00:00:00.000000000", "2000-03-01T00:00:00.000000000"], + ["2000-02-01T00:00:00.000000000", "2000-03-01T00:00:00.000000000"], + ["2000-04-01T00:00:00.000000000", "2000-05-01T00:00:00.000000000"], + ["2001-02-01T00:00:00.000000000", "2001-03-01T00:00:00.000000000"], + ], + dtype="datetime64[ns]", + ), + coords={"time": ds.time}, + dims=["time", "bnds"], + attrs={"xcdat_bounds": "True"}, + ) + + ds["ts"] = xr.DataArray( + data=np.array([[[2]], [[np.nan]], [[1]], [[1]], [[1]]]), + coords={"lat": ds.lat, "lon": ds.lon, "time": ds.time}, + dims=["time", "lat", "lon"], + attrs={"test_attr": "test"}, + ) + + # NOTE: If a cell has a missing value for any of the months, the average + # for that month should be masked with a min_weight threshold of 100%. + result = ds.temporal.group_average("ts", "month", min_weight=0.55) + expected = ds.copy() + expected = expected.drop_dims("time") + expected["ts"] = xr.DataArray( + name="ts", + data=np.array([[[2.0]], [[np.nan]], [[1.0]], [[1.0]]]), + coords={ + "lat": expected.lat, + "lon": expected.lon, + "time": xr.DataArray( + data=np.array( + [ + cftime.DatetimeGregorian(2000, 1, 1), + cftime.DatetimeGregorian(2000, 2, 1), + cftime.DatetimeGregorian(2000, 4, 1), + cftime.DatetimeGregorian(2001, 2, 1), + ], + ), + dims=["time"], + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ), + }, + dims=["time", "lat", "lon"], + attrs={ + "test_attr": "test", + "operation": "temporal_avg", + "mode": "group_average", + "freq": "month", + "weighted": "True", + }, + ) + + xr.testing.assert_identical(result, expected) def test_weighted_daily_averages(self): ds = self.ds.copy() @@ -967,7 +1329,7 @@ def test_weighted_daily_averages(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_hourly_averages(self): ds = self.ds.copy() @@ -1011,7 +1373,7 @@ def test_weighted_hourly_averages(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) class TestClimatology: @@ -1105,7 +1467,7 @@ def test_subsets_climatology_based_on_reference_period(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_climatology_with_DJF(self): ds = self.ds.copy() @@ -1159,7 +1521,7 @@ def test_weighted_seasonal_climatology_with_DJF(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) @requires_dask def test_chunked_weighted_seasonal_climatology_with_DJF(self): @@ -1214,7 +1576,7 @@ def test_chunked_weighted_seasonal_climatology_with_DJF(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_climatology_with_JFD(self): ds = self.ds.copy() @@ -1265,7 +1627,7 @@ def test_weighted_seasonal_climatology_with_JFD(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_custom_seasonal_climatology(self): ds = self.ds.copy() @@ -1328,7 +1690,7 @@ def test_weighted_custom_seasonal_climatology(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_monthly_climatology(self): result = self.ds.temporal.climatology("ts", "month") @@ -1391,7 +1753,7 @@ def test_weighted_monthly_climatology(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_unweighted_monthly_climatology(self): result = self.ds.temporal.climatology("ts", "month", weighted=False) @@ -1453,7 +1815,7 @@ def test_unweighted_monthly_climatology(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_daily_climatology(self): result = self.ds.temporal.climatology("ts", "day", weighted=True) @@ -1515,7 +1877,7 @@ def test_weighted_daily_climatology(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_daily_climatology_drops_leap_days_with_matching_calendar(self): time = xr.DataArray( @@ -1606,7 +1968,7 @@ def test_weighted_daily_climatology_drops_leap_days_with_matching_calendar(self) }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_unweighted_daily_climatology(self): result = self.ds.temporal.climatology("ts", "day", weighted=False) @@ -1668,7 +2030,7 @@ def test_unweighted_daily_climatology(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) class TestDepartures: @@ -1810,7 +2172,7 @@ def test_seasonal_departures_relative_to_climatology_reference_period(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_monthly_departures_relative_to_climatology_reference_period_with_same_output_freq( self, @@ -1895,7 +2257,7 @@ def test_monthly_departures_relative_to_climatology_reference_period_with_same_o }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_departures_with_DJF(self): ds = self.ds.copy() @@ -1945,7 +2307,7 @@ def test_weighted_seasonal_departures_with_DJF(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self): ds = self.ds.copy() @@ -2021,7 +2383,7 @@ def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self): dims=["time_original"], ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_unweighted_seasonal_departures_with_DJF(self): ds = self.ds.copy() @@ -2071,7 +2433,7 @@ def test_unweighted_seasonal_departures_with_DJF(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_unweighted_seasonal_departures_with_JFD(self): ds = self.ds.copy() @@ -2121,7 +2483,7 @@ def test_unweighted_seasonal_departures_with_JFD(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_daily_departures_drops_leap_days_with_matching_calendar(self): time = xr.DataArray( @@ -2214,7 +2576,7 @@ def test_weighted_daily_departures_drops_leap_days_with_matching_calendar(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) class Test_GetWeights: diff --git a/xcdat/spatial.py b/xcdat/spatial.py index 2c50595a..0d2d978b 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -1,4 +1,6 @@ """Module containing geospatial averaging functions.""" +from __future__ import annotations + from functools import reduce from typing import ( Callable, diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 4e608729..5ff2c277 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -19,11 +19,7 @@ from xcdat._logger import _setup_custom_logger from xcdat.axis import get_dim_coords from xcdat.dataset import _get_data_var -from xcdat.utils import ( - _get_masked_weights, - _validate_min_weight, - mask_var_with_weight_threshold, -) +from xcdat.utils import _get_masked_weights, _validate_min_weight logger = _setup_custom_logger(__name__) @@ -166,7 +162,6 @@ def average( data_var: str, weighted: bool = True, keep_weights: bool = False, - min_weight: float | None = None, ): """ Returns a Dataset with the average of a data variable and the time @@ -208,10 +203,6 @@ def average( keep_weights : bool, optional If calculating averages using weights, keep the weights in the final dataset output, by default False. - min_weight : float | None, optional - Fraction of data coverage (i..e, weight) needed to return a - spatial average value. Value must range from 0 to 1, by default None - (equivalent to ``min_weight=0.0``). Returns ------- @@ -252,7 +243,6 @@ def average( freq, weighted=weighted, keep_weights=keep_weights, - min_weight=min_weight, ) def group_average( @@ -262,6 +252,7 @@ def group_average( weighted: bool = True, keep_weights: bool = False, season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG, + min_weight: float | None = None, ): """Returns a Dataset with average of a data variable by time group. @@ -340,6 +331,10 @@ def group_average( >>> ["Jul", "Aug", "Sep"], # "JulAugSep" >>> ["Oct", "Nov", "Dec"], # "OctNovDec" >>> ] + min_weight : float | None, optional + Fraction of data coverage (i..e, weight) needed to return a + temporal average value. Value must range from 0 to 1, by default + None ((equivalent to ``min_weight=0.0``). Returns ------- @@ -418,6 +413,7 @@ def group_average( weighted=weighted, keep_weights=keep_weights, season_config=season_config, + min_weight=min_weight, ) def climatology( @@ -937,31 +933,36 @@ def _set_arg_attrs( ): """Validates method arguments and sets them as object attributes. - Parameters - ---------- - mode : Mode - The mode for temporal averaging. - freq : Frequency - The frequency of time to group by. - weighted : bool - Calculate averages using weights. - season_config: Optional[SeasonConfigInput] - A dictionary for "season" frequency configurations. If configs for - predefined seasons are passed, configs for custom seasons are - ignored and vice versa, by default DEFAULT_SEASON_CONFIG. - min_weight : float | None, optional - Fraction of data coverage (i..e, weight) needed to return a - spatial average value. Value must range from 0 to 1, by default None - (equivalent to ``min_weight=0.0``). - - Raises - ------ - KeyError - If the Dataset does not have a time dimension. - ValueError - If an incorrect ``freq`` arg was passed. - ValueError - If an incorrect ``dec_mode`` arg was passed. + Parameters + ---------- + mode : Mode + The mode for temporal averaging. + freq : Frequency + The frequency of time to group by. + weighted : bool + Calculate averages using weights. + season_config: Optional[SeasonConfigInput] + A dictionary for "season" frequency configurations. If configs for + predefined seasons are passed, configs for custom seasons are + ignored and vice versa, by default DEFAULT_SEASON_CONFIG. + min_weight : float | None, optional + Fraction of data coverage (i..e, weight) needed to return a + <<<<<<< Updated upstream + spatial average value. Value must range from 0 to 1, by default None + (equivalent to ``min_weight=0.0``). + ======= + temporal average value. Value must range from 0 to 1, by default + None ((equivalent to ``min_weight=0.0``). + >>>>>>> Stashed changes + + Raises + ------ + KeyError + If the Dataset does not have a time dimension. + ValueError + If an incorrect ``freq`` arg was passed. + ValueError + If an incorrect ``dec_mode`` arg was passed. """ # General configuration attributes. if mode not in list(MODES): @@ -1198,11 +1199,6 @@ def _average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray: self._weights = self._get_weights(time_bounds) dv = dv.weighted(self._weights).mean(dim=self.dim) - - if self._min_weight > 0.0: - dv = mask_var_with_weight_threshold( - dv, self._weights, self._min_weight - ) else: dv = dv.mean(dim=self.dim) @@ -1235,23 +1231,33 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray: time_bounds = ds.bounds.get_bounds("T", var_key=data_var) self._weights = self._get_weights(time_bounds) - # Weight the data variable. - dv *= self._weights - - # Perform weighted average using the formula - # WA = sum(data*weights) / sum(masked weights). The denominator must - # be included to take into account zero weight for missing data. - masked_weights = _get_masked_weights(dv, self._weights) with xr.set_options(keep_attrs=True): - dv = self._group_data(dv).sum() / self._group_data(masked_weights).sum() + dv_weighted = dv * self._weights + + # Perform weighted average using the formula + # # WA = sum(data*weights) / sum(masked weights). + # The denominator must be included to take into account zero + # weight for missing data. + dv_group_sum = self._group_data(dv_weighted).sum() + weights_masked = _get_masked_weights(dv_weighted, self._weights) + weights_masked_group_sum = self._group_data(weights_masked).sum() + + dv_avg = dv_group_sum / weights_masked_group_sum + # Mask the data variable values with weights below the minimum + # weight threshold (if specified). if self._min_weight > 0.0: - dv = mask_var_with_weight_threshold(dv, self._weights, self._min_weight) + dv_avg = xr.where( + weights_masked_group_sum >= self._min_weight, + dv_avg, + np.nan, + keep_attrs=True, + ) # Restore the data variable's name. - dv.name = data_var + dv_avg.name = data_var else: - dv = self._group_data(dv).mean() + dv_avg = self._group_data(dv).mean() # After grouping and aggregating the data variable values, the # original time dimension is replaced with the grouped time dimension. @@ -1259,18 +1265,18 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray: # with "year_season". This dimension needs to be renamed back to # the original time dimension name before the data variable is added # back to the dataset so that the original name is preserved. - dv = dv.rename({self._labeled_time.name: self.dim}) + dv_avg = dv_avg.rename({self._labeled_time.name: self.dim}) # After grouping and aggregating, the grouped time dimension's # attributes are removed. Xarray's `keep_attrs=True` option only keeps # attributes for data variables and not their coordinates, so the # coordinate attributes have to be restored manually. - dv[self.dim].attrs = self._labeled_time.attrs - dv[self.dim].encoding = self._labeled_time.encoding + dv_avg[self.dim].attrs = self._labeled_time.attrs + dv_avg[self.dim].encoding = self._labeled_time.encoding - dv = self._add_operation_attrs(dv) + dv_avg = self._add_operation_attrs(dv_avg) - return dv + return dv_avg def _get_weights(self, time_bounds: xr.DataArray) -> xr.DataArray: """Calculates weights for a data variable using time bounds. diff --git a/xcdat/utils.py b/xcdat/utils.py index b18bd719..1f7f250e 100644 --- a/xcdat/utils.py +++ b/xcdat/utils.py @@ -174,7 +174,8 @@ def mask_var_with_weight_threshold( frac = weight_sum_masked / weight_sum_all # Nan out values that don't meet specified weight threshold. - dv_new = xr.where(frac >= min_weight, dv, np.nan) + dv_new = xr.where(frac >= min_weight, dv, np.nan, keep_attrs=True) + dv_new.name = dv.name return dv_new @@ -196,8 +197,7 @@ def _get_masked_weights(dv: xr.DataArray, weights: xr.DataArray) -> xr.DataArray xr.DataArray The masked weights. """ - masked_weights, _ = xr.broadcast(weights, dv) - masked_weights = xr.where(dv.copy().isnull(), 0.0, masked_weights) + masked_weights = xr.where(dv.copy().isnull(), 0.0, weights) return masked_weights