diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 834eb394..772b442f 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -154,7 +154,13 @@ class TemporalAccessor: def __init__(self, dataset: xr.Dataset): self._dataset: xr.Dataset = dataset - def average(self, data_var: str, weighted: bool = True, keep_weights: bool = False): + def average( + self, + data_var: str, + weighted: bool = True, + keep_weights: bool = False, + required_weight_pct: float | None = None, + ): """ Returns a Dataset with the average of a data variable and the time dimension removed. @@ -195,6 +201,9 @@ def average(self, data_var: str, weighted: bool = True, keep_weights: bool = Fal keep_weights : bool, optional If calculating averages using weights, keep the weights in the final dataset output, by default False. + required_weight_pct : float | None + Fraction of data coverage (i.e, weight) needed to return an + average value. Value must range from 0 to 1. Returns ------- @@ -230,7 +239,12 @@ def average(self, data_var: str, weighted: bool = True, keep_weights: bool = Fal freq = _infer_freq(self._dataset[self.dim]) return self._averager( - data_var, "average", freq, weighted=weighted, keep_weights=keep_weights + data_var, + "average", + freq, + weighted=weighted, + keep_weights=keep_weights, + required_weight_pct=required_weight_pct, ) def group_average( @@ -824,10 +838,13 @@ def _averager( keep_weights: bool = False, reference_period: Optional[Tuple[str, str]] = None, season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG, + required_weight_pct: float | None = None, ) -> xr.Dataset: """Averages a data variable based on the averaging mode and frequency.""" ds = self._dataset.copy() - self._set_arg_attrs(mode, freq, weighted, reference_period, season_config) + self._set_arg_attrs( + mode, freq, weighted, reference_period, season_config, required_weight_pct + ) # Preprocess the dataset based on method argument values. ds = self._preprocess_dataset(ds) @@ -908,6 +925,7 @@ def _set_arg_attrs( weighted: bool, reference_period: Optional[Tuple[str, str]] = None, season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG, + required_weight_pct: float | None = None, ): """Validates method arguments and sets them as object attributes. @@ -923,6 +941,9 @@ def _set_arg_attrs( A dictionary for "season" frequency configurations. If configs for predefined seasons are passed, configs for custom seasons are ignored and vice versa, by default DEFAULT_SEASON_CONFIG. + required_weight_pct : float | None + Fraction of data coverage (i.e, weight) needed to return an + average value. Value must range from 0 to 1. Raises ------ @@ -950,6 +971,7 @@ def _set_arg_attrs( self._mode = mode self._freq = freq self._weighted = weighted + self._required_weight_pct = self._set_required_weight_pct(required_weight_pct) self._reference_period = None if reference_period is not None: @@ -1208,14 +1230,18 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray: # achieve this, first broadcast the one-dimensional (temporal # dimension) shape of the `weights` DataArray to the # multi-dimensional shape of its corresponding data variable. - weights, _ = xr.broadcast(self._weights, dv) - weights = xr.where(dv.copy().isnull(), 0.0, weights) + masked_weights, _ = xr.broadcast(self._weights, dv) + masked_weights = xr.where(dv.copy().isnull(), 0.0, masked_weights) # Perform weighted average using the formula # WA = sum(data*weights) / sum(weights). The denominator must be # included to take into account zero weight for missing data. with xr.set_options(keep_attrs=True): - dv = self._group_data(dv).sum() / self._group_data(weights).sum() + dv = self._group_data(dv).sum() / self._group_data(masked_weights).sum() + + # Apply the weight threshold using the required percentage (if set). + if self._required_weight_pct > 0.0: + dv = self._apply_weight_threshold(dv, masked_weights) # Restore the data variable's name. dv.name = data_var @@ -1300,6 +1326,51 @@ def _get_weights(self, time_bounds: xr.DataArray) -> xr.DataArray: return weights + def _set_required_weight_pct(self, required_weight_pct: float | None) -> float: + if required_weight_pct is None: + required_weight_pct = 0.0 + elif required_weight_pct < 0.0: + raise ValueError( + "required_weight argment is less than zero. " + "required_weight must be between 0 and 1." + ) + elif required_weight_pct > 1.0: + raise ValueError( + "required_weight argment is greater than zero. " + "required_weight must be between 0 and 1." + ) + + return required_weight_pct + + def _apply_weight_threshold( + self, dv: xr.DataArray, masked_weights: xr.DataArray + ) -> xr.DataArray: + """Nan out values that don't meet the weight threshold percentage. + + Parameters + ---------- + dv : xr.DataArray + The weighted variable. + masked_weights : xr.DataArray + The weights with missing values masked with 0.0. + + Returns + ------- + xr.DataArray + The weighted variable with an weight threshold percentage. + """ + # Sum all weights, including zero for missing values. + weight_sum_all = self._weights.sum(dim=self.dim) + weight_sum_masked = masked_weights.sum(dim=self.dim) # type: ignore + + # Get fraction of the available weight. + frac = weight_sum_masked / weight_sum_all + + # Nan out values that don't meet specified weight threshold. + dv_new = xr.where(frac >= self._required_weight_pct, dv, np.nan) + + return dv_new + def _group_data(self, data_var: xr.DataArray) -> DataArrayGroupBy: """Groups a data variable.