Skip to content

Commit

Permalink
add chaining_func_kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasloveday committed Sep 9, 2024
1 parent d691902 commit 11692bc
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions src/scores/probability/crps_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
"""

from collections.abc import Iterable
from typing import Callable, Literal, Optional, Sequence, Union
from typing import Any, Callable, Literal, Optional, Sequence, Union

import numpy as np
import pandas as pd
import xarray as xr

import scores.utils
from scores.probability.checks import coords_increasing
from scores.processing import broadcast_and_match_nan
from scores.processing.cdf import (
add_thresholds,
cdf_envelope,
Expand Down Expand Up @@ -882,6 +883,7 @@ def tw_crps_for_ensemble(
ensemble_member_dim: str,
chaining_func: Callable[[XarrayLike], XarrayLike],
*, # Force keywords arguments to be keyword-only
chainging_func_kwargs: Optional[dict[str, Any]] = {},
method: Literal["ecdf", "fair"] = "ecdf",
reduce_dims: Optional[Sequence[str]] = None,
preserve_dims: Optional[Sequence[str]] = None,
Expand Down Expand Up @@ -940,6 +942,7 @@ def tw_crps_for_ensemble(
ensemble_member_dim: the dimension that specifies the ensemble member or the sample
from the predictive distribution.
chaining_func: the chaining function.
chainging_func_kwargs: keyword arguments for the chaining function.
method: Either "ecdf" for the empirical twCRPS or "fair" for the Fair twCRPS.
reduce_dims: Dimensions to reduce. Can be "all" to reduce all dimensions.
preserve_dims: Dimensions to preserve. Can be "all" to preserve all dimensions.
Expand Down Expand Up @@ -981,9 +984,9 @@ def tw_crps_for_ensemble(
>>> tw_crps_for_ensemble(fcst, obs, 'ensemble', lambda x: np.maximum(x, 0.5))
"""
obs = chaining_func(obs, **chainging_func_kwargs)
fcst = chaining_func(fcst, **chainging_func_kwargs)

fcst = chaining_func(fcst)
obs = chaining_func(obs)
result = crps_for_ensemble(
fcst,
obs,
Expand Down Expand Up @@ -1014,7 +1017,6 @@ def tail_tw_crps_for_ensemble(
A threshold weight of 1 is assigned for values of the tail and a threshold weight of 0 otherwise.
The threshold value of where the tail begins is specified by the ``threshold`` argument.
The tail does not include the threshold value itself.
The ``tail`` argument specifies whether the tail is the upper or lower tail.
For example, if we only care about values above 40 degrees C, we can set ``threshold=40`` and ``tail="upper"``.
Expand Down Expand Up @@ -1072,19 +1074,20 @@ def tail_tw_crps_for_ensemble(
raise ValueError(f"'{tail}' is not one of 'upper' or 'lower'")
if tail == "upper":

def _vfunc(x):
def _vfunc(x, threshold=threshold):
return np.maximum(x, threshold)

else:

def _vfunc(x):
def _vfunc(x, threshold=threshold):
return np.minimum(x, threshold)

result = tw_crps_for_ensemble(
fcst,
obs,
ensemble_member_dim,
_vfunc,
chainging_func_kwargs={"threshold": threshold},
method=method,
reduce_dims=reduce_dims,
preserve_dims=preserve_dims,
Expand Down

0 comments on commit 11692bc

Please sign in to comment.