Skip to content

Commit

Permalink
add examples
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Aug 29, 2024
1 parent 75510e9 commit d70e77e
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions pymc_marketing/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,24 @@ def per_observation_crps(y_true: npt.NDArray, y_pred: npt.NDArray) -> npt.NDArra
array-like
The CRPS for each observation.
Examples
--------
.. code-block:: python
import numpy as np
from pymc_marketing.metrics import per_observation_crps
# y_true shape is (3,)
y_true = np.array([1, 1, 1])
# y_pred shape is (10, 3). The extra dimension on the left is the number of samples.
y_pred = np.repeat(np.array([[0, 1, 0]]), 10, axis=0)
# The result has shape (3,), one value per observation.
per_observation_crps(y_true, y_pred)
>> array([1., 0., 1.])
References
----------
- This implementation is a minimal adaptation from the one in the Pyro project: https://docs.pyro.ai/en/dev/_modules/pyro/ops/stats.html#crps_empirical
Expand Down Expand Up @@ -85,6 +103,24 @@ def crps(
float
The CRPS value as a (possibly weighted) average of the per-observation CRPS values.
Examples
--------
.. code-block:: python
import numpy as np
from pymc_marketing.metrics import crps
# y_true shape is (3,)
y_true = np.array([1, 1, 1])
# y_pred shape is (10, 3). The extra dimension on the left is the number of samples.
y_pred = np.repeat(np.array([[0, 1, 0]]), 10, axis=0)
# The result is a scalar.
crps(y_true, y_pred)
>> 0.666
References
----------
- This implementation is a minimal adaptation from the one in the Pyro project: https://docs.pyro.ai/en/dev/_modules/pyro/ops/stats.html#crps_empirical
Expand Down

0 comments on commit d70e77e

Please sign in to comment.