diff --git a/README.md b/README.md index b922bc98..1fe27f3e 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Currently Included Metrics: | continuous | probability | categorical | statistical tests | | ---------- | ----------- | ----------- | ----------------- | -| MAE, MSE, RMSE, Flip Flop Index, Quantile Score | CRPS, ROC | FIRM, POD, POFD | Diebold Mariano (with the Harvey et al. 1997 and the Hering and Genton 2011 modifications) | +| MAE, MSE, RMSE, Flip Flop Index, Quantile Score | CRPS for CDF, CRPS for ensemble, ROC | FIRM, POD, POFD | Diebold Mariano (with the Harvey et al. 1997 and the Hering and Genton 2011 modifications) | **Notice -- This repository is currently undergoing initial construction and maintenance. It is not yet recommended for use. This notice will be removed after the first feature release. In the meantime, please feel free to look around, and don't hesitate to get in touch with any questions (see the contributing guide for how).** diff --git a/docs/api.md b/docs/api.md index 4b146e40..8c3a536c 100644 --- a/docs/api.md +++ b/docs/api.md @@ -25,6 +25,7 @@ .. autofunction:: scores.probability.crps_cdf .. autofunction:: scores.probability.adjust_fcst_for_crps .. autofunction:: scores.probability.crps_cdf_brier_decomposition +.. autofunction:: scores.probability.crps_for_ensemble .. autofunction:: scores.probability.murphy_score .. autofunction:: scores.probability.murphy_thetas .. autofunction:: scores.probability.roc_curve_data diff --git a/docs/summary_table_of_scores.md b/docs/summary_table_of_scores.md index b035332a..77469bf6 100644 --- a/docs/summary_table_of_scores.md +++ b/docs/summary_table_of_scores.md @@ -1,3 +1,3 @@ | continuous | probability | categorical | statistical tests | | ---------- | ----------- | ----------- | ----------- | -| MAE, MSE, RMSE, Murphy score, Quantile Score | CRPS_CDF, Murphy score, ROC | FIRM, POD, POFD, Flip Flop Index | Diebold Mariano (with the Harvey et al. 1997 and the Hering and Genton 2011 modifications)| +| MAE, MSE, RMSE, Murphy score, Quantile Score | CRPS for CDF, CRPS for ensemble, Murphy score, ROC | FIRM, POD, POFD, Flip Flop Index | Diebold Mariano (with the Harvey et al. 1997 and the Hering and Genton 2011 modifications)| diff --git a/src/scores/probability/__init__.py b/src/scores/probability/__init__.py index 59b34baa..5afb304d 100644 --- a/src/scores/probability/__init__.py +++ b/src/scores/probability/__init__.py @@ -7,5 +7,6 @@ adjust_fcst_for_crps, crps_cdf, crps_cdf_brier_decomposition, + crps_for_ensemble, ) from scores.probability.roc_impl import roc_curve_data diff --git a/src/scores/probability/crps_impl.py b/src/scores/probability/crps_impl.py index ae78b7f1..4edab259 100644 --- a/src/scores/probability/crps_impl.py +++ b/src/scores/probability/crps_impl.py @@ -1,9 +1,10 @@ """ This module supports the implementation of the CRPS scoring function, drawing from additional functions. -The primary method, `crps_cdf` is imported into the probability module to be part of the probability API +The two primary methods, `crps_cdf` and `crps_for_ensemble` are imported into +the probability module to be part of the probability API. """ from collections.abc import Iterable -from typing import Literal, Optional +from typing import Literal, Optional, Sequence import numpy as np import pandas as pd @@ -735,3 +736,87 @@ def crps_step_threshold_weight( weight = 1 - weight return weight + + +def crps_for_ensemble( + fcst: xr.DataArray, + obs: xr.DataArray, + ensemble_member_dim: str, + method: Literal["ecdf", "fair"] = "ecdf", + reduce_dims: Optional[Sequence[str]] = None, + preserve_dims: Optional[Sequence[str]] = None, + weights: xr.DataArray = None, +) -> xr.DataArray: + """ + Calculates the continuous ranked probability score (CRPS) given an ensemble of forecasts. + An ensemble of forecasts can also be thought of as a random sample from the predictive + distribution. + + Given an observation y, and ensemble member values {x_i} (for 1 <= i <= M), the CRPS is + calculated by the formula + CRPS({x_i}, y) = (1 / M) * sum(|x_i - y|) - (1 / 2 * K) * sum(|x_i - x_j|), + where the first sum is iterated over 1 <= i <= M and the second sum is iterated over + 1 <= i <= M and 1 <= j <= M. + + The value of the constant K in this formula depends on the method. + - If `method="ecdf"` then K = M ** 2. In this case the CRPS value returned is + the exact CRPS value for the emprical cumulation distribution function + constructed using the ensemble values. + - If `method="fair"` then K = M * (M - 1). In this case the CRPS value returned + is the approximated CRPS where the ensemble values can be interpreted as a + random sample from the underlying predictive distribution. This interpretation + stems from the formula CRPS(F, Y) = E|X - Y| - E|X - X'|/2, where X and X' + are independent samples of the predictive distribution F, Y is the observation + (possibly unknown) and E denotes the expectation. This choice of K gives an + unbiased estimate for the second expectation. + + Args: + fcst: Forecast data. Must have a dimension `ensemble_member_dim`. + obs: Observation data. + ensemble_member_dim: the dimension that specifies the ensemble member or the sample + from the predictive distribution. + method: Either "ecdf" or "fair". + reduce_dims: Dimensions to reduce. Can be "all" to reduce all dimensions. + preserve_dims: Dimensions to preserve. Can be "all" to preserve all dimensions. + weights: Weights for calculating a weighted mean of individual scores. + + Returns: + xarray object of (weighted mean) CRPS values. + + Raises: + ValueError: when method is not one of "ecdf" or "fair". + + References: + - C. Ferro (2014), "Fair scores for ensemble forecasts", Q J R Meteorol Soc + 140(683):1917-1923. + - T. Gneiting T and A. Raftery (2007), "Strictly proper scoring rules, prediction, + and estimation", J Am Stat Assoc, 102(477):359-37. + - M. Zamo and P. Naveau (2018), "Estimation of the Continuous Ranked Probability + Score with Limited Information and Applications to Ensemble Weather Forecasts", + Math Geosci 50:209-234, https://doi.org/10.1007/s11004-017-9709-7 + """ + if method not in ["ecdf", "fair"]: + raise ValueError("`method` must be one of 'ecdf' or 'fair'") + + dims_for_mean = scores.utils.gather_dimensions2(fcst, obs, weights, reduce_dims, preserve_dims, ensemble_member_dim) + + ensemble_member_dim1 = scores.utils.tmp_coord_name(fcst) + + # calculate forecast spread contribution + fcst_copy = fcst.rename({ensemble_member_dim: ensemble_member_dim1}) + + fcst_spread_term = np.abs(fcst - fcst_copy).sum([ensemble_member_dim, ensemble_member_dim1]) + ens_count = fcst.count(ensemble_member_dim) + if method == "ecdf": + fcst_spread_term = fcst_spread_term / (2 * ens_count**2) + if method == "fair": + fcst_spread_term = fcst_spread_term / (2 * ens_count * (ens_count - 1)) + + # calculate final CRPS for each forecast case + fcst_obs_term = np.abs(fcst - obs).mean(ensemble_member_dim) + result = fcst_obs_term - fcst_spread_term + + # apply weights and take means across specified dims + result = scores.functions.apply_weights(result, weights).mean(dim=dims_for_mean) + + return result diff --git a/src/scores/utils.py b/src/scores/utils.py index 074fb457..027c1e4b 100644 --- a/src/scores/utils.py +++ b/src/scores/utils.py @@ -23,11 +23,23 @@ It is ambiguous how to proceed therefore an exception has been raised instead. """ +ERROR_SPECIFIED_NONPRESENT_PRESERVE_DIMENSION2 = """ +You are requesting to preserve a dimension which does not appear in your data +(fcst, obs or weights). It is ambiguous how to proceed therefore an exception has been +raised instead. +""" + ERROR_SPECIFIED_NONPRESENT_REDUCE_DIMENSION = """ You are requesting to reduce a dimension which does not appear in your data (fcst or obs). It is ambiguous how to proceed therefore an exception has been raised instead. """ +ERROR_SPECIFIED_NONPRESENT_REDUCE_DIMENSION2 = """ +You are requesting to reduce a dimension which does not appear in your data +(fcst, obs or weights). It is ambiguous how to proceed therefore an exception has been +raised instead. +""" + ERROR_OVERSPECIFIED_PRESERVE_REDUCE = """ You have specified both preserve_dims and reduce_dims. This method doesn't know how to properly interpret that, therefore an exception has been raised. @@ -112,6 +124,95 @@ def gather_dimensions( # pylint: disable=too-many-branches return reduce_dims +def gather_dimensions2( + fcst: xr.DataArray, + obs: xr.DataArray, + weights: xr.DataArray = None, + reduce_dims: FlexibleDimensionTypes = None, + preserve_dims: FlexibleDimensionTypes = None, + special_fcst_dims: FlexibleDimensionTypes = None, +) -> set[Hashable]: + """ + Performs standard dimensions checks for inputs of functions that calculate (mean) scores. + Returns a set of the dimensions to reduce. + + Args: + fcst: Forecast data + obs: Observation data + weights: Weights for calculating a weighted mean of scores + reduce_dims: Dimensions to reduce. Can be "all" to reduce all dimensions. + preserve_dims: Dimensions to preserve. Can be "all" to preserve all dimensions. + special_fcst_dims: Dimension(s) in `fcst` that are reduced to calculate individual scores. + Must not appear as a dimension in `obs`, `weights`, `reduce_dims` or `preserve_dims`. + e.g. the ensemble member dimension if calculating CRPS for ensembles, or the + threshold dimension of calculating CRPS for CDFs. + + Returns: + Set of dimensions over which to take the mean once the checks are passed. + + Raises: + ValueError: when `preserve_dims and `reduce_dims` are both specified. + ValueError: when `special_fcst_dims` is not a subset of `fcst.dims`. + ValueError: when `obs.dims`, `weights.dims`, `reduce_dims` or `preserve_dims` + contains elements from `special_fcst_dims`. + ValueError: when `preserve_dims and `reduce_dims` contain elements not among dimensions + of the data (`fcst`, `obs` or `weights`). + """ + all_data_dims = set(fcst.dims).union(set(obs.dims)) + if weights is not None: + all_data_dims = all_data_dims.union(set(weights.dims)) + + # all_scoring_dims is the set of dims remaining after individual scores are computed. + all_scoring_dims = all_data_dims.copy() + + # Handle error conditions related to specified dimensions + if preserve_dims is not None and reduce_dims is not None: + raise ValueError(ERROR_OVERSPECIFIED_PRESERVE_REDUCE) + + specified_dims = preserve_dims or reduce_dims + + if specified_dims == "all": + if "all" in all_data_dims: + warnings.warn(WARN_ALL_DATA_CONFLICT_MSG) + elif specified_dims is not None: + if isinstance(specified_dims, str): + specified_dims = [specified_dims] + + # check that special_fcst_dims are in fcst.dims only + if special_fcst_dims is not None: + if isinstance(special_fcst_dims, str): + special_fcst_dims = [special_fcst_dims] + if not set(special_fcst_dims).issubset(set(fcst.dims)): + raise ValueError("`special_fcst_dims` must be a subset of `fcst` dimensions") + if len(set(obs.dims).intersection(set(special_fcst_dims))) > 0: + raise ValueError("`obs.dims` must not contain any `special_fcst_dims`") + if weights is not None: + if len(set(weights.dims).intersection(set(special_fcst_dims))) > 0: + raise ValueError("`weights.dims` must not contain any `special_fcst_dims`") + if specified_dims is not None and specified_dims != "all": + if len(set(specified_dims).intersection(set(special_fcst_dims))) > 0: + raise ValueError("`reduce_dims` and `preserve_dims` must not contain any `special_fcst_dims`") + # remove special_fcst_dims from all_scoring_dims + all_scoring_dims = all_scoring_dims.difference(set(special_fcst_dims)) + + if specified_dims is not None and specified_dims != "all": + if not set(specified_dims).issubset(all_scoring_dims): + if preserve_dims is not None: + raise ValueError(ERROR_SPECIFIED_NONPRESENT_PRESERVE_DIMENSION2) + raise ValueError(ERROR_SPECIFIED_NONPRESENT_REDUCE_DIMENSION2) + + # all errors have been captured, so now return list of dims to reduce + if specified_dims is None: + return all_scoring_dims + if reduce_dims is not None: + if reduce_dims == "all": + return all_scoring_dims + return set(specified_dims) + if preserve_dims == "all": + return set([]) + return all_scoring_dims.difference(set(specified_dims)) + + def dims_complement(data, dims=None) -> list[str]: """Returns the complement of data.dims and dims @@ -215,3 +316,20 @@ def check_dims(xr_data: XarrayLike, expected_dims: Sequence[str], mode: Optional f"Dimensions {list(xr_data[data_var].dims)} of data variable " f"'{data_var}' are not {mode} to the dimensions {sorted(dims_set)}" ) + + +def tmp_coord_name(xr_data: xr.DataArray) -> str: + """ + Generates a temporary coordinate name that is not among the coordinate or dimension + names of `xr_data`. + + Args: + xr_data: Input xarray data array + + Returns: + A string which is the concatenation of 'new' with all coordinate and + dimension names in the input array. + """ + all_names = ["new"] + list(xr_data.dims) + list(xr_data.coords) + result = "".join(all_names) + return result diff --git a/tests/probabilty/crps_test_data.py b/tests/probabilty/crps_test_data.py index 7975c976..b7a8ff3b 100644 --- a/tests/probabilty/crps_test_data.py +++ b/tests/probabilty/crps_test_data.py @@ -562,3 +562,33 @@ ) EXP_CRPS_BD2 = xr.merge([EXP_TOTAL_CRPS_BD2, EXP_UNDER_CRPS_BD2, EXP_OVER_CRPS_BD2]) + +# test data for CRPS for ensembles + +DA_FCST_CRPSENS = xr.DataArray( + data=[[0.0, 4, 3, 7], [1, -1, 2, 4], [0, 1, 4, np.nan], [2, 3, 4, 1], [2, np.nan, np.nan, np.nan]], + dims=["stn", "ens_member"], + coords={"stn": [101, 102, 103, 104, 105], "ens_member": [1, 2, 3, 4]}, +) +DA_OBS_CRPSENS = xr.DataArray(data=[2.0, 3, 1, np.nan, 4], dims=["stn"], coords={"stn": [101, 102, 103, 104, 105]}) +DA_WT_CRPSENS = xr.DataArray(data=[1, 2, 1, 0, 2], dims=["stn"], coords={"stn": [101, 102, 103, 104, 105]}) + +# first and second (spread) terms from crps for ensembles, "ecdf" and "fair" methods +FIRST_TERM = xr.DataArray( + data=[10 / 4, 8 / 4, 4 / 3, np.nan, 2], dims=["stn"], coords={"stn": [101, 102, 103, 104, 105]} +) +SPREAD_ECDF = xr.DataArray( + data=[(14 + 8 + 8 + 14) / 32, (6 + 10 + 6 + 10) / 32, (5 + 4 + 7) / 18, np.nan, 0], + dims=["stn"], + coords={"stn": [101, 102, 103, 104, 105]}, +) +SPREAD_FAIR = xr.DataArray( + data=[(14 + 8 + 8 + 14) / 24, (6 + 10 + 6 + 10) / 24, (5 + 4 + 7) / 12, np.nan, np.nan], + dims=["stn"], + coords={"stn": [101, 102, 103, 104, 105]}, +) + +# expected results +EXP_CRPSENS_ECDF = FIRST_TERM - SPREAD_ECDF +EXP_CRPSENS_FAIR = FIRST_TERM - SPREAD_FAIR +EXP_CRPSENS_WT = (EXP_CRPSENS_ECDF * DA_WT_CRPSENS).mean("stn") diff --git a/tests/probabilty/test_crps.py b/tests/probabilty/test_crps.py index c20fb841..1f4f1821 100644 --- a/tests/probabilty/test_crps.py +++ b/tests/probabilty/test_crps.py @@ -6,11 +6,13 @@ import dask.array import numpy as np import pytest +import xarray as xr from scores.probability import ( adjust_fcst_for_crps, crps_cdf, crps_cdf_brier_decomposition, + crps_for_ensemble, ) from scores.probability.crps_impl import ( crps_cdf_exact, @@ -604,3 +606,46 @@ def test_crps_cdf_brier_decomposition(dims, expected): crps_test_data.DA_FCST_CRPS_BD, crps_test_data.DA_OBS_CRPS_BD, "x", preserve_dims=dims ) assert_dataset_equal(result, expected, decimals=7) + + +def test_crps_for_ensemble(): + """Tests `crps_for_ensemble` returns as expected.""" + result_ecdf = crps_for_ensemble( + crps_test_data.DA_FCST_CRPSENS, crps_test_data.DA_OBS_CRPSENS, "ens_member", method="ecdf", preserve_dims="all" + ) + result_fair = crps_for_ensemble( + crps_test_data.DA_FCST_CRPSENS, crps_test_data.DA_OBS_CRPSENS, "ens_member", method="fair", preserve_dims="all" + ) + result_weighted_mean = crps_for_ensemble( + crps_test_data.DA_FCST_CRPSENS, + crps_test_data.DA_OBS_CRPSENS, + "ens_member", + method="ecdf", + weights=crps_test_data.DA_WT_CRPSENS, + ) + + assert_dataarray_equal(result_ecdf, crps_test_data.EXP_CRPSENS_ECDF, decimals=7) + assert_dataarray_equal(result_fair, crps_test_data.EXP_CRPSENS_FAIR, decimals=7) + assert_dataarray_equal(result_weighted_mean, crps_test_data.EXP_CRPSENS_WT, decimals=7) + + +def test_crps_for_ensemble_raises(): + """Tests `crps_for_ensemble` raises exception as expected.""" + with pytest.raises(ValueError) as excinfo: + crps_for_ensemble(xr.DataArray(data=[1]), xr.DataArray(data=[1]), "ens_member", "unfair") + assert "`method` must be one of 'ecdf' or 'fair'" in str(excinfo.value) + + +def test_crps_for_ensemble_dask(): + """Tests `crps_for_ensemble` works with dask.""" + result = crps_for_ensemble( + fcst=crps_test_data.DA_FCST_CRPSENS.chunk(), + obs=crps_test_data.DA_OBS_CRPSENS.chunk(), + ensemble_member_dim="ens_member", + method="ecdf", + preserve_dims="all", + ) + assert isinstance(result.data, dask.array.Array) + result = result.compute() + assert isinstance(result.data, np.ndarray) + assert_dataarray_equal(result, crps_test_data.EXP_CRPSENS_ECDF, decimals=7) diff --git a/tests/test_utils.py b/tests/test_utils.py index 1dcc4c24..3d2beae2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,10 +3,12 @@ """ import pytest +import xarray as xr from scores import utils from scores.utils import DimensionError from scores.utils import gather_dimensions as gd +from scores.utils import gather_dimensions2 from tests import utils_test_data @@ -501,3 +503,177 @@ def test_gather_dimensions_exceptions(): # Preserve "all" as a string but named dimension present in data with pytest.warns(UserWarning): assert gd(fcst_dims_conflict, obs_dims, reduce_dims="all") == fcst_dims_conflict + + +@pytest.mark.parametrize( + ("fcst", "obs", "weights", "reduce_dims", "preserve_dims", "special_fcst_dims", "error_msg_snippet"), + [ + ( + utils_test_data.DA_RGB, + utils_test_data.DA_R, + None, + ["red"], + ["blue"], + None, + utils.ERROR_OVERSPECIFIED_PRESERVE_REDUCE, + ), + # checks for special_fcst_dims + ( + utils_test_data.DA_RGB, + utils_test_data.DA_R, + None, + None, + None, + ["black"], + "`special_fcst_dims` must be a subset of `fcst` dimensions", + ), + ( + utils_test_data.DA_RGB, + utils_test_data.DA_R, + None, + None, + None, + ["red"], + "`obs.dims` must not contain any `special_fcst_dims`", + ), + ( + utils_test_data.DA_RGB, + utils_test_data.DA_R, + None, + None, + None, + "red", + "`obs.dims` must not contain any `special_fcst_dims`", + ), + ( + utils_test_data.DA_RGB, + utils_test_data.DA_R, + utils_test_data.DA_G, + None, + None, + "green", + "`weights.dims` must not contain any `special_fcst_dims`", + ), + ( + utils_test_data.DA_RGB, + utils_test_data.DA_R, + utils_test_data.DA_G, + "blue", + None, + "blue", + "`reduce_dims` and `preserve_dims` must not contain any `special_fcst_dims`", + ), + ( + utils_test_data.DA_RGB, + utils_test_data.DA_R, + utils_test_data.DA_G, + None, + ["blue"], + "blue", + "`reduce_dims` and `preserve_dims` must not contain any `special_fcst_dims`", + ), + ( + utils_test_data.DA_RGB, + utils_test_data.DA_R, + utils_test_data.DA_G, + None, + ["blue", "yellow"], + None, + utils.ERROR_SPECIFIED_NONPRESENT_PRESERVE_DIMENSION2, + ), + ( + utils_test_data.DA_RGB, + utils_test_data.DA_R, + utils_test_data.DA_G, + "yellow", + None, + "blue", + utils.ERROR_SPECIFIED_NONPRESENT_REDUCE_DIMENSION2, + ), + ], +) +def test_gather_dimensions2_exceptions( + fcst, obs, weights, reduce_dims, preserve_dims, special_fcst_dims, error_msg_snippet +): + """ + Confirm `gather_dimensions2` raises exceptions as expected. + """ + with pytest.raises(ValueError) as excinfo: + gather_dimensions2( + fcst, + obs, + weights=weights, + reduce_dims=reduce_dims, + preserve_dims=preserve_dims, + special_fcst_dims=special_fcst_dims, + ) + assert error_msg_snippet in str(excinfo.value) + + +def test_gather_dimensions2_warnings(): + """Tests that gather_dimensions2 warns as expected with correct output.""" + # Preserve "all" as a string but named dimension present in data + with pytest.warns(UserWarning): + result = gather_dimensions2( + utils_test_data.DA_R.rename({"red": "all"}), utils_test_data.DA_R, preserve_dims="all" + ) + assert result == set([]) + + with pytest.warns(UserWarning): + result = gather_dimensions2( + utils_test_data.DA_R.rename({"red": "all"}), utils_test_data.DA_R, reduce_dims="all" + ) + assert result == {"red", "all"} + + +@pytest.mark.parametrize( + ("fcst", "obs", "weights", "reduce_dims", "preserve_dims", "special_fcst_dims", "expected"), + [ + # test that fcst and obs dims are returned + (utils_test_data.DA_B, utils_test_data.DA_R, None, None, None, None, {"blue", "red"}), + # test that fcst, obs and weights dims are returned + (utils_test_data.DA_B, utils_test_data.DA_R, utils_test_data.DA_G, None, None, None, {"blue", "red", "green"}), + # two tests that fcst, obs and weights dims are returned, without the special fcst dim + (utils_test_data.DA_B, utils_test_data.DA_R, utils_test_data.DA_G, None, None, "blue", {"red", "green"}), + (utils_test_data.DA_B, utils_test_data.DA_R, utils_test_data.DA_G, None, None, ["blue"], {"red", "green"}), + # test that reduce_dims="all" behaves as expected + (utils_test_data.DA_B, utils_test_data.DA_R, utils_test_data.DA_G, "all", None, None, {"blue", "red", "green"}), + # three tests for reduce_dims + (utils_test_data.DA_B, utils_test_data.DA_R, utils_test_data.DA_G, "blue", None, None, {"blue"}), + (utils_test_data.DA_B, utils_test_data.DA_R, utils_test_data.DA_G, ["blue"], None, None, {"blue"}), + (utils_test_data.DA_RGB, utils_test_data.DA_R, utils_test_data.DA_G, ["green"], None, "blue", {"green"}), + # test for preserve_dims="all" + (utils_test_data.DA_RGB, utils_test_data.DA_B, None, None, "all", "red", set([])), + # three tests for preserve_dims + (utils_test_data.DA_RGB, utils_test_data.DA_R, utils_test_data.DA_G, None, "green", None, {"red", "blue"}), + (utils_test_data.DA_RGB, utils_test_data.DA_R, None, None, ["green"], None, {"red", "blue"}), + (utils_test_data.DA_RGB, utils_test_data.DA_B, None, None, ["green"], "red", {"blue"}), + ], +) +def test_gather_dimensions2_examples(fcst, obs, weights, reduce_dims, preserve_dims, special_fcst_dims, expected): + """ + Test that `gather_dimensions2` gives outputs as expected. + """ + result = gather_dimensions2( + fcst, + obs, + weights=weights, + reduce_dims=reduce_dims, + preserve_dims=preserve_dims, + special_fcst_dims=special_fcst_dims, + ) + assert result == expected + + +def test_tmp_coord_name(): + """ + Tests that `tmp_coord_name` returns as expected. + """ + data = xr.DataArray(data=[1, 2, 3]) + assert utils.tmp_coord_name(data) == "newdim_0" + + data = xr.DataArray(data=[1, 2, 3], dims=["stn"], coords=dict(stn=[101, 202, 304])) + assert utils.tmp_coord_name(data) == "newstnstn" + + data = xr.DataArray(data=[1, 2, 3], dims=["stn"], coords=dict(stn=[101, 202, 304], elevation=("stn", [0, 3, 24]))) + assert utils.tmp_coord_name(data) == "newstnstnelevation"