Skip to content

Commit

Permalink
tidy up some docstrings and type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
rob-taggart committed Dec 11, 2023
1 parent 017aa10 commit 49b7b79
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 34 deletions.
27 changes: 16 additions & 11 deletions src/scores/probability/crps_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,14 +753,22 @@ def crps_for_ensemble(
where X, X' are independent samples of the predictive distribution F and Y is the observation (possibly unknown).
Samples from F and Y are drawn from the fcst_sample_dim and obs_sample_dim respectively.
Other dimensions are broadcast using xr broadcast rules.
Args:
fcst: Forecast data.
obs: Observation data
weights: Weights for calculating a weighted mean of scores
reduce_dims: Dimensions to reduce. Can be "all" to reduce all dimensions.
preserve_dims: Dimensions to preserve. Can be "all" to preserve all dimensions.
special_fcst_dims: Dimension(s) in `fcst` that are reduced to calculate individual scores.
Must not appear as a dimension in `obs`, `weights`, `reduce_dims` or `preserve_dims`.
e.g. the ensemble member dimension if calculating CRPS for ensembles, or the
threshold dimension of calculating CRPS for CDFs.
"""
# all_dims =
# if weight not None:
# all_dims = set(weights)
# if method not in ["ecdf", "fair"]:
# raise ValueError("`method` must be one of 'ecdf' or 'fair'")
# if ensemble_member_dim in obs.dims or (weight is not None and weights.dims):
# raise ValueError("`ensemble_member_dim` cannot be a dimension of `obs` or `weights`")
if method not in ["ecdf", "fair"]:
raise ValueError("`method` must be one of 'ecdf' or 'fair'")

dims_for_mean = scores.utils.gather_dimensions2(fcst, obs, weights, reduce_dims, preserve_dims, ensemble_member_dim)

ensemble_member_dim1 = scores.utils.tmp_coord_name(fcst)

Expand All @@ -779,9 +787,6 @@ def crps_for_ensemble(
result = fcst_obs_term - fcst_spread_term

# apply weights and take means across specified dims
fcst_dims = [x for x in fcst.dims if x != ensemble_member_dim]
reduce_dims = scores.utils.gather_dimensions(fcst_dims, obs.dims, reduce_dims, preserve_dims) # type: ignore[assignment]
result = scores.functions.apply_weights(result, weights)
result = result.mean(dim=reduce_dims)
result = scores.functions.apply_weights(result, weights).mean(dim=dims_for_mean)

return result
49 changes: 27 additions & 22 deletions src/scores/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,34 +112,39 @@ def gather_dimensions( # pylint: disable=too-many-branches
return reduce_dims


def gather_dimensions2(fcst, obs, weights=None, reduce_dims=None, preserve_dims=None, special_fcst_dims=None):
def gather_dimensions2(
fcst: XarrayLike,
obs: XarrayLike,
weights: XarrayLike = None,
reduce_dims: FlexibleDimensionTypes = None,
preserve_dims: FlexibleDimensionTypes = None,
special_fcst_dims: FlexibleDimensionTypes = None,
) -> set[Hashable]:
"""
Performs a standard dimensions check for a function that calculates (mean) scores.
Returns a list of the dimensions to reduce.
Performs a standard dimensions check for inputs of a function that calculates (mean) scores.
Returns a set of the dimensions to reduce.
special_fcst_dims are dims in fcst that will be collapsed while calculating individual scores
(e.g. the threshold dimension of a CDF, or the ensemble member dimesnion, when calculating CRPS)
and is never present for the step of calculating mean scores.
Checks that:
- reduce_dims and preserve_dims are both not specified
- specified dims (reduce_dims or preserve_dims) are a subset of fcst.dims,
obs.dims and (if not None) weight.dims.
- special_fcst_dims are not in obs.dims or weights.dims
- specified dims with a value of "all" is handled correctly
- specified dims with a string value is handled correctly
Args:
fcst: Forecast data
obs: Observation data
weights: Weights for calculating a weighted mean of scores
reduce_dims: Dimensions to reduce. Can be "all" to reduce all dimensions.
preserve_dims: Dimensions to preserve. Can be "all" to preserve all dimensions.
special_fcst_dims: Dimension(s) in `fcst` that are reduced to calculate individual scores.
Must not appear as a dimension in `obs`, `weights`, `reduce_dims` or `preserve_dims`.
e.g. the ensemble member dimension if calculating CRPS for ensembles, or the
threshold dimension of calculating CRPS for CDFs.
Returns:
list of dimensions over which to take the mean once the checks are passed.
Set of dimensions over which to take the mean once the checks are passed.
Raises:
ValueError:
- when `preserve_dims and `reduce_dims` are both specified
- when `special_fcst_dims` is not a subset of `fcst.dims`
- when `obs.dims`, `weights.dims`, `reduce_dims` or `preserve_dims`
contains elements from `special_fcst_dims`
- when `preserve_dims and `reduce_dims` contain elements not among dimensions
of the data (`fcst`, `obs` or `weights`)
ValueError: when `preserve_dims and `reduce_dims` are both specified.
ValueError: when `special_fcst_dims` is not a subset of `fcst.dims`.
ValueError: when `obs.dims`, `weights.dims`, `reduce_dims` or `preserve_dims`
contains elements from `special_fcst_dims`.
ValueError: when `preserve_dims and `reduce_dims` contain elements not among dimensions
of the data (`fcst`, `obs` or `weights`).
"""
all_data_dims = set(fcst.dims).union(set(obs.dims))
if weights is not None:
Expand Down
9 changes: 8 additions & 1 deletion tests/probabilty/test_crps.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,8 +637,15 @@ def test_crps_for_ensemble():

result_ecdf = crps_for_ensemble(fcst, obs, "ens_member", method="ecdf", preserve_dims="all")
result_fair = crps_for_ensemble(fcst, obs, "ens_member", method="fair", preserve_dims="all")
result_weighted_mean = crps_for_ensemble(fcst, obs, "ens_member", method="ecdf", preserve_dims=None, weights=weight)
result_weighted_mean = crps_for_ensemble(fcst, obs, "ens_member", method="ecdf", weights=weight)

assert_dataarray_equal(result_ecdf, expected_ecdf, decimals=7)
assert_dataarray_equal(result_fair, expected_fair, decimals=7)
assert_dataarray_equal(result_weighted_mean, expected_weighted_mean, decimals=7)


def test_crps_for_ensemble_raises():
"""Tests `crps_for_ensemble` raises exception as expected."""
with pytest.raises(ValueError) as excinfo:
crps_for_ensemble(xr.DataArray(data=[1]), xr.DataArray(data=[1]), "ens_member", "unfair")
assert "`method` must be one of 'ecdf' or 'fair'" in str(excinfo.value)

0 comments on commit 49b7b79

Please sign in to comment.