diff --git a/pytorch_forecasting/models/nhits/__init__.py b/pytorch_forecasting/models/nhits/__init__.py index a856afca..6af643a1 100644 --- a/pytorch_forecasting/models/nhits/__init__.py +++ b/pytorch_forecasting/models/nhits/__init__.py @@ -154,8 +154,12 @@ def __init__( if pooling_sizes is None: pooling_sizes = np.exp2(np.round(np.linspace(0.49, np.log2(prediction_length / 2), n_stacks))) pooling_sizes = [int(x) for x in pooling_sizes[::-1]] + # remove zero from pooling_sizes + pooling_sizes = max(pooling_sizes, [1] * len(pooling_sizes)) if downsample_frequencies is None: downsample_frequencies = [min(prediction_length, int(np.power(x, 1.5))) for x in pooling_sizes] + # remove zero from downsample_frequencies + downsample_frequencies = max(downsample_frequencies, [1] * len(downsample_frequencies)) # set static hidden size if static_hidden_size is None: diff --git a/tests/test_models/test_nhits.py b/tests/test_models/test_nhits.py index 8698a345..2c1bbe29 100644 --- a/tests/test_models/test_nhits.py +++ b/tests/test_models/test_nhits.py @@ -4,8 +4,11 @@ import lightning.pytorch as pl from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.loggers import TensorBoardLogger +import numpy as np +import pandas as pd import pytest +from pytorch_forecasting.data.timeseries import TimeSeriesDataSet from pytorch_forecasting.metrics import MQF2DistributionLoss, QuantileLoss from pytorch_forecasting.metrics.distributions import ImplicitQuantileNetworkDistributionLoss from pytorch_forecasting.models import NHiTS @@ -70,7 +73,12 @@ def _integration(dataloader, tmp_path, trainer_kwargs=None, **kwargs): finally: shutil.rmtree(tmp_path, ignore_errors=True) - net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True) + net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + ) @pytest.mark.parametrize( @@ -143,3 +151,48 @@ def test_interpretation(model, dataloaders_with_covariates): raw_predictions = model.predict(dataloaders_with_covariates["val"], mode="raw", return_x=True, fast_dev_run=True) model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=0, add_loss_to_title=True) model.plot_interpretation(raw_predictions.x, raw_predictions.output, idx=0) + + +# Bug when max_prediction_length=1 #1571 +@pytest.mark.parametrize("max_prediction_length", [1, 5]) +def test_prediction_length(max_prediction_length: int): + n_timeseries = 10 + time_points = 10 + data = pd.DataFrame( + data={ + "target": np.random.rand(time_points * n_timeseries), + "time_varying_known_real_1": np.random.rand(time_points * n_timeseries), + "time_idx": np.tile(np.arange(time_points), n_timeseries), + "group_id": np.repeat(np.arange(n_timeseries), time_points), + } + ) + training_dataset = TimeSeriesDataSet( + data=data, + time_idx="time_idx", + target="target", + group_ids=["group_id"], + time_varying_unknown_reals=["target"], + time_varying_known_reals=(["time_varying_known_real_1"]), + max_prediction_length=max_prediction_length, + max_encoder_length=3, + ) + training_data_loader = training_dataset.to_dataloader(train=True) + forecaster = NHiTS.from_dataset(training_dataset, log_val_interval=1) + trainer = pl.Trainer( + accelerator="cpu", + max_epochs=3, + min_epochs=2, + limit_train_batches=10, + ) + trainer.fit( + forecaster, + train_dataloaders=training_data_loader, + ) + validation_dataset = TimeSeriesDataSet.from_dataset(training_dataset, data, stop_randomization=True, predict=True) + validation_data_loader = validation_dataset.to_dataloader(train=False) + forecaster.predict( + validation_data_loader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + )