Skip to content

Commit

Permalink
added dm test stat tutorial notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasloveday committed Aug 23, 2023
1 parent 8918027 commit 7fbb5c3
Show file tree
Hide file tree
Showing 5 changed files with 643 additions and 16 deletions.
5 changes: 0 additions & 5 deletions src/scores/stats/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +0,0 @@
"""
Import the functions from the implementations into the public API
"""

from .confidence_intervals_impl import dm_test_stats
5 changes: 5 additions & 0 deletions src/scores/stats/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
Import the functions from the implementations into the public API
"""

from .diebold_mariano_impl import diebold_mariano
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Functions for calculating confidence statistics
Functions for calculating a modified Deibold-Mariano test statistic
"""
from typing import Literal

Expand All @@ -12,7 +12,7 @@
from scores.utils import dims_complement


def dm_test_stats(
def diebold_mariano(
da_timeseries: xr.DataArray,
ts_dim: str,
h_coord: str,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""
This module contains unit tests for scores.stats.confidence_intervals
This module contains unit tests for scores.stats.tests.diebold_mariano_impl
"""
import numpy as np
import pytest
import xarray as xr

from scores.stats.confidence_intervals_impl import (
from scores.stats.tests.diebold_mariano_impl import (
_dm_gamma_hat_k,
_dm_test_statistic,
_dm_v_hat,
_hg_func,
_hg_method_stat,
_hln_method_stat,
dm_test_stats,
diebold_mariano,
)


Expand Down Expand Up @@ -161,7 +161,7 @@
),
],
)
def test_dm_test_stats_raises(
def test_diebold_mariano_raises(
da_timeseries,
ts_dim,
h_coord,
Expand All @@ -170,9 +170,9 @@ def test_dm_test_stats_raises(
statistic_distribution,
error_msg,
):
"""Tests that dm_test_stats raises a ValueError as expected."""
"""Tests that diebold_mariano raises a ValueError as expected."""
with pytest.raises(ValueError, match=error_msg):
dm_test_stats(
diebold_mariano(
da_timeseries,
ts_dim,
h_coord,
Expand Down Expand Up @@ -345,8 +345,8 @@ def test__dm_test_statistic_raises(diff, h, method, error_msg):
("normal", DM_TEST_STATS_NORMAL_EXP),
],
)
def test_dm_test_stats(distribution, expected):
"""Tests that dm_test_stats gives results as expected."""
def test_diebold_mariano(distribution, expected):
"""Tests that diebold_mariano gives results as expected."""
da_timeseries = xr.DataArray(
data=[[1, 2, 3.0, 4, np.nan], [2.0, 1, -3, -1, 0], [1.0, 1, 1, 1, 1]],
dims=["lead_day", "valid_date"],
Expand All @@ -356,7 +356,7 @@ def test_dm_test_stats(distribution, expected):
"h": ("lead_day", [2, 3, 4]),
},
)
result = dm_test_stats(
result = diebold_mariano(
da_timeseries,
"lead_day",
"h",
Expand Down
Loading

0 comments on commit 7fbb5c3

Please sign in to comment.