Skip to content

Commit

Permalink
firm can now take a list of xr.datarrays for the thresholds (#661)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasloveday authored Sep 9, 2024
1 parent 170c503 commit 2480bc1
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 9 deletions.
11 changes: 8 additions & 3 deletions src/scores/categorical/multicategorical_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
This module contains methods which may be used for scoring multicategorical forecasts
"""

from collections.abc import Sequence
from typing import Optional, Union

Expand All @@ -16,7 +17,7 @@ def firm( # pylint: disable=too-many-arguments
fcst: xr.DataArray,
obs: xr.DataArray,
risk_parameter: float,
categorical_thresholds: Sequence[float],
categorical_thresholds: Union[Sequence[float], Sequence[xr.DataArray]],
threshold_weights: Sequence[Union[float, xr.DataArray]],
*, # Force keywords arguments to be keyword-only
discount_distance: Optional[float] = 0,
Expand All @@ -37,7 +38,9 @@ def firm( # pylint: disable=too-many-arguments
risk_parameter: Risk parameter (alpha) for the FIRM score. The value must
satisfy 0 < `risk_parameter` < 1.
categorical_thresholds: Category thresholds (thetas) to delineate the
categories.
categories. A sequence of xr.DataArrays may be supplied to allow
for different thresholds at each coordinate (e.g., thresholds
determined by climatology).
threshold_weights: Weights that specify the relative importance of forecasting on
the correct side of each category threshold. Either a positive
float can be supplied for each categorical threshold or an
Expand Down Expand Up @@ -165,7 +168,7 @@ def _single_category_score(
fcst: xr.DataArray,
obs: xr.DataArray,
risk_parameter: float,
categorical_threshold: float,
categorical_threshold: Union[float, xr.DataArray],
*, # Force keywords arguments to be keyword-only
discount_distance: Optional[float] = None,
threshold_assignment: Optional[str] = "lower",
Expand Down Expand Up @@ -216,8 +219,10 @@ def _single_category_score(
# Bring back NaNs
condition1 = condition1.where(~np.isnan(fcst))
condition1 = condition1.where(~np.isnan(obs))
condition1 = condition1.where(~np.isnan(categorical_threshold))
condition2 = condition2.where(~np.isnan(fcst))
condition2 = condition2.where(~np.isnan(obs))
condition2 = condition2.where(~np.isnan(categorical_threshold))

if discount_distance:
scale_1 = np.minimum(categorical_threshold - obs, discount_distance)
Expand Down
70 changes: 70 additions & 0 deletions tests/categorical/multicategorical_test_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Test data for testing scores.categorical.multicategorical functions
"""

import numpy as np
import xarray as xr

Expand Down Expand Up @@ -37,6 +38,14 @@
"i": [1, 2, 3, 4, 5, 6],
},
)
DA_THRESHOLD_SC = xr.DataArray(
data=[[5, 5], [-200, np.nan]],
dims=["j", "k"],
coords={
"j": [100001, 10000], # coords in different order to forecast
"k": [10, 11],
},
)

EXP_SC_TOTAL_CASE0 = xr.DataArray(
data=[[[np.nan, 0.3], [0.7, np.nan]], [[0.0, 0], [0, np.nan]]],
Expand Down Expand Up @@ -222,6 +231,41 @@
}
)

EXP_SC_TOTAL_CASE6 = xr.DataArray(
data=[[[np.nan, np.nan], [0.7, np.nan]], [[0.0, np.nan], [0, np.nan]]],
dims=["i", "j", "k"],
coords={
"i": [1, 2],
"j": [10000, 100001],
"k": [10, 11],
},
)
EXP_SC_UNDER_CASE6 = xr.DataArray(
data=[[[np.nan, np.nan], [0.7, np.nan]], [[0.0, np.nan], [0, np.nan]]],
dims=["i", "j", "k"],
coords={
"i": [1, 2],
"j": [10000, 100001],
"k": [10, 11],
},
)
EXP_SC_OVER_CASE6 = xr.DataArray(
data=[[[np.nan, np.nan], [0, np.nan]], [[0.0, np.nan], [0, np.nan]]],
dims=["i", "j", "k"],
coords={
"i": [1, 2],
"j": [10000, 100001],
"k": [10, 11],
},
)
EXP_SC_CASE6 = xr.Dataset(
{
"firm_score": EXP_SC_TOTAL_CASE6,
"underforecast_penalty": EXP_SC_UNDER_CASE6,
"overforecast_penalty": EXP_SC_OVER_CASE6,
}
)

DA_FCST_FIRM = xr.DataArray(
data=[
[[np.nan, 7, 4], [-100, 0, 1], [0, -100, 1]],
Expand All @@ -243,6 +287,32 @@
},
)

DA_THRESHOLD_FIRM = [
xr.DataArray(
data=[0, 0, 0],
dims=["j"],
coords={"j": [10000, 100001, 900000]},
),
xr.DataArray(
data=[5, 5, 5],
dims=["j"],
coords={"j": [10000, 100001, 900000]},
),
]

DA_THRESHOLD_FIRM2 = [
xr.DataArray(
data=[0, 5, 0],
dims=["j"],
coords={"j": [10000, 100001, 900000]},
),
xr.DataArray(
data=[5, 0, 5],
dims=["j"],
coords={"j": [10000, 100001, 900000]},
),
]

LIST_WEIGHTS_FIRM0 = [
xr.DataArray(
data=[2, 2, 2],
Expand Down
27 changes: 27 additions & 0 deletions tests/categorical/test_multicategorical.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Contains unit tests for scores.categorical
"""

try:
import dask
import dask.array
Expand Down Expand Up @@ -36,6 +37,8 @@
(mtd.DA_FCST_SC2, mtd.DA_OBS_SC2, 2, None, "lower", mtd.EXP_SC_CASE4),
# Test upper/left assignment
(mtd.DA_FCST_SC2, mtd.DA_OBS_SC2, 2, None, "upper", mtd.EXP_SC_CASE5),
# Threshold xr.Datarray, discount = 0, preserve all dims
(mtd.DA_FCST_SC, mtd.DA_OBS_SC, mtd.DA_THRESHOLD_SC, 0, "lower", mtd.EXP_SC_CASE6),
],
)
def test__single_category_score(fcst, obs, categorical_threshold, discount_distance, threshold_assignment, expected):
Expand Down Expand Up @@ -228,6 +231,30 @@ def test__single_category_score(fcst, obs, categorical_threshold, discount_dista
0,
mtd.EXP_FIRM_CASE6,
),
# 2 categories defined with xr.DataArrays for threhsolds that don't vary by coord
(
mtd.DA_FCST_FIRM,
mtd.DA_OBS_FIRM,
0.7,
mtd.DA_THRESHOLD_FIRM,
[1, 1],
None,
["i", "j", "k"],
0,
mtd.EXP_FIRM_CASE3,
),
# 2 categories defined with xr.DataArrays for threhsolds that do vary
(
mtd.DA_FCST_FIRM,
mtd.DA_OBS_FIRM,
0.7,
mtd.DA_THRESHOLD_FIRM2,
[1, 1],
None,
["i", "j", "k"],
0,
mtd.EXP_FIRM_CASE3,
),
],
)
def test_firm(
Expand Down
11 changes: 5 additions & 6 deletions tutorials/FIRM.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,15 @@
"# Create observations for 100 dates\n",
"obs = 50 * np.random.random((100, 100))\n",
"obs = xr.DataArray(\n",
" data=obs, \n",
" dims=[\"time\", \"x\"],\n",
" coords={\"time\": pd.date_range(\"2023-01-01\", \"2023-04-10\"), \"x\": np.arange(0, 100)}\n",
" data=obs, dims=[\"time\", \"x\"], coords={\"time\": pd.date_range(\"2023-01-01\", \"2023-04-10\"), \"x\": np.arange(0, 100)}\n",
")\n",
"\n",
"# Create forecasts for 7 lead days\n",
"fcst = xr.DataArray(data=[1]*7, dims=\"lead_day\", coords={\"lead_day\": np.arange(1, 8)})\n",
"fcst = xr.DataArray(data=[1] * 7, dims=\"lead_day\", coords={\"lead_day\": np.arange(1, 8)})\n",
"fcst = fcst * obs\n",
"\n",
"# Create two forecasts. Forecast A has no bias compared to the observations, but \n",
"# Forecast B is biased upwards to align better with forecast directive of forecast \n",
"# Create two forecasts. Forecast A has no bias compared to the observations, but\n",
"# Forecast B is biased upwards to align better with forecast directive of forecast\n",
"# the highest category that has at least a 30% chance of occuring\n",
"noise = 4 * norm.rvs(size=(7, 100, 100))\n",
"fcst_a = fcst + noise\n",
Expand Down Expand Up @@ -938,6 +936,7 @@
"metadata": {},
"source": [
"## Further extensions\n",
"- Instead of passing in a list of floats for the threshold arg, pass in a list of xr.DataArrays where the thresholds vary spatially and are based on climatological values \n",
"- Test the impact of varying the risk threshold\n",
"- Solve for the optimal way to bias the forecasts to optimize the verification score\n",
"- Test the impact of the discount distance parameter\n",
Expand Down

0 comments on commit 2480bc1

Please sign in to comment.