Skip to content

Commit

Permalink
[BUG] remove zero from pooling_sizes and downsample_frequencies (#1577)
Browse files Browse the repository at this point in the history
This PR partly fixes #1571 for NHiTS.

If `max_prediction_length = 1`,  pooling_sizes and downsample_frequencies will contain zero, which causes divided by zero error. This PR replace zero by one.
  • Loading branch information
XinyuWuu authored Aug 27, 2024
1 parent 119fa89 commit 4df80af
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
4 changes: 4 additions & 0 deletions pytorch_forecasting/models/nhits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,12 @@ def __init__(
if pooling_sizes is None:
pooling_sizes = np.exp2(np.round(np.linspace(0.49, np.log2(prediction_length / 2), n_stacks)))
pooling_sizes = [int(x) for x in pooling_sizes[::-1]]
# remove zero from pooling_sizes
pooling_sizes = max(pooling_sizes, [1] * len(pooling_sizes))
if downsample_frequencies is None:
downsample_frequencies = [min(prediction_length, int(np.power(x, 1.5))) for x in pooling_sizes]
# remove zero from downsample_frequencies
downsample_frequencies = max(downsample_frequencies, [1] * len(downsample_frequencies))

# set static hidden size
if static_hidden_size is None:
Expand Down
55 changes: 54 additions & 1 deletion tests/test_models/test_nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
import numpy as np
import pandas as pd
import pytest

from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
from pytorch_forecasting.metrics import MQF2DistributionLoss, QuantileLoss
from pytorch_forecasting.metrics.distributions import ImplicitQuantileNetworkDistributionLoss
from pytorch_forecasting.models import NHiTS
Expand Down Expand Up @@ -70,7 +73,12 @@ def _integration(dataloader, tmp_path, trainer_kwargs=None, **kwargs):
finally:
shutil.rmtree(tmp_path, ignore_errors=True)

net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True)
net.predict(
val_dataloader,
fast_dev_run=True,
return_index=True,
return_decoder_lengths=True,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -143,3 +151,48 @@ def test_interpretation(model, dataloaders_with_covariates):
raw_predictions = model.predict(dataloaders_with_covariates["val"], mode="raw", return_x=True, fast_dev_run=True)
model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=0, add_loss_to_title=True)
model.plot_interpretation(raw_predictions.x, raw_predictions.output, idx=0)


# Bug when max_prediction_length=1 #1571
@pytest.mark.parametrize("max_prediction_length", [1, 5])
def test_prediction_length(max_prediction_length: int):
n_timeseries = 10
time_points = 10
data = pd.DataFrame(
data={
"target": np.random.rand(time_points * n_timeseries),
"time_varying_known_real_1": np.random.rand(time_points * n_timeseries),
"time_idx": np.tile(np.arange(time_points), n_timeseries),
"group_id": np.repeat(np.arange(n_timeseries), time_points),
}
)
training_dataset = TimeSeriesDataSet(
data=data,
time_idx="time_idx",
target="target",
group_ids=["group_id"],
time_varying_unknown_reals=["target"],
time_varying_known_reals=(["time_varying_known_real_1"]),
max_prediction_length=max_prediction_length,
max_encoder_length=3,
)
training_data_loader = training_dataset.to_dataloader(train=True)
forecaster = NHiTS.from_dataset(training_dataset, log_val_interval=1)
trainer = pl.Trainer(
accelerator="cpu",
max_epochs=3,
min_epochs=2,
limit_train_batches=10,
)
trainer.fit(
forecaster,
train_dataloaders=training_data_loader,
)
validation_dataset = TimeSeriesDataSet.from_dataset(training_dataset, data, stop_randomization=True, predict=True)
validation_data_loader = validation_dataset.to_dataloader(train=False)
forecaster.predict(
validation_data_loader,
fast_dev_run=True,
return_index=True,
return_decoder_lengths=True,
)

0 comments on commit 4df80af

Please sign in to comment.