Skip to content

Commit

Permalink
Add _apply_weight_threshold() method
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Jul 30, 2024
1 parent cdae826 commit f84fe1f
Showing 1 changed file with 77 additions and 6 deletions.
83 changes: 77 additions & 6 deletions xcdat/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,13 @@ class TemporalAccessor:
def __init__(self, dataset: xr.Dataset):
self._dataset: xr.Dataset = dataset

def average(self, data_var: str, weighted: bool = True, keep_weights: bool = False):
def average(
self,
data_var: str,
weighted: bool = True,
keep_weights: bool = False,
required_weight_pct: float | None = None,
):
"""
Returns a Dataset with the average of a data variable and the time
dimension removed.
Expand Down Expand Up @@ -195,6 +201,9 @@ def average(self, data_var: str, weighted: bool = True, keep_weights: bool = Fal
keep_weights : bool, optional
If calculating averages using weights, keep the weights in the
final dataset output, by default False.
required_weight_pct : float | None
Fraction of data coverage (i.e, weight) needed to return an
average value. Value must range from 0 to 1.
Returns
-------
Expand Down Expand Up @@ -230,7 +239,12 @@ def average(self, data_var: str, weighted: bool = True, keep_weights: bool = Fal
freq = _infer_freq(self._dataset[self.dim])

return self._averager(
data_var, "average", freq, weighted=weighted, keep_weights=keep_weights
data_var,
"average",
freq,
weighted=weighted,
keep_weights=keep_weights,
required_weight_pct=required_weight_pct,
)

def group_average(
Expand Down Expand Up @@ -824,10 +838,13 @@ def _averager(
keep_weights: bool = False,
reference_period: Optional[Tuple[str, str]] = None,
season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG,
required_weight_pct: float | None = None,
) -> xr.Dataset:
"""Averages a data variable based on the averaging mode and frequency."""
ds = self._dataset.copy()
self._set_arg_attrs(mode, freq, weighted, reference_period, season_config)
self._set_arg_attrs(
mode, freq, weighted, reference_period, season_config, required_weight_pct
)

# Preprocess the dataset based on method argument values.
ds = self._preprocess_dataset(ds)
Expand Down Expand Up @@ -908,6 +925,7 @@ def _set_arg_attrs(
weighted: bool,
reference_period: Optional[Tuple[str, str]] = None,
season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG,
required_weight_pct: float | None = None,
):
"""Validates method arguments and sets them as object attributes.
Expand All @@ -923,6 +941,9 @@ def _set_arg_attrs(
A dictionary for "season" frequency configurations. If configs for
predefined seasons are passed, configs for custom seasons are
ignored and vice versa, by default DEFAULT_SEASON_CONFIG.
required_weight_pct : float | None
Fraction of data coverage (i.e, weight) needed to return an
average value. Value must range from 0 to 1.
Raises
------
Expand Down Expand Up @@ -950,6 +971,7 @@ def _set_arg_attrs(
self._mode = mode
self._freq = freq
self._weighted = weighted
self._required_weight_pct = self._set_required_weight_pct(required_weight_pct)

self._reference_period = None
if reference_period is not None:
Expand Down Expand Up @@ -1208,14 +1230,18 @@ def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
# achieve this, first broadcast the one-dimensional (temporal
# dimension) shape of the `weights` DataArray to the
# multi-dimensional shape of its corresponding data variable.
weights, _ = xr.broadcast(self._weights, dv)
weights = xr.where(dv.copy().isnull(), 0.0, weights)
masked_weights, _ = xr.broadcast(self._weights, dv)
masked_weights = xr.where(dv.copy().isnull(), 0.0, masked_weights)

# Perform weighted average using the formula
# WA = sum(data*weights) / sum(weights). The denominator must be
# included to take into account zero weight for missing data.
with xr.set_options(keep_attrs=True):
dv = self._group_data(dv).sum() / self._group_data(weights).sum()
dv = self._group_data(dv).sum() / self._group_data(masked_weights).sum()

# Apply the weight threshold using the required percentage (if set).
if self._required_weight_pct > 0.0:
dv = self._apply_weight_threshold(dv, masked_weights)

# Restore the data variable's name.
dv.name = data_var
Expand Down Expand Up @@ -1300,6 +1326,51 @@ def _get_weights(self, time_bounds: xr.DataArray) -> xr.DataArray:

return weights

def _set_required_weight_pct(self, required_weight_pct: float | None) -> float:
if required_weight_pct is None:
required_weight_pct = 0.0
elif required_weight_pct < 0.0:
raise ValueError(
"required_weight argment is less than zero. "
"required_weight must be between 0 and 1."
)
elif required_weight_pct > 1.0:
raise ValueError(
"required_weight argment is greater than zero. "
"required_weight must be between 0 and 1."
)

return required_weight_pct

def _apply_weight_threshold(
self, dv: xr.DataArray, masked_weights: xr.DataArray
) -> xr.DataArray:
"""Nan out values that don't meet the weight threshold percentage.
Parameters
----------
dv : xr.DataArray
The weighted variable.
masked_weights : xr.DataArray
The weights with missing values masked with 0.0.
Returns
-------
xr.DataArray
The weighted variable with an weight threshold percentage.
"""
# Sum all weights, including zero for missing values.
weight_sum_all = self._weights.sum(dim=self.dim)
weight_sum_masked = masked_weights.sum(dim=self.dim) # type: ignore

# Get fraction of the available weight.
frac = weight_sum_masked / weight_sum_all

# Nan out values that don't meet specified weight threshold.
dv_new = xr.where(frac >= self._required_weight_pct, dv, np.nan)

return dv_new

def _group_data(self, data_var: xr.DataArray) -> DataArrayGroupBy:
"""Groups a data variable.
Expand Down

0 comments on commit f84fe1f

Please sign in to comment.