Skip to content

Commit

Permalink
FEAT: crps_for_ensemble
Browse files Browse the repository at this point in the history
  • Loading branch information
rob-taggart committed Dec 19, 2023
1 parent ee69685 commit 2285409
Show file tree
Hide file tree
Showing 9 changed files with 460 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Currently Included Metrics:

| continuous | probability | categorical | statistical tests |
| ---------- | ----------- | ----------- | ----------------- |
| MAE, MSE, RMSE, Flip Flop Index, Quantile Score | CRPS, ROC | FIRM, POD, POFD | Diebold Mariano (with the Harvey et al. 1997 and the Hering and Genton 2011 modifications) |
| MAE, MSE, RMSE, Flip Flop Index, Quantile Score | CRPS for CDF, CRPS for ensemble, ROC | FIRM, POD, POFD | Diebold Mariano (with the Harvey et al. 1997 and the Hering and Genton 2011 modifications) |


**Notice -- This repository is currently undergoing initial construction and maintenance. It is not yet recommended for use. This notice will be removed after the first feature release. In the meantime, please feel free to look around, and don't hesitate to get in touch with any questions (see the contributing guide for how).**
Expand Down
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
.. autofunction:: scores.probability.crps_cdf
.. autofunction:: scores.probability.adjust_fcst_for_crps
.. autofunction:: scores.probability.crps_cdf_brier_decomposition
.. autofunction:: scores.probability.crps_for_ensemble
.. autofunction:: scores.probability.murphy_score
.. autofunction:: scores.probability.murphy_thetas
.. autofunction:: scores.probability.roc_curve_data
Expand Down
2 changes: 1 addition & 1 deletion docs/summary_table_of_scores.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
| continuous | probability | categorical | statistical tests |
| ---------- | ----------- | ----------- | ----------- |
| MAE, MSE, RMSE, Murphy score, Quantile Score | CRPS_CDF, Murphy score, ROC | FIRM, POD, POFD, Flip Flop Index | Diebold Mariano (with the Harvey et al. 1997 and the Hering and Genton 2011 modifications)|
| MAE, MSE, RMSE, Murphy score, Quantile Score | CRPS for CDF, CRPS for ensemble, Murphy score, ROC | FIRM, POD, POFD, Flip Flop Index | Diebold Mariano (with the Harvey et al. 1997 and the Hering and Genton 2011 modifications)|
1 change: 1 addition & 0 deletions src/scores/probability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
adjust_fcst_for_crps,
crps_cdf,
crps_cdf_brier_decomposition,
crps_for_ensemble,
)
from scores.probability.roc_impl import roc_curve_data
89 changes: 87 additions & 2 deletions src/scores/probability/crps_impl.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""
This module supports the implementation of the CRPS scoring function, drawing from additional functions.
The primary method, `crps_cdf` is imported into the probability module to be part of the probability API
The two primary methods, `crps_cdf` and `crps_for_ensemble` are imported into
the probability module to be part of the probability API.
"""
from collections.abc import Iterable
from typing import Literal, Optional
from typing import Literal, Optional, Sequence

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -735,3 +736,87 @@ def crps_step_threshold_weight(
weight = 1 - weight

return weight


def crps_for_ensemble(
fcst: xr.DataArray,
obs: xr.DataArray,
ensemble_member_dim: str,
method: Literal["ecdf", "fair"] = "ecdf",
reduce_dims: Optional[Sequence[str]] = None,
preserve_dims: Optional[Sequence[str]] = None,
weights: xr.DataArray = None,
) -> xr.DataArray:
"""
Calculates the continuous ranked probability score (CRPS) given an ensemble of forecasts.
An ensemble of forecasts can also be thought of as a random sample from the predictive
distribution.
Given an observation y, and ensemble member values {x_i} (for 1 <= i <= M), the CRPS is
calculated by the formula
CRPS({x_i}, y) = (1 / M) * sum(|x_i - y|) - (1 / 2 * K) * sum(|x_i - x_j|),
where the first sum is iterated over 1 <= i <= M and the second sum is iterated over
1 <= i <= M and 1 <= j <= M.
The value of the constant K in this formula depends on the method.
- If `method="ecdf"` then K = M ** 2. In this case the CRPS value returned is
the exact CRPS value for the emprical cumulation distribution function
constructed using the ensemble values.
- If `method="fair"` then K = M * (M - 1). In this case the CRPS value returned
is the approximated CRPS where the ensemble values can be interpreted as a
random sample from the underlying predictive distribution. This interpretation
stems from the formula CRPS(F, Y) = E|X - Y| - E|X - X'|/2, where X and X'
are independent samples of the predictive distribution F, Y is the observation
(possibly unknown) and E denotes the expectation. This choice of K gives an
unbiased estimate for the second expectation.
Args:
fcst: Forecast data. Must have a dimension `ensemble_member_dim`.
obs: Observation data.
ensemble_member_dim: the dimension that specifies the ensemble member or the sample
from the predictive distribution.
method: Either "ecdf" or "fair".
reduce_dims: Dimensions to reduce. Can be "all" to reduce all dimensions.
preserve_dims: Dimensions to preserve. Can be "all" to preserve all dimensions.
weights: Weights for calculating a weighted mean of individual scores.
Returns:
xarray object of (weighted mean) CRPS values.
Raises:
ValueError: when method is not one of "ecdf" or "fair".
References:
- C. Ferro (2014), "Fair scores for ensemble forecasts", Q J R Meteorol Soc
140(683):1917-1923.
- T. Gneiting T and A. Raftery (2007), "Strictly proper scoring rules, prediction,
and estimation", J Am Stat Assoc, 102(477):359-37.
- M. Zamo and P. Naveau (2018), "Estimation of the Continuous Ranked Probability
Score with Limited Information and Applications to Ensemble Weather Forecasts",
Math Geosci 50:209-234, https://doi.org/10.1007/s11004-017-9709-7
"""
if method not in ["ecdf", "fair"]:
raise ValueError("`method` must be one of 'ecdf' or 'fair'")

dims_for_mean = scores.utils.gather_dimensions2(fcst, obs, weights, reduce_dims, preserve_dims, ensemble_member_dim)

ensemble_member_dim1 = scores.utils.tmp_coord_name(fcst)

# calculate forecast spread contribution
fcst_copy = fcst.rename({ensemble_member_dim: ensemble_member_dim1})

fcst_spread_term = np.abs(fcst - fcst_copy).sum([ensemble_member_dim, ensemble_member_dim1])
ens_count = fcst.count(ensemble_member_dim)
if method == "ecdf":
fcst_spread_term = fcst_spread_term / (2 * ens_count**2)
if method == "fair":
fcst_spread_term = fcst_spread_term / (2 * ens_count * (ens_count - 1))

# calculate final CRPS for each forecast case
fcst_obs_term = np.abs(fcst - obs).mean(ensemble_member_dim)
result = fcst_obs_term - fcst_spread_term

# apply weights and take means across specified dims
result = scores.functions.apply_weights(result, weights).mean(dim=dims_for_mean)

return result
118 changes: 118 additions & 0 deletions src/scores/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,23 @@
It is ambiguous how to proceed therefore an exception has been raised instead.
"""

ERROR_SPECIFIED_NONPRESENT_PRESERVE_DIMENSION2 = """
You are requesting to preserve a dimension which does not appear in your data
(fcst, obs or weights). It is ambiguous how to proceed therefore an exception has been
raised instead.
"""

ERROR_SPECIFIED_NONPRESENT_REDUCE_DIMENSION = """
You are requesting to reduce a dimension which does not appear in your data (fcst or obs).
It is ambiguous how to proceed therefore an exception has been raised instead.
"""

ERROR_SPECIFIED_NONPRESENT_REDUCE_DIMENSION2 = """
You are requesting to reduce a dimension which does not appear in your data
(fcst, obs or weights). It is ambiguous how to proceed therefore an exception has been
raised instead.
"""

ERROR_OVERSPECIFIED_PRESERVE_REDUCE = """
You have specified both preserve_dims and reduce_dims. This method doesn't know how
to properly interpret that, therefore an exception has been raised.
Expand Down Expand Up @@ -112,6 +124,95 @@ def gather_dimensions( # pylint: disable=too-many-branches
return reduce_dims


def gather_dimensions2(
fcst: xr.DataArray,
obs: xr.DataArray,
weights: xr.DataArray = None,
reduce_dims: FlexibleDimensionTypes = None,
preserve_dims: FlexibleDimensionTypes = None,
special_fcst_dims: FlexibleDimensionTypes = None,
) -> set[Hashable]:
"""
Performs standard dimensions checks for inputs of functions that calculate (mean) scores.
Returns a set of the dimensions to reduce.
Args:
fcst: Forecast data
obs: Observation data
weights: Weights for calculating a weighted mean of scores
reduce_dims: Dimensions to reduce. Can be "all" to reduce all dimensions.
preserve_dims: Dimensions to preserve. Can be "all" to preserve all dimensions.
special_fcst_dims: Dimension(s) in `fcst` that are reduced to calculate individual scores.
Must not appear as a dimension in `obs`, `weights`, `reduce_dims` or `preserve_dims`.
e.g. the ensemble member dimension if calculating CRPS for ensembles, or the
threshold dimension of calculating CRPS for CDFs.
Returns:
Set of dimensions over which to take the mean once the checks are passed.
Raises:
ValueError: when `preserve_dims and `reduce_dims` are both specified.
ValueError: when `special_fcst_dims` is not a subset of `fcst.dims`.
ValueError: when `obs.dims`, `weights.dims`, `reduce_dims` or `preserve_dims`
contains elements from `special_fcst_dims`.
ValueError: when `preserve_dims and `reduce_dims` contain elements not among dimensions
of the data (`fcst`, `obs` or `weights`).
"""
all_data_dims = set(fcst.dims).union(set(obs.dims))
if weights is not None:
all_data_dims = all_data_dims.union(set(weights.dims))

# all_scoring_dims is the set of dims remaining after individual scores are computed.
all_scoring_dims = all_data_dims.copy()

# Handle error conditions related to specified dimensions
if preserve_dims is not None and reduce_dims is not None:
raise ValueError(ERROR_OVERSPECIFIED_PRESERVE_REDUCE)

specified_dims = preserve_dims or reduce_dims

if specified_dims == "all":
if "all" in all_data_dims:
warnings.warn(WARN_ALL_DATA_CONFLICT_MSG)
elif specified_dims is not None:
if isinstance(specified_dims, str):
specified_dims = [specified_dims]

# check that special_fcst_dims are in fcst.dims only
if special_fcst_dims is not None:
if isinstance(special_fcst_dims, str):
special_fcst_dims = [special_fcst_dims]
if not set(special_fcst_dims).issubset(set(fcst.dims)):
raise ValueError("`special_fcst_dims` must be a subset of `fcst` dimensions")
if len(set(obs.dims).intersection(set(special_fcst_dims))) > 0:
raise ValueError("`obs.dims` must not contain any `special_fcst_dims`")
if weights is not None:
if len(set(weights.dims).intersection(set(special_fcst_dims))) > 0:
raise ValueError("`weights.dims` must not contain any `special_fcst_dims`")
if specified_dims is not None and specified_dims != "all":
if len(set(specified_dims).intersection(set(special_fcst_dims))) > 0:
raise ValueError("`reduce_dims` and `preserve_dims` must not contain any `special_fcst_dims`")
# remove special_fcst_dims from all_scoring_dims
all_scoring_dims = all_scoring_dims.difference(set(special_fcst_dims))

if specified_dims is not None and specified_dims != "all":
if not set(specified_dims).issubset(all_scoring_dims):
if preserve_dims is not None:
raise ValueError(ERROR_SPECIFIED_NONPRESENT_PRESERVE_DIMENSION2)
raise ValueError(ERROR_SPECIFIED_NONPRESENT_REDUCE_DIMENSION2)

# all errors have been captured, so now return list of dims to reduce
if specified_dims is None:
return all_scoring_dims
if reduce_dims is not None:
if reduce_dims == "all":
return all_scoring_dims
return set(specified_dims)
if preserve_dims == "all":
return set([])
return all_scoring_dims.difference(set(specified_dims))


def dims_complement(data, dims=None) -> list[str]:
"""Returns the complement of data.dims and dims
Expand Down Expand Up @@ -215,3 +316,20 @@ def check_dims(xr_data: XarrayLike, expected_dims: Sequence[str], mode: Optional
f"Dimensions {list(xr_data[data_var].dims)} of data variable "
f"'{data_var}' are not {mode} to the dimensions {sorted(dims_set)}"
)


def tmp_coord_name(xr_data: xr.DataArray) -> str:
"""
Generates a temporary coordinate name that is not among the coordinate or dimension
names of `xr_data`.
Args:
xr_data: Input xarray data array
Returns:
A string which is the concatenation of 'new' with all coordinate and
dimension names in the input array.
"""
all_names = ["new"] + list(xr_data.dims) + list(xr_data.coords)
result = "".join(all_names)
return result
30 changes: 30 additions & 0 deletions tests/probabilty/crps_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,3 +562,33 @@
)

EXP_CRPS_BD2 = xr.merge([EXP_TOTAL_CRPS_BD2, EXP_UNDER_CRPS_BD2, EXP_OVER_CRPS_BD2])

# test data for CRPS for ensembles

DA_FCST_CRPSENS = xr.DataArray(
data=[[0.0, 4, 3, 7], [1, -1, 2, 4], [0, 1, 4, np.nan], [2, 3, 4, 1], [2, np.nan, np.nan, np.nan]],
dims=["stn", "ens_member"],
coords={"stn": [101, 102, 103, 104, 105], "ens_member": [1, 2, 3, 4]},
)
DA_OBS_CRPSENS = xr.DataArray(data=[2.0, 3, 1, np.nan, 4], dims=["stn"], coords={"stn": [101, 102, 103, 104, 105]})
DA_WT_CRPSENS = xr.DataArray(data=[1, 2, 1, 0, 2], dims=["stn"], coords={"stn": [101, 102, 103, 104, 105]})

# first and second (spread) terms from crps for ensembles, "ecdf" and "fair" methods
FIRST_TERM = xr.DataArray(
data=[10 / 4, 8 / 4, 4 / 3, np.nan, 2], dims=["stn"], coords={"stn": [101, 102, 103, 104, 105]}
)
SPREAD_ECDF = xr.DataArray(
data=[(14 + 8 + 8 + 14) / 32, (6 + 10 + 6 + 10) / 32, (5 + 4 + 7) / 18, np.nan, 0],
dims=["stn"],
coords={"stn": [101, 102, 103, 104, 105]},
)
SPREAD_FAIR = xr.DataArray(
data=[(14 + 8 + 8 + 14) / 24, (6 + 10 + 6 + 10) / 24, (5 + 4 + 7) / 12, np.nan, np.nan],
dims=["stn"],
coords={"stn": [101, 102, 103, 104, 105]},
)

# expected results
EXP_CRPSENS_ECDF = FIRST_TERM - SPREAD_ECDF
EXP_CRPSENS_FAIR = FIRST_TERM - SPREAD_FAIR
EXP_CRPSENS_WT = (EXP_CRPSENS_ECDF * DA_WT_CRPSENS).mean("stn")
45 changes: 45 additions & 0 deletions tests/probabilty/test_crps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import dask.array
import numpy as np
import pytest
import xarray as xr

from scores.probability import (
adjust_fcst_for_crps,
crps_cdf,
crps_cdf_brier_decomposition,
crps_for_ensemble,
)
from scores.probability.crps_impl import (
crps_cdf_exact,
Expand Down Expand Up @@ -604,3 +606,46 @@ def test_crps_cdf_brier_decomposition(dims, expected):
crps_test_data.DA_FCST_CRPS_BD, crps_test_data.DA_OBS_CRPS_BD, "x", preserve_dims=dims
)
assert_dataset_equal(result, expected, decimals=7)


def test_crps_for_ensemble():
"""Tests `crps_for_ensemble` returns as expected."""
result_ecdf = crps_for_ensemble(
crps_test_data.DA_FCST_CRPSENS, crps_test_data.DA_OBS_CRPSENS, "ens_member", method="ecdf", preserve_dims="all"
)
result_fair = crps_for_ensemble(
crps_test_data.DA_FCST_CRPSENS, crps_test_data.DA_OBS_CRPSENS, "ens_member", method="fair", preserve_dims="all"
)
result_weighted_mean = crps_for_ensemble(
crps_test_data.DA_FCST_CRPSENS,
crps_test_data.DA_OBS_CRPSENS,
"ens_member",
method="ecdf",
weights=crps_test_data.DA_WT_CRPSENS,
)

assert_dataarray_equal(result_ecdf, crps_test_data.EXP_CRPSENS_ECDF, decimals=7)
assert_dataarray_equal(result_fair, crps_test_data.EXP_CRPSENS_FAIR, decimals=7)
assert_dataarray_equal(result_weighted_mean, crps_test_data.EXP_CRPSENS_WT, decimals=7)


def test_crps_for_ensemble_raises():
"""Tests `crps_for_ensemble` raises exception as expected."""
with pytest.raises(ValueError) as excinfo:
crps_for_ensemble(xr.DataArray(data=[1]), xr.DataArray(data=[1]), "ens_member", "unfair")
assert "`method` must be one of 'ecdf' or 'fair'" in str(excinfo.value)


def test_crps_for_ensemble_dask():
"""Tests `crps_for_ensemble` works with dask."""
result = crps_for_ensemble(
fcst=crps_test_data.DA_FCST_CRPSENS.chunk(),
obs=crps_test_data.DA_OBS_CRPSENS.chunk(),
ensemble_member_dim="ens_member",
method="ecdf",
preserve_dims="all",
)
assert isinstance(result.data, dask.array.Array)
result = result.compute()
assert isinstance(result.data, np.ndarray)
assert_dataarray_equal(result, crps_test_data.EXP_CRPSENS_ECDF, decimals=7)
Loading

0 comments on commit 2285409

Please sign in to comment.