Skip to content

Commit

Permalink
feat(test_fourier.py): new tests after removing today from params in …
Browse files Browse the repository at this point in the history
…fourier.py
  • Loading branch information
Ishaanjolly committed Sep 30, 2024
1 parent a31e716 commit b4f5cf5
Showing 1 changed file with 45 additions and 36 deletions.
81 changes: 45 additions & 36 deletions tests/mmm/test_fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import xarray as xr

from pymc_marketing.mmm.fourier import (
FourierBase,
MonthlyFourier,
YearlyFourier,
generate_fourier_modes,
Expand Down Expand Up @@ -282,71 +283,79 @@ def test_serialization_to_json() -> None:

@pytest.fixture
def yearly_fourier() -> YearlyFourier:
prior = Prior("Normal", dims=("fourier",))
return YearlyFourier(n_order=2, prior=prior)
prior = Prior("Laplace", mu=0, b=1, dims="fourier")
return YearlyFourier(n_order=2, days_in_period=365, prior=prior)


@pytest.fixture
def monthly_fourier() -> MonthlyFourier:
prior = Prior("Normal", dims=("fourier",))
return MonthlyFourier(n_order=2, prior=prior)
prior = Prior("Laplace", mu=0, b=1, dims="fourier")
return MonthlyFourier(n_order=2, days_in_period=30, prior=prior)


def test_get_default_start_date_none(yearly_fourier) -> None:
today = datetime.datetime(2023, 5, 15)
expected_start_date = datetime.datetime(2023, 1, 1)
assert yearly_fourier.get_default_start_date(today, None) == expected_start_date
def test_get_default_start_date_none_yearly(yearly_fourier: YearlyFourier):
current_year = datetime.datetime.now().year
expected_start_date = datetime.datetime(year=current_year, month=1, day=1)
actual_start_date = yearly_fourier.get_default_start_date()
assert actual_start_date == expected_start_date


def test_get_default_start_date_str(yearly_fourier) -> None:
today = datetime.datetime.now()
def test_get_default_start_date_none_monthly(monthly_fourier: MonthlyFourier):
now = datetime.datetime.now()
expected_start_date = datetime.datetime(year=now.year, month=now.month, day=1)
actual_start_date = monthly_fourier.get_default_start_date()
assert actual_start_date == expected_start_date


def test_get_default_start_date_str_yearly(yearly_fourier: YearlyFourier):
start_date_str = "2023-02-01"
assert (
yearly_fourier.get_default_start_date(today, start_date_str) == start_date_str
)
actual_start_date = yearly_fourier.get_default_start_date(start_date=start_date_str)
assert actual_start_date == start_date_str


def test_get_default_start_date_datetime(yearly_fourier) -> None:
today = datetime.datetime.now()
def test_get_default_start_date_datetime_yearly(yearly_fourier: YearlyFourier):
start_date_dt = datetime.datetime(2023, 3, 1)
assert yearly_fourier.get_default_start_date(today, start_date_dt) == start_date_dt
actual_start_date = yearly_fourier.get_default_start_date(start_date=start_date_dt)
assert actual_start_date == start_date_dt


def test_get_default_start_date_invalid_type(yearly_fourier) -> None:
today = datetime.datetime.now()
invalid_start_date = 12345 # Invalid type
def test_get_default_start_date_invalid_type_yearly(yearly_fourier: YearlyFourier):
invalid_start_date = 12345 # Invalid type again
with pytest.raises(TypeError) as exc_info:
yearly_fourier.get_default_start_date(today, invalid_start_date)
yearly_fourier.get_default_start_date(start_date=invalid_start_date)
assert "start_date must be a datetime.datetime object, a string, or None" in str(
exc_info.value
)


def test_get_default_start_date_none_monthly(monthly_fourier) -> None:
today = datetime.datetime(2023, 5, 15)
expected_start_date = datetime.datetime(2023, 5, 1)
assert monthly_fourier.get_default_start_date(today, None) == expected_start_date


def test_get_default_start_date_str_monthly(monthly_fourier):
today = datetime.datetime.now()
def test_get_default_start_date_str_monthly(monthly_fourier: MonthlyFourier):
start_date_str = "2023-06-15"
assert (
monthly_fourier.get_default_start_date(today, start_date_str) == start_date_str
actual_start_date = monthly_fourier.get_default_start_date(
start_date=start_date_str
)
assert actual_start_date == start_date_str


def test_get_default_start_date_datetime_monthly(monthly_fourier) -> None:
today = datetime.datetime.now()
def test_get_default_start_date_datetime_monthly(monthly_fourier: MonthlyFourier):
start_date_dt = datetime.datetime(2023, 7, 1)
assert monthly_fourier.get_default_start_date(today, start_date_dt) == start_date_dt
actual_start_date = monthly_fourier.get_default_start_date(start_date=start_date_dt)
assert actual_start_date == start_date_dt


def test_get_default_start_date_invalid_type_monthly(monthly_fourier) -> None:
today = datetime.datetime.now()
def test_get_default_start_date_invalid_type_monthly(monthly_fourier: MonthlyFourier):
invalid_start_date = [2023, 1, 1]
with pytest.raises(TypeError) as exc_info:
monthly_fourier.get_default_start_date(today, invalid_start_date)
monthly_fourier.get_default_start_date(start_date=invalid_start_date)
assert "start_date must be a datetime.datetime object, a string, or None" in str(
exc_info.value
)


def test_fourier_base_instantiation():
with pytest.raises(TypeError) as exc_info:
FourierBase(
n_order=2,
days_in_period=365,
prior=Prior("Laplace", mu=0, b=1, dims="fourier"),
)
assert "Can't instantiate abstract class FourierBase" in str(exc_info.value)

0 comments on commit b4f5cf5

Please sign in to comment.