From ff1f8e23de07acffb919acfc17585fd27e95c27d Mon Sep 17 00:00:00 2001 From: Syama Sundar Rangapuram Date: Wed, 24 Apr 2024 15:41:15 +0200 Subject: [PATCH] Add Seasonal Aggregate Predictor (#3162) --- src/gluonts/model/seasonal_agg/__init__.py | 16 + src/gluonts/model/seasonal_agg/_predictor.py | 133 ++++++ test/model/seasonal_agg/test_seasonal_agg.py | 402 +++++++++++++++++++ 3 files changed, 551 insertions(+) create mode 100644 src/gluonts/model/seasonal_agg/__init__.py create mode 100644 src/gluonts/model/seasonal_agg/_predictor.py create mode 100644 test/model/seasonal_agg/test_seasonal_agg.py diff --git a/src/gluonts/model/seasonal_agg/__init__.py b/src/gluonts/model/seasonal_agg/__init__.py new file mode 100644 index 0000000000..845871551b --- /dev/null +++ b/src/gluonts/model/seasonal_agg/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from ._predictor import SeasonalAggregatePredictor + +__all__ = ["SeasonalAggregatePredictor"] diff --git a/src/gluonts/model/seasonal_agg/_predictor.py b/src/gluonts/model/seasonal_agg/_predictor.py new file mode 100644 index 0000000000..5f41fc2789 --- /dev/null +++ b/src/gluonts/model/seasonal_agg/_predictor.py @@ -0,0 +1,133 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import Callable, Union + +import numpy as np + +from gluonts.core.component import validated +from gluonts.dataset.common import DataEntry +from gluonts.dataset.util import forecast_start +from gluonts.dataset.field_names import FieldName +from gluonts.model.forecast import Forecast, SampleForecast +from gluonts.model.predictor import RepresentablePredictor +from gluonts.transform.feature import ( + LastValueImputation, + MissingValueImputation, +) + + +class SeasonalAggregatePredictor(RepresentablePredictor): + """ + Seasonal aggegate forecaster. + + For each time series :math:`y`, this predictor produces a forecast + :math:`\\tilde{y}(T+k) = f\big(y(T+k-h), y(T+k-2h), ..., + y(T+k-mh)\big)`, where :math:`T` is the forecast time, + :math:`k = 0, ...,` `prediction_length - 1`, :math:`m =`num_seasons`, + :math:`h =`season_length` and :math:`f =`agg_fun`. + + If `prediction_length > season_length` :math:\times `num_seasons`, then the + seasonal aggregate is repeated multiple times. If a time series is shorter + than season_length` :math:\times `num_seasons`, then the `agg_fun` is + applied to the full time series. + + Parameters + ---------- + prediction_length + Number of time points to predict. + season_length + Seasonality used to make predictions. If this is an integer, then a + fixed sesasonlity is applied; if this is a function, then it will be + called on each given entry's ``freq`` attribute of the ``"start"`` + field, and the returned seasonality will be used. + num_seasons + Number of seasons to aggregate. + agg_fun + Aggregate function. + imputation_method + The imputation method to use in case of missing values. + Defaults to :py:class:`LastValueImputation` which replaces each missing + value with the last value that was not missing. + """ + + @validated() + def __init__( + self, + prediction_length: int, + season_length: Union[int, Callable], + num_seasons: int, + agg_fun: Callable = np.nanmean, + imputation_method: MissingValueImputation = LastValueImputation(), + ) -> None: + super().__init__(prediction_length=prediction_length) + + assert ( + not isinstance(season_length, int) or season_length > 0 + ), "The value of `season_length` should be > 0" + + assert ( + isinstance(num_seasons, int) and num_seasons > 0 + ), "The value of `num_seasons` should be > 0" + + self.prediction_length = prediction_length + self.season_length = season_length + self.num_seasons = num_seasons + self.agg_fun = agg_fun + self.imputation_method = imputation_method + + def predict_item(self, item: DataEntry) -> Forecast: + if isinstance(self.season_length, int): + season_length = self.season_length + else: + season_length = self.season_length(item["start"].freq) + + target = np.asarray(item[FieldName.TARGET], np.float32) + len_ts = len(target) + forecast_start_time = forecast_start(item) + + assert ( + len_ts >= 1 + ), "all time series should have at least one data point" + + if np.isnan(target).any(): + target = target.copy() + target = self.imputation_method(target) + + if len_ts >= season_length * self.num_seasons: + # `indices` here is a 2D array where each row collects indices + # from one of the past seasons. The first row is identical to the + # one in `seasonal_naive` and the subsequent rows are similar + # except that the indices are taken from a different past season. + indices = [ + [ + len_ts - (j + 1) * season_length + k % season_length + for k in range(self.prediction_length) + ] + for j in range(self.num_seasons) + ] + samples = self.agg_fun(target[indices], axis=0).reshape( + (1, self.prediction_length) + ) + else: + samples = np.full( + shape=(1, self.prediction_length), + fill_value=self.agg_fun(target), + ) + + return SampleForecast( + samples=samples, + start_date=forecast_start_time, + item_id=item.get("item_id", None), + info=item.get("info", None), + ) diff --git a/test/model/seasonal_agg/test_seasonal_agg.py b/test/model/seasonal_agg/test_seasonal_agg.py new file mode 100644 index 0000000000..321116ce07 --- /dev/null +++ b/test/model/seasonal_agg/test_seasonal_agg.py @@ -0,0 +1,402 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import numpy as np +import pandas as pd +import pytest + +from gluonts.model.seasonal_agg import SeasonalAggregatePredictor +from gluonts.transform.feature import LastValueImputation, LeavesMissingValues + +FREQ = "D" +START_DATE = "2023" + + +def get_prediction( + target, + prediction_length=1, + season_length=1, + num_seasons=1, + agg_fun=np.nanmean, + imputation_method=LastValueImputation(), +): + pred = SeasonalAggregatePredictor( + prediction_length=prediction_length, + season_length=season_length, + num_seasons=num_seasons, + agg_fun=agg_fun, + imputation_method=imputation_method, + ) + item = { + "target": np.asarray(target), + "start": pd.Period(START_DATE, freq=FREQ), + } + forecast = pred.predict_item(item) + + return forecast + + +@pytest.mark.parametrize( + "data, expected_output, prediction_length, season_length, " + "num_seasons, agg_fun, imputation_method", + [ + # same as seasonal naive + ([1, 1, 1], [1], 1, 1, 1, np.nanmean, LastValueImputation()), + ( + [1, 10, 2, 20], + [1.5, 15], + 2, + 2, + 2, + np.nanmean, + LastValueImputation(), + ), + # check predictions repeat seasonally + ( + [1, 10, 2, 20], + [1.5, 15, 1.5, 15], + 4, + 2, + 2, + np.nanmean, + LastValueImputation(), + ), + ( + [1, 10, 2, 20], + [1.5, 15, 1.5], + 3, + 2, + 2, + np.nanmean, + LastValueImputation(), + ), + # check `nanmedian` + ( + [1, 10, 2, 20, 3, 30], + [2, 20, 2, 20], + 4, + 2, + 3, + np.nanmedian, + LastValueImputation(), + ), + ( + [1, 10, 2, 20, 3, 30], + [2, 20, 2], + 3, + 2, + 3, + np.nanmedian, + LastValueImputation(), + ), + # check `nanmax` + ( + [1, 10, 2, 20, 3, 30], + [3, 30, 3, 30], + 4, + 2, + 3, + np.nanmax, + LastValueImputation(), + ), + # check `nanmin` + ( + [1, 10, 2, 20, 3, 30], + [1, 10, 1, 10], + 4, + 2, + 3, + np.nanmin, + LastValueImputation(), + ), + # data is shorter than season length + ([1, 2, 3], [2], 1, 4, 1, np.nanmean, LastValueImputation()), + ([10, 1, 100], [10], 1, 4, 1, np.nanmedian, LastValueImputation()), + ([10, 1, 100], [100], 1, 4, 1, np.nanmax, LastValueImputation()), + ([10, 1, 100], [1], 1, 4, 1, np.nanmin, LastValueImputation()), + # data not available for all seasons + ([1, 2, 3, 4, 5], [3] * 4, 4, 4, 2, np.nanmean, LastValueImputation()), + ( + [10, 20, 40, 50, 21], + [21] * 4, + 4, + 4, + 2, + np.nanmedian, + LastValueImputation(), + ), + ( + [10, 20, 40, 50, 21], + [50] * 4, + 4, + 4, + 2, + np.nanmax, + LastValueImputation(), + ), + ( + [10, 20, 40, 50, 21], + [10] * 4, + 4, + 4, + 2, + np.nanmin, + LastValueImputation(), + ), + # missing values with imputation + ([np.nan], [0], 1, 1, 2, np.nanmean, LastValueImputation()), + ([np.nan], [0], 1, 1, 2, np.nanmedian, LastValueImputation()), + ([1, 4, np.nan], [3], 1, 3, 2, np.nanmean, LastValueImputation()), + ([1, 4, np.nan], [4], 1, 3, 2, np.nanmedian, LastValueImputation()), + ( + [1, 10, np.nan, 1, 10, np.nan], + [1, 10, 10], + 3, + 3, + 2, + np.nanmean, + LastValueImputation(), + ), + ( + [1, 10, np.nan, 1, 10, np.nan], + [1, 10, 10], + 3, + 3, + 2, + np.nanmedian, + LastValueImputation(), + ), + ( + [1, 10, np.nan, 1, 10, np.nan], + [1, 10, 10, 1, 10], + 5, + 3, + 2, + np.nanmax, + LastValueImputation(), + ), + ( + [1, 10, np.nan, 1, 10, np.nan], + [1, 10, 10, 1, 10], + 5, + 3, + 2, + np.nanmin, + LastValueImputation(), + ), + # missing values without imputation + ([1, 3, np.nan], [np.nan], 1, 1, 1, np.nanmean, LeavesMissingValues()), + ( + [1, 3, np.nan], + [np.nan], + 1, + 1, + 1, + np.nanmedian, + LeavesMissingValues(), + ), + ([1, 3, np.nan], [np.nan], 1, 1, 1, np.nanmax, LeavesMissingValues()), + ([1, 3, np.nan], [np.nan], 1, 1, 1, np.nanmin, LeavesMissingValues()), + ( + [1, 3, np.nan], + [np.nan] * 2, + 2, + 1, + 1, + np.nanmean, + LeavesMissingValues(), + ), + ( + [1, 3, np.nan], + [np.nan] * 2, + 2, + 1, + 1, + np.nanmedian, + LeavesMissingValues(), + ), + ( + [1, 3, np.nan], + [np.nan] * 2, + 2, + 1, + 1, + np.nanmax, + LeavesMissingValues(), + ), + ( + [1, 3, np.nan], + [np.nan] * 2, + 2, + 1, + 1, + np.nanmin, + LeavesMissingValues(), + ), + ([1, 3, np.nan], [3], 1, 1, 2, np.nanmean, LeavesMissingValues()), + ([1, 3, np.nan], [3], 1, 1, 2, np.nanmedian, LeavesMissingValues()), + ([1, 3, np.nan], [3], 1, 1, 2, np.nanmax, LeavesMissingValues()), + ([1, 3, np.nan], [3], 1, 1, 2, np.nanmin, LeavesMissingValues()), + ([1, 3, np.nan], [3], 1, 2, 1, np.nanmean, LeavesMissingValues()), + ([1, 3, np.nan], [3], 1, 2, 1, np.nanmedian, LeavesMissingValues()), + ([1, 3, np.nan], [3], 1, 2, 1, np.nanmax, LeavesMissingValues()), + ([1, 3, np.nan], [3], 1, 2, 1, np.nanmin, LeavesMissingValues()), + ( + [1, 3, np.nan], + [3, np.nan], + 2, + 2, + 1, + np.nanmean, + LeavesMissingValues(), + ), + ( + [1, 3, np.nan], + [3, np.nan], + 2, + 2, + 1, + np.nanmedian, + LeavesMissingValues(), + ), + ( + [1, 3, np.nan], + [3, np.nan], + 2, + 2, + 1, + np.nanmax, + LeavesMissingValues(), + ), + ( + [1, 3, np.nan], + [3, np.nan], + 2, + 2, + 1, + np.nanmin, + LeavesMissingValues(), + ), + # check if `nanmean` works when some seasons have missing values + ([1, 3, np.nan], [3, 3], 2, 1, 2, np.nanmean, LeavesMissingValues()), + ( + [1, 3, np.nan], + [3, 3, 3], + 3, + 1, + 2, + np.nanmean, + LeavesMissingValues(), + ), + # check if `mean` works when some seasons have missing values + ( + [1, 3, np.nan], + [np.nan] * 2, + 2, + 1, + 2, + np.mean, + LeavesMissingValues(), + ), + ( + [1, 3, np.nan], + [np.nan] * 3, + 3, + 1, + 2, + np.mean, + LeavesMissingValues(), + ), + # check if `nanmedian` works when some seasons have missing values + ([1, 3, np.nan], [3, 3], 2, 1, 2, np.nanmedian, LeavesMissingValues()), + ( + [1, 3, np.nan], + [3, 3, 3], + 3, + 1, + 2, + np.nanmedian, + LeavesMissingValues(), + ), + # check if `nanmax` works when some seasons have missing values + ([1, 3, np.nan], [3, 3], 2, 1, 2, np.nanmax, LeavesMissingValues()), + ([1, 3, np.nan], [3, 3, 3], 3, 1, 2, np.nanmax, LeavesMissingValues()), + # check if `nanmin` works when some seasons have missing values + ([1, 3, np.nan], [3, 3], 2, 1, 2, np.nanmin, LeavesMissingValues()), + ([1, 3, np.nan], [3, 3, 3], 3, 1, 2, np.nanmin, LeavesMissingValues()), + # check if `mean` works when some seasons have missing values + ( + [1, 3, np.nan], + [np.nan] * 2, + 2, + 1, + 2, + np.median, + LeavesMissingValues(), + ), + ( + [1, 3, np.nan], + [np.nan] * 3, + 3, + 1, + 2, + np.median, + LeavesMissingValues(), + ), + # check if `median` works when some seasons have missing values + ( + [1, 3, np.nan], + [np.nan] * 2, + 2, + 1, + 2, + np.median, + LeavesMissingValues(), + ), + ( + [1, 3, np.nan], + [np.nan] * 3, + 3, + 1, + 2, + np.median, + LeavesMissingValues(), + ), + # check if `max` works when some seasons have missing values + ([1, 3, np.nan], [np.nan] * 2, 2, 1, 2, np.max, LeavesMissingValues()), + ([1, 3, np.nan], [np.nan] * 3, 3, 1, 2, np.max, LeavesMissingValues()), + # check if `min` works when some seasons have missing values + ([1, 3, np.nan], [np.nan] * 2, 2, 1, 2, np.min, LeavesMissingValues()), + ([1, 3, np.nan], [np.nan] * 3, 3, 1, 2, np.min, LeavesMissingValues()), + ], +) +def test_predictor( + data, + expected_output, + prediction_length, + season_length, + num_seasons, + agg_fun, + imputation_method, +): + prediction = get_prediction( + data, + prediction_length=prediction_length, + season_length=season_length, + num_seasons=num_seasons, + agg_fun=agg_fun, + imputation_method=imputation_method, + ) + assert prediction.samples.shape == (1, prediction_length) + + np.testing.assert_equal(prediction.mean, expected_output)