diff --git a/docs/api.md b/docs/api.md index bf477f36..59004689 100644 --- a/docs/api.md +++ b/docs/api.md @@ -43,6 +43,7 @@ .. autofunction:: scores.probability.crps_for_ensemble .. autofunction:: scores.probability.tw_crps_for_ensemble .. autofunction:: scores.probability.tail_tw_crps_for_ensemble +.. autofunction:: scores.probability.interval_tw_crps_for_ensemble .. autofunction:: scores.probability.murphy_score .. autofunction:: scores.probability.murphy_thetas .. autofunction:: scores.probability.roc_curve_data diff --git a/docs/included.md b/docs/included.md index 6d7d5161..01a02620 100644 --- a/docs/included.md +++ b/docs/included.md @@ -184,6 +184,10 @@ - [API](api.md#scores.probability.crps_for_ensemble) - [Tutorial](project:./tutorials/CRPS_for_Ensembles.md) - [Ferro (2014)](https://doi.org/10.1002/qj.2270); [Gneiting And Raftery (2007)](https://doi.org/10.1198/016214506000001437); [Zamo and Naveau (2018)](https://doi.org/10.1007/s11004-017-9709-7) +* - Interval Threshold Weighted Continuous Ranked Probability Score (twCRPS) for Ensembles + - [API](api.md#scores.probability.interval_tw_crps_for_ensemble) + - — + - [Allen et al. (2023)](https://doi.org/10.1137/22M1532184) * - Isotonic Fit, *see Isotonic Regression* - — - — diff --git a/src/scores/probability/__init__.py b/src/scores/probability/__init__.py index e794c740..c851b74f 100644 --- a/src/scores/probability/__init__.py +++ b/src/scores/probability/__init__.py @@ -10,6 +10,7 @@ crps_cdf_brier_decomposition, crps_for_ensemble, crps_step_threshold_weight, + interval_tw_crps_for_ensemble, tail_tw_crps_for_ensemble, tw_crps_for_ensemble, ) @@ -29,4 +30,5 @@ "crps_step_threshold_weight", "tw_crps_for_ensemble", "tail_tw_crps_for_ensemble", + "interval_tw_crps_for_ensemble", ] diff --git a/src/scores/probability/crps_impl.py b/src/scores/probability/crps_impl.py index 39fdd305..3242597d 100644 --- a/src/scores/probability/crps_impl.py +++ b/src/scores/probability/crps_impl.py @@ -968,6 +968,7 @@ def tw_crps_for_ensemble( See also: :py:func:`scores.probability.crps_for_ensemble` :py:func:`scores.probability.tail_tw_crps_for_ensemble` + :py:func:`scores.probability.interval_tw_crps_for_ensemble` :py:func:`scores.probability.crps_cdf` @@ -1055,6 +1056,7 @@ def tail_tw_crps_for_ensemble( See also: :py:func:`scores.probability.tw_crps_for_ensemble` + :py:func:`scores.probability.interval_tw_crps_for_ensemble` :py:func:`scores.probability.crps_for_ensemble` :py:func:`scores.probability.crps_cdf` @@ -1094,3 +1096,102 @@ def _vfunc(x, threshold=threshold): weights=weights, ) return result + + +def interval_tw_crps_for_ensemble( + fcst: XarrayLike, + obs: XarrayLike, + ensemble_member_dim: str, + lower_threshold: Union[xr.DataArray, float], + upper_threshold: Union[xr.DataArray, float], + *, # Force keywords arguments to be keyword-only + method: Literal["ecdf", "fair"] = "ecdf", + reduce_dims: Optional[Sequence[str]] = None, + preserve_dims: Optional[Sequence[str]] = None, + weights: Optional[XarrayLike] = None, +) -> XarrayLike: + """ + Calculates the threshold weighted continuous ranked probability score (twCRPS) + weighted for an interval across the distribution from ensemble input. + + A threshold weight of 1 is assigned for values within the interval and a threshold weight of 0 otherwise. + The threshold values that define the bounds of the interval are given by the + ``lower_threshold`` and ``upper_threshold`` arguments. + For example, if we only want to foucs on the temperatures between -10 and -20 degrees C + where aircraft icing is most likely, we can set ``lower_threshold=-20`` and ``upper_threshold=-10``. + + + For more flexible weighting options and the relevant equations, see the + :py:func:`scores.probability.tw_crps_for_ensemble` function. + + Args: + fcst: Forecast data. Must have a dimension `ensemble_member_dim`. + obs: Observation data. + ensemble_member_dim: the dimension that specifies the ensemble member or the sample + from the predictive distribution. + lower_threshold: the threshold value for where the interval begins. It can either be a float + for a single threshold or an xarray object if the threshold varies across + dimensions (e.g., climatological values). + upper_threshold: the threshold value for where the interval ends. It can either be a float + for a single threshold or an xarray object if the threshold varies across + dimensions (e.g., climatological values). + method: Either "ecdf" or "fair". See :py:func:`scores.probability.tw_crps_for_ensemble` + for more details. + reduce_dims: Dimensions to reduce. Can be "all" to reduce all dimensions. + preserve_dims: Dimensions to preserve. Can be "all" to preserve all dimensions. + weights: Weights for calculating a weighted mean of individual scores. Note that + these weights are different to threshold weighting which is done by decision + threshold. + + Returns: + xarray object of twCRPS values that has been weighted based on an interval. + + Raises: + ValueError: when ``lower_threshold`` is not less than ``upper_threshold``. + ValueError: when ``method`` is not one of "ecdf" or "fair". + + References: + Allen, S., Ginsbourger, D., & Ziegel, J. (2023). Evaluating forecasts for high-impact + events using transformed kernel scores. SIAM/ASA Journal on Uncertainty + Quantification, 11(3), 906-940. https://doi.org/10.1137/22M1532184 + + See also: + :py:func:`scores.probability.tw_crps_for_ensemble` + :py:func:`scores.probability.tail_tw_crps_for_ensemble` + :py:func:`scores.probability.crps_for_ensemble` + :py:func:`scores.probability.crps_cdf` + + Examples: + Calculate the twCRPS for an ensemble where we assign a threshold weight of 1 + to thresholds between -20 and -10 and a threshold weight of 0 to thresholds outside + that interval. + + >>> import numpy as np + >>> import xarray as xr + >>> from scores.probability import interval_tw_crps_for_ensemble + >>> fcst = xr.DataArray(np.random.uniform(-40, 10, size=(10, 10)), dims=['time', 'ensemble']) + >>> obs = xr.DataArray(np.random.uniform(-40, 10, size=10), dims=['time']) + >>> interval_tw_crps_for_ensemble(fcst, obs, 'ensemble', -20, 10) + """ + if isinstance(lower_threshold, xr.DataArray) or isinstance(upper_threshold, xr.DataArray): + if (lower_threshold >= upper_threshold).any().values.item(): + raise ValueError("`lower_threshold` must be less than `upper_threshold`") + elif lower_threshold >= upper_threshold: + raise ValueError("`lower_threshold` must be less than `upper_threshold`") + + def _vfunc(x, lower_threshold=lower_threshold, upper_threshold=upper_threshold): + + return np.minimum(np.maximum(x, lower_threshold), upper_threshold) + + result = tw_crps_for_ensemble( + fcst, + obs, + ensemble_member_dim, + _vfunc, + chainging_func_kwargs={"lower_threshold": lower_threshold, "upper_threshold": upper_threshold}, + method=method, + reduce_dims=reduce_dims, + preserve_dims=preserve_dims, + weights=weights, + ) + return result diff --git a/tests/probabilty/crps_test_data.py b/tests/probabilty/crps_test_data.py index 6fa23cea..032d1426 100644 --- a/tests/probabilty/crps_test_data.py +++ b/tests/probabilty/crps_test_data.py @@ -589,6 +589,10 @@ ) DA_WT_CRPSENS = xr.DataArray(data=[1, 2, 1, 0, 2], dims=["stn"], coords={"stn": [101, 102, 103, 104, 105]}) DA_T_TWCRPSENS = xr.DataArray(data=[np.nan, 1, 10, 1, -2], dims=["stn"], coords={"stn": [101, 102, 103, 104, 105]}) +DA_LI_TWCRPSENS = xr.DataArray(data=[np.nan, 2, 2, 100, -200], dims=["stn"], coords={"stn": [101, 102, 103, 104, 105]}) +DA_UI_TWCRPSENS = xr.DataArray(data=[np.nan, 5, 5, 200, -100], dims=["stn"], coords={"stn": [101, 102, 103, 104, 105]}) +DA_LI_CONS_TWCRPSENS = xr.DataArray(data=[2, 2, 2, 2, 2], dims=["stn"], coords={"stn": [101, 102, 103, 104, 105]}) +DA_UI_CONS_TWCRPSENS = xr.DataArray(data=[5, 5, 5, 5, 5], dims=["stn"], coords={"stn": [101, 102, 103, 104, 105]}) DS_FCST_CRPSENS = xr.Dataset({"a": DA_FCST_CRPSENS, "b": DA_FCST_CRPSENS}) DS_OBS_CRPSENS = xr.Dataset({"a": DA_OBS_CRPSENS, "b": DA_OBS_CRPSENS}) @@ -698,3 +702,18 @@ ) EXP_VAR_THRES_CRPSENS_DA = VAR_THRES_FIRST_TERM_DA - VAR_THRES_SPREAD_ECDF_DA EXP_VAR_THRES_CRPSENS_DS = xr.Dataset({"a": EXP_VAR_THRES_CRPSENS_DA, "b": EXP_VAR_THRES_CRPSENS_DA}) + +# exp test data for interval twCRPS +INTERVAL_FIRST_TERM_DA = xr.DataArray( + data=[6 / 4, 4 / 4, 2 / 3, np.nan, 2], dims=["stn"], coords={"stn": [101, 102, 103, 104, 105]} +) +INTERVAL_SPREAD_ECDF_DA = xr.DataArray( + data=[(6 + 4 + 4 + 6) / 32, (2 + 2 + 2 + 6) / 32, (2 + 2 + 4) / 18, np.nan, 0], + dims=["stn"], + coords={"stn": [101, 102, 103, 104, 105]}, +) +EXP_INTERVAL_CRPSENS_ECDF_DA = INTERVAL_FIRST_TERM_DA - INTERVAL_SPREAD_ECDF_DA + +EXP_VAR_INTERVAL_CRPSENS_ECDF_DA = EXP_INTERVAL_CRPSENS_ECDF_DA * xr.DataArray( + data=[np.nan, 1, 1, np.nan, 0], dims=["stn"], coords={"stn": [101, 102, 103, 104, 105]} +) diff --git a/tests/probabilty/test_crps.py b/tests/probabilty/test_crps.py index 7fe50905..b2c33a03 100644 --- a/tests/probabilty/test_crps.py +++ b/tests/probabilty/test_crps.py @@ -18,6 +18,7 @@ crps_cdf, crps_cdf_brier_decomposition, crps_for_ensemble, + interval_tw_crps_for_ensemble, tail_tw_crps_for_ensemble, tw_crps_for_ensemble, ) @@ -776,7 +777,7 @@ def test_crps_for_ensemble_dask(): None, crps_test_data.EXP_VAR_THRES_CRPSENS_DA, ), - # test that passing in xr.DataSets with an xr.Dataset for the threshold arg works + # test that passing in xr.Datasets with an xr.Dataset for the threshold arg works ( crps_test_data.DS_FCST_CRPSENS, crps_test_data.DS_OBS_CRPSENS, @@ -788,7 +789,7 @@ def test_crps_for_ensemble_dask(): None, crps_test_data.EXP_VAR_THRES_CRPSENS_DS, ), - # test that passing in xr.DataSets with an xr.DataArray for the threshold arg works + # test that passing in xr.Datasets with an xr.DataArray for the threshold arg works ( crps_test_data.DS_FCST_CRPSENS, crps_test_data.DS_OBS_CRPSENS, @@ -1033,3 +1034,211 @@ def test_tw_crps_for_ensemble_dask(): result_ds = result_ds.compute() assert isinstance(result_ds["a"].data, np.ndarray) xr.testing.assert_allclose(result_ds, crps_test_data.EXP_UPPER_TAIL_CRPSENS_ECDF_DS) + + +@pytest.mark.parametrize( + ( + "fcst", + "obs", + "method", + "lower_threshold", + "upper_threshold", + "preserve_dims", + "reduce_dims", + "weights", + "expected", + ), + [ + # Test interval + ( + crps_test_data.DA_FCST_CRPSENS, + crps_test_data.DA_OBS_CRPSENS, + "ecdf", + 2, + 5, + "all", + None, + None, + crps_test_data.EXP_INTERVAL_CRPSENS_ECDF_DA, + ), + # test that it equals the standard CRPS when the interval contains all threshold with ecdf + ( + crps_test_data.DA_FCST_CRPSENS, + crps_test_data.DA_OBS_CRPSENS, + "ecdf", + -np.inf, + np.inf, + "all", + None, + None, + crps_test_data.EXP_CRPSENS_ECDF, + ), + # test that it equals the standard CRPS when the interval contains all threshold with fair + ( + crps_test_data.DA_FCST_CRPSENS, + crps_test_data.DA_OBS_CRPSENS, + "fair", + -np.inf, + np.inf, + "all", + None, + None, + crps_test_data.EXP_CRPSENS_FAIR, + ), + # Test broadcast + ( + crps_test_data.DA_FCST_CRPSENS_LT, + crps_test_data.DA_OBS_CRPSENS, + "ecdf", + -np.inf, + np.inf, + "all", + None, + None, + crps_test_data.EXP_CRPSENS_ECDF_BC, + ), + # Test with weights and reduce dims + ( + crps_test_data.DA_FCST_CRPSENS, + crps_test_data.DA_OBS_CRPSENS, + "ecdf", + -np.inf, + np.inf, + None, + "stn", + crps_test_data.DA_WT_CRPSENS, + crps_test_data.EXP_CRPSENS_WT, + ), + # test that passing an xarray object for the threshold args works + ( + crps_test_data.DA_FCST_CRPSENS, + crps_test_data.DA_OBS_CRPSENS, + "ecdf", + crps_test_data.DA_LI_TWCRPSENS, + crps_test_data.DA_UI_TWCRPSENS, + "all", + None, + None, + crps_test_data.EXP_VAR_INTERVAL_CRPSENS_ECDF_DA, + ), + # Test that a float for lower_threshold and an xr.DataArray for upper_threshold works + ( + crps_test_data.DA_FCST_CRPSENS, + crps_test_data.DA_OBS_CRPSENS, + "ecdf", + 2, + crps_test_data.DA_UI_CONS_TWCRPSENS, + "all", + None, + None, + crps_test_data.EXP_INTERVAL_CRPSENS_ECDF_DA, + ), + # Test that an xr.DataArray for lower_threshold and a float for upper_threshold works + ( + crps_test_data.DA_FCST_CRPSENS, + crps_test_data.DA_OBS_CRPSENS, + "ecdf", + crps_test_data.DA_LI_CONS_TWCRPSENS, + 5, + "all", + None, + None, + crps_test_data.EXP_INTERVAL_CRPSENS_ECDF_DA, + ), + # test that passing in xr.Datasets for fcst, obs and weights works + ( + crps_test_data.DS_FCST_CRPSENS, + crps_test_data.DS_OBS_CRPSENS, + "ecdf", + -np.inf, + np.inf, + None, + None, + crps_test_data.DS_WT_CRPSENS, + crps_test_data.EXP_CRPSENS_WT_DS, + ), + ], +) +def test_interval_tw_crps_for_ensemble( + fcst, obs, method, lower_threshold, upper_threshold, preserve_dims, reduce_dims, weights, expected +): + """Tests interval_tw_crps_for_ensembles""" + result = interval_tw_crps_for_ensemble( + fcst, + obs, + ensemble_member_dim="ens_member", + lower_threshold=lower_threshold, + upper_threshold=upper_threshold, + method=method, + preserve_dims=preserve_dims, + reduce_dims=reduce_dims, + weights=weights, + ) + + xr.testing.assert_allclose(result, expected) + + +@pytest.mark.parametrize( + ("lower_threshold", "upper_threshold"), + [ + (1, 1), + (2, 1), + (xr.DataArray(data=[1, np.nan]), xr.DataArray(data=[1, np.nan])), + (xr.DataArray(data=[2, np.nan]), xr.DataArray(data=[1, np.nan])), + (1, xr.DataArray(data=[1, np.nan])), + (xr.DataArray(data=[1, np.nan]), 1), + ], +) +def test_interval_tw_crps_for_ensemble_raises(lower_threshold, upper_threshold): + """Tests if interval_tw_crps_for_ensemble raises an error when lower_threshold >= upper_threshold""" + with pytest.raises(ValueError) as excinfo: + interval_tw_crps_for_ensemble( + fcst=crps_test_data.DA_FCST_CRPSENS, + obs=crps_test_data.DA_OBS_CRPSENS, + ensemble_member_dim="ens_member", + lower_threshold=lower_threshold, + upper_threshold=upper_threshold, + method="ecdf", + ) + assert "`lower_threshold` must be less than `upper_threshold`" in str(excinfo.value) + + +def test_interval_tw_crps_for_ensemble_dask(): + """Tests `interval_tw_crps_for_ensemble` works with dask.""" + + if dask == "Unavailable": # pragma: no cover + pytest.skip("Dask unavailable, could not run test") # pragma: no cover + + # Check that it works with xr.Datarrays + result = interval_tw_crps_for_ensemble( + fcst=crps_test_data.DA_FCST_CRPSENS.chunk(), + obs=crps_test_data.DA_OBS_CRPSENS.chunk(), + ensemble_member_dim="ens_member", + lower_threshold=2, + upper_threshold=5, + method="ecdf", + preserve_dims="all", + reduce_dims=None, + weights=None, + ) + assert isinstance(result.data, dask.array.Array) + result = result.compute() + assert isinstance(result.data, np.ndarray) + xr.testing.assert_allclose(result, crps_test_data.EXP_INTERVAL_CRPSENS_ECDF_DA) + + # Check that it works with xr.Datasets + result_ds = interval_tw_crps_for_ensemble( + fcst=crps_test_data.DS_FCST_CRPSENS.chunk(), + obs=crps_test_data.DS_OBS_CRPSENS.chunk(), + ensemble_member_dim="ens_member", + lower_threshold=-np.inf, + upper_threshold=np.inf, + method="ecdf", + preserve_dims=None, + reduce_dims=None, + weights=crps_test_data.DS_WT_CRPSENS.chunk(), + ) + assert isinstance(result_ds["a"].data, dask.array.Array) + result_ds = result_ds.compute() + assert isinstance(result_ds["a"].data, np.ndarray) + xr.testing.assert_allclose(result_ds, crps_test_data.EXP_CRPSENS_WT_DS)