Skip to content

Commit

Permalink
Merge pull request #890 from jdb78/feature/n-hits
Browse files Browse the repository at this point in the history
N-Hits
  • Loading branch information
jdb78 authored Mar 22, 2022
2 parents 33fbc60 + 8823531 commit eea8c16
Show file tree
Hide file tree
Showing 20 changed files with 1,112 additions and 67 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Release Notes

## v0.9.3 UNRELEASED
## v0.10.0 UNRELEASED

### Added

- Added new `N-HiTS` network that has consistently beaten `N-BEATS` (#890)
- Allow using [torchmetrics](https://torchmetrics.readthedocs.io/) as loss metrics (#776)
- Enable fitting `EncoderNormalizer()` with limited data history using `max_length` argument (#782)
- More flexible `MultiEmbedding()` with convenience `output_size` and `input_size` properties (#829)
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ The documentation provides a [comparison of available models](https://pytorch-fo
- [N-BEATS: Neural basis expansion analysis for interpretable time series forecasting](http://arxiv.org/abs/1905.10437)
which has (if used as ensemble) outperformed all other methods including ensembles of traditional statical
methods in the M4 competition. The M4 competition is arguably the most important benchmark for univariate time series forecasting.
- [N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting](http://arxiv.org/abs/2201.12886) which supports covariates and has consistently beaten N-BEATS. It is also particularly well-suited for long-horizon forecasting.
- [DeepAR: Probabilistic forecasting with autoregressive recurrent networks](https://www.sciencedirect.com/science/article/pii/S0169207019301888)
which is the one of the most popular forecasting algorithms and is often used as a baseline
- Simple standard networks for baselining: LSTM and GRU networks as well as a MLP on the decoder
Expand Down
7 changes: 6 additions & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ and you should take into account. Here is an overview over the pros and cons of
:py:class:`~pytorch_forecasting.models.rnn.RecurrentNetwork`, "x", "x", "x", "", "", "", "", "x", "", 2
:py:class:`~pytorch_forecasting.models.mlp.DecoderMLP`, "x", "x", "x", "x", "", "x", "", "x", "x", 1
:py:class:`~pytorch_forecasting.models.nbeats.NBeats`, "", "", "x", "", "", "", "", "", "", 1
:py:class:`~pytorch_forecasting.models.nhits.NHiTS`, "x", "x", "x", "", "", "", "", "", "", 1
:py:class:`~pytorch_forecasting.models.deepar.DeepAR`, "x", "x", "x", "", "x", "x", "", "x", "", 3
:py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`, "x", "x", "x", "x", "", "x", "", "x", "x", 4

Expand Down Expand Up @@ -85,6 +86,9 @@ multiple targets and even hetrogeneous targets where some are continuous variabl
i.e. regression and classification at the same time. :py:class:`~pytorch_forecasting.models.deepar.DeepAR`
can handle multiple targets but only works for regression tasks.

For long forecast horizon forecasts, :py:class:`~pytorch_forecasting.models.nhits.NHiTS` is an excellent choice
as it uses interpolation capabilities.

Supporting uncertainty
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -123,7 +127,8 @@ the lifetime of a model.

The :py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer` is
a rather large model but might benefit from being trained with.
For example, :py:class:`~pytorch_forecasting.models.nbeats.NBeats` is an efficient model.
For example, :py:class:`~pytorch_forecasting.models.nbeats.NBeats` or :py:class:`~pytorch_forecasting.models.nhits.NHiTS` are
efficient models.
Autoregressive models such as :py:class:`~pytorch_forecasting.models.deepar.DeepAR` might be quick to train
but might be slow at inference time (in case of :py:class:`~pytorch_forecasting.models.deepar.DeepAR` this is
driven by sampling results probabilistically multiple times, effectively increasing the computational burden linearly with the
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/ar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@
],
"source": [
"# find optimal learning rate\n",
"res = trainer.tuner.lr_find(net, train_dataloader=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5)\n",
"res = trainer.tuner.lr_find(net, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5)\n",
"print(f\"suggested learning rate: {res.suggestion()}\")\n",
"fig = res.plot(show=True, suggest=True)\n",
"fig.show()\n",
Expand Down Expand Up @@ -617,7 +617,7 @@
"\n",
"trainer.fit(\n",
" net,\n",
" train_dataloader=train_dataloader,\n",
" train_dataloaders=train_dataloader,\n",
" val_dataloaders=val_dataloader,\n",
")"
]
Expand Down
2 changes: 1 addition & 1 deletion docs/source/tutorials/building.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3803,7 +3803,7 @@
"\n",
"model = FullyConnectedForDistributionLossModel.from_dataset(dataset, hidden_size=10, n_hidden_layers=2, log_interval=1)\n",
"trainer = Trainer(fast_dev_run=True)\n",
"trainer.fit(model, train_dataloader=dataloader, val_dataloaders=dataloader)"
"trainer.fit(model, train_dataloaders=dataloader, val_dataloaders=dataloader)"
]
}
],
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/stallion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,7 @@
"# find optimal learning rate\n",
"res = trainer.tuner.lr_find(\n",
" tft,\n",
" train_dataloader=train_dataloader,\n",
" train_dataloaders=train_dataloader,\n",
" val_dataloaders=val_dataloader,\n",
" max_lr=10.0,\n",
" min_lr=1e-6,\n",
Expand Down Expand Up @@ -1577,7 +1577,7 @@
"# fit network\n",
"trainer.fit(\n",
" tft,\n",
" train_dataloader=train_dataloader,\n",
" train_dataloaders=train_dataloader,\n",
" val_dataloaders=val_dataloader,\n",
")"
]
Expand Down
2 changes: 2 additions & 0 deletions pytorch_forecasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
DeepAR,
MultiEmbedding,
NBeats,
NHiTS,
RecurrentNetwork,
TemporalFusionTransformer,
get_rnn,
Expand All @@ -66,6 +67,7 @@
"MultiNormalizer",
"TemporalFusionTransformer",
"NBeats",
"NHiTS",
"Baseline",
"DeepAR",
"BaseModel",
Expand Down
17 changes: 14 additions & 3 deletions pytorch_forecasting/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,19 +264,30 @@ def __len__(self) -> int:
"""
return len(self.metrics)

def update(self, y_pred: torch.Tensor, y_actual: torch.Tensor):
def update(self, y_pred: torch.Tensor, y_actual: torch.Tensor, **kwargs):
"""
Update composite metric
Args:
y_pred: network output
y_actual: actual values
**kwargs: arguments to update function
Returns:
torch.Tensor: metric value on which backpropagation can be applied
"""
for idx, metric in enumerate(self.metrics):
metric.update(y_pred[idx], (y_actual[0][idx], y_actual[1]))
try:
metric.update(
y_pred[idx],
(y_actual[0][idx], y_actual[1]),
**{
name: value[idx] if isinstance(value, (list, tuple)) else value
for name, value in kwargs.items()
},
)
except TypeError: # silently update without kwargs if not supported
metric.update(y_pred[idx], (y_actual[0][idx], y_actual[1]))

def compute(self) -> torch.Tensor:
"""
Expand Down Expand Up @@ -949,7 +960,7 @@ def update(
self._update_losses_and_lengths(losses, lengths)

def loss(self, y_pred, target, scaling):
return (y_pred - target).abs() / scaling.unsqueeze(-1)
return (self.to_prediction(y_pred) - target).abs() / scaling.unsqueeze(-1)

def calculate_scaling(self, target, lengths, encoder_target, encoder_lengths):
# calcualte mean(abs(diff(targets)))
Expand Down
2 changes: 2 additions & 0 deletions pytorch_forecasting/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
from pytorch_forecasting.models.deepar import DeepAR
from pytorch_forecasting.models.mlp import DecoderMLP
from pytorch_forecasting.models.nbeats import NBeats
from pytorch_forecasting.models.nhits import NHiTS
from pytorch_forecasting.models.nn import GRU, LSTM, MultiEmbedding, get_rnn
from pytorch_forecasting.models.rnn import RecurrentNetwork
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer

__all__ = [
"NBeats",
"NHiTS",
"TemporalFusionTransformer",
"RecurrentNetwork",
"DeepAR",
Expand Down
113 changes: 70 additions & 43 deletions pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@
QuantileLoss,
convert_torchmetric_to_pytorch_forecasting_metric,
)
from pytorch_forecasting.models.nn.embeddings import MultiEmbedding
from pytorch_forecasting.optim import Ranger
from pytorch_forecasting.utils import (
OutputMixIn,
TupleOutputMixIn,
apply_to_list,
create_mask,
get_embedding_size,
Expand Down Expand Up @@ -131,7 +133,7 @@ def _concatenate_output(
}


class BaseModel(LightningModule):
class BaseModel(LightningModule, TupleOutputMixIn):
"""
BaseModel from which new timeseries models should inherit from.
The ``hparams`` of the created object will default to the parameters indicated in :py:meth:`~__init__`.
Expand Down Expand Up @@ -192,6 +194,7 @@ def __init__(
loss: Metric = SMAPE(),
logging_metrics: nn.ModuleList = nn.ModuleList([]),
reduce_on_plateau_patience: int = 1000,
reduce_on_plateau_reduction: float = 2.0,
reduce_on_plateau_min_lr: float = 1e-5,
weight_decay: float = 0.0,
optimizer_params: Dict[str, Any] = None,
Expand All @@ -215,6 +218,7 @@ def __init__(
Defaults to [].
reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10. Defaults
to 1000
reduce_on_plateau_reduction (float): reduction in learning rate when encountering plateau. Defaults to 2.0.
reduce_on_plateau_min_lr (float): minimum learning rate for reduce on plateua learning rate scheduler.
Defaults to 1e-5
weight_decay (float): weight decay. Defaults to 0.0.
Expand Down Expand Up @@ -513,7 +517,7 @@ def step(
# multiply monotinicity loss by large number to ensure relevance and take to the power of 2
# for smoothness of loss function
monotinicity_loss = 10 * torch.pow(monotinicity_loss, 2)
if isinstance(self.loss, MASE):
if isinstance(self.loss, (MASE, MultiLoss)):
loss = self.loss(
prediction, y, encoder_target=x["encoder_target"], encoder_lengths=x["encoder_lengths"]
)
Expand All @@ -526,10 +530,9 @@ def step(

# calculate loss
prediction = out["prediction"]
if isinstance(self.loss, MASE):
loss = self.loss(
prediction, y, encoder_target=x["encoder_target"], encoder_lengths=x["encoder_lengths"]
)
if isinstance(self.loss, (MASE, MultiLoss)):
mase_kwargs = dict(encoder_target=x["encoder_target"], encoder_lengths=x["encoder_lengths"])
loss = self.loss(prediction, y, **mase_kwargs)
else:
loss = self.loss(prediction, y)

Expand Down Expand Up @@ -595,27 +598,6 @@ def log_metrics(
batch_size=len(x["decoder_target"]),
)

def to_network_output(self, **results):
"""
Convert output into a named (and immuatable) tuple.
This allows tracing the modules as graphs and prevents modifying the output.
Returns:
named tuple
"""
if hasattr(self, "_output_class"):
Output = self._output_class
else:
OutputTuple = namedtuple("output", results)

class Output(OutputMixIn, OutputTuple):
pass

self._output_class = Output

return self._output_class(**results)

def forward(
self, x: Dict[str, Union[torch.Tensor, List[torch.Tensor]]]
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
Expand Down Expand Up @@ -956,7 +938,7 @@ def configure_optimizers(self):
"scheduler": ReduceLROnPlateau(
optimizer,
mode="min",
factor=0.2,
factor=1.0 / self.hparams.reduce_on_plateau_reduction,
patience=self.hparams.reduce_on_plateau_patience,
cooldown=self.hparams.reduce_on_plateau_patience,
min_lr=self.hparams.reduce_on_plateau_min_lr,
Expand Down Expand Up @@ -1347,6 +1329,25 @@ class BaseModelWithCovariates(BaseModel):
as bag of embeddings
"""

@property
def target_positions(self) -> torch.LongTensor:
"""
Positions of target variable(s) in covariates.
Returns:
torch.LongTensor: tensor of positions.
"""
# todo: expand for categorical targets
if "target" in self.hparams:
target = self.hparams.target
else:
target = self.dataset_parameters["target"]
return torch.tensor(
[self.hparams.x_reals.index(name) for name in to_list(target)],
device=self.device,
dtype=torch.long,
)

@property
def reals(self) -> List[str]:
"""List of all continuous variables in model"""
Expand Down Expand Up @@ -1454,6 +1455,47 @@ def from_dataset(
new_kwargs.update(kwargs)
return super().from_dataset(dataset, **new_kwargs)

def extract_features(
self,
x,
embeddings: MultiEmbedding = None,
period: str = "all",
) -> torch.Tensor:
"""
Extract features
Args:
x (Dict[str, torch.Tensor]): input from the dataloader
embeddings (MultiEmbedding): embeddings for categorical variables
period (str, optional): One of "encoder", "decoder" or "all". Defaults to "all".
Returns:
torch.Tensor: tensor with selected variables
"""
# select period
if period == "encoder":
x_cat = x["encoder_cat"]
x_cont = x["encoder_cont"]
elif period == "decoder":
x_cat = x["decoder_cat"]
x_cont = x["decoder_cont"]
elif period == "all":
x_cat = torch.cat([x["encoder_cat"], x["decoder_cat"]], dim=1) # concatenate in time dimension
x_cont = torch.cat([x["encoder_cont"], x["decoder_cont"]], dim=1) # concatenate in time dimension
else:
raise ValueError(f"Unknown type: {type}")

# create dictionary of encoded vectors
input_vectors = embeddings(x_cat)
input_vectors.update(
{
name: x_cont[..., idx].unsqueeze(-1)
for idx, name in enumerate(self.hparams.x_reals)
if name in self.reals
}
)
return input_vectors

def calculate_prediction_actual_by_variable(
self,
x: Dict[str, torch.Tensor],
Expand Down Expand Up @@ -1983,21 +2025,6 @@ class AutoRegressiveBaseModelWithCovariates(BaseModelWithCovariates, AutoRegress
as bag of embeddings
"""

@property
def target_positions(self) -> torch.LongTensor:
"""
Positions of target variable(s) in covariates.
Returns:
torch.LongTensor: tensor of positions.
"""
# todo: expand for categorical targets
return torch.tensor(
[self.hparams.x_reals.index(name) for name in to_list(self.hparams.target)],
device=self.device,
dtype=torch.long,
)

@property
def lagged_target_positions(self) -> Dict[int, torch.LongTensor]:
"""
Expand Down
3 changes: 3 additions & 0 deletions pytorch_forecasting/models/nbeats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def __init__(
the most
important benchmark for univariate time series forecasting.
The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently shown to consistently outperform
N-BEATS.
Args:
stack_types: One of the following values: “generic”, “seasonality" or “trend". A list of strings
of length 1 or ‘num_stacks’. Default and recommended value
Expand Down
Loading

0 comments on commit eea8c16

Please sign in to comment.