diff --git a/tests/mmm/test_fourier.py b/tests/mmm/test_fourier.py index 64120fc8..d10c18b6 100644 --- a/tests/mmm/test_fourier.py +++ b/tests/mmm/test_fourier.py @@ -20,6 +20,7 @@ import xarray as xr from pymc_marketing.mmm.fourier import ( + FourierBase, MonthlyFourier, YearlyFourier, generate_fourier_modes, @@ -282,71 +283,79 @@ def test_serialization_to_json() -> None: @pytest.fixture def yearly_fourier() -> YearlyFourier: - prior = Prior("Normal", dims=("fourier",)) - return YearlyFourier(n_order=2, prior=prior) + prior = Prior("Laplace", mu=0, b=1, dims="fourier") + return YearlyFourier(n_order=2, days_in_period=365, prior=prior) @pytest.fixture def monthly_fourier() -> MonthlyFourier: - prior = Prior("Normal", dims=("fourier",)) - return MonthlyFourier(n_order=2, prior=prior) + prior = Prior("Laplace", mu=0, b=1, dims="fourier") + return MonthlyFourier(n_order=2, days_in_period=30, prior=prior) -def test_get_default_start_date_none(yearly_fourier) -> None: - today = datetime.datetime(2023, 5, 15) - expected_start_date = datetime.datetime(2023, 1, 1) - assert yearly_fourier.get_default_start_date(today, None) == expected_start_date +def test_get_default_start_date_none_yearly(yearly_fourier: YearlyFourier): + current_year = datetime.datetime.now().year + expected_start_date = datetime.datetime(year=current_year, month=1, day=1) + actual_start_date = yearly_fourier.get_default_start_date() + assert actual_start_date == expected_start_date -def test_get_default_start_date_str(yearly_fourier) -> None: - today = datetime.datetime.now() +def test_get_default_start_date_none_monthly(monthly_fourier: MonthlyFourier): + now = datetime.datetime.now() + expected_start_date = datetime.datetime(year=now.year, month=now.month, day=1) + actual_start_date = monthly_fourier.get_default_start_date() + assert actual_start_date == expected_start_date + + +def test_get_default_start_date_str_yearly(yearly_fourier: YearlyFourier): start_date_str = "2023-02-01" - assert ( - yearly_fourier.get_default_start_date(today, start_date_str) == start_date_str - ) + actual_start_date = yearly_fourier.get_default_start_date(start_date=start_date_str) + assert actual_start_date == start_date_str -def test_get_default_start_date_datetime(yearly_fourier) -> None: - today = datetime.datetime.now() +def test_get_default_start_date_datetime_yearly(yearly_fourier: YearlyFourier): start_date_dt = datetime.datetime(2023, 3, 1) - assert yearly_fourier.get_default_start_date(today, start_date_dt) == start_date_dt + actual_start_date = yearly_fourier.get_default_start_date(start_date=start_date_dt) + assert actual_start_date == start_date_dt -def test_get_default_start_date_invalid_type(yearly_fourier) -> None: - today = datetime.datetime.now() - invalid_start_date = 12345 # Invalid type +def test_get_default_start_date_invalid_type_yearly(yearly_fourier: YearlyFourier): + invalid_start_date = 12345 # Invalid type again with pytest.raises(TypeError) as exc_info: - yearly_fourier.get_default_start_date(today, invalid_start_date) + yearly_fourier.get_default_start_date(start_date=invalid_start_date) assert "start_date must be a datetime.datetime object, a string, or None" in str( exc_info.value ) -def test_get_default_start_date_none_monthly(monthly_fourier) -> None: - today = datetime.datetime(2023, 5, 15) - expected_start_date = datetime.datetime(2023, 5, 1) - assert monthly_fourier.get_default_start_date(today, None) == expected_start_date - - -def test_get_default_start_date_str_monthly(monthly_fourier): - today = datetime.datetime.now() +def test_get_default_start_date_str_monthly(monthly_fourier: MonthlyFourier): start_date_str = "2023-06-15" - assert ( - monthly_fourier.get_default_start_date(today, start_date_str) == start_date_str + actual_start_date = monthly_fourier.get_default_start_date( + start_date=start_date_str ) + assert actual_start_date == start_date_str -def test_get_default_start_date_datetime_monthly(monthly_fourier) -> None: - today = datetime.datetime.now() +def test_get_default_start_date_datetime_monthly(monthly_fourier: MonthlyFourier): start_date_dt = datetime.datetime(2023, 7, 1) - assert monthly_fourier.get_default_start_date(today, start_date_dt) == start_date_dt + actual_start_date = monthly_fourier.get_default_start_date(start_date=start_date_dt) + assert actual_start_date == start_date_dt -def test_get_default_start_date_invalid_type_monthly(monthly_fourier) -> None: - today = datetime.datetime.now() +def test_get_default_start_date_invalid_type_monthly(monthly_fourier: MonthlyFourier): invalid_start_date = [2023, 1, 1] with pytest.raises(TypeError) as exc_info: - monthly_fourier.get_default_start_date(today, invalid_start_date) + monthly_fourier.get_default_start_date(start_date=invalid_start_date) assert "start_date must be a datetime.datetime object, a string, or None" in str( exc_info.value ) + + +def test_fourier_base_instantiation(): + with pytest.raises(TypeError) as exc_info: + FourierBase( + n_order=2, + days_in_period=365, + prior=Prior("Laplace", mu=0, b=1, dims="fourier"), + ) + assert "Can't instantiate abstract class FourierBase" in str(exc_info.value)