Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PatchTST: Add support for time features #3167

Merged
merged 7 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 57 additions & 15 deletions src/gluonts/torch/model/patch_tst/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
TestSplitSampler,
ExpectedNumInstanceSampler,
SelectFields,
RenameFields,
)
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
Expand Down Expand Up @@ -74,6 +75,8 @@ class PatchTSTEstimator(PyTorchLightningEstimator):
Number of attention heads in the Transformer encoder which must divide d_model.
dim_feedforward
Size of hidden layers in the Transformer encoder.
num_feat_dynamic_real
Number of dynamic real features in the data (default: 0).
dropout
Dropout probability in the Transformer encoder.
activation
Expand Down Expand Up @@ -115,6 +118,7 @@ def __init__(
d_model: int = 32,
nhead: int = 4,
dim_feedforward: int = 128,
num_feat_dynamic_real: int = 0,
dropout: float = 0.1,
activation: str = "relu",
norm_first: bool = False,
Expand Down Expand Up @@ -151,6 +155,7 @@ def __init__(
self.d_model = d_model
self.nhead = nhead
self.dim_feedforward = dim_feedforward
self.num_feat_dynamic_real = num_feat_dynamic_real
self.dropout = dropout
self.activation = activation
self.norm_first = norm_first
Expand All @@ -166,17 +171,26 @@ def __init__(
)

def create_transformation(self) -> Transformation:
return SelectFields(
[
FieldName.ITEM_ID,
FieldName.INFO,
FieldName.START,
FieldName.TARGET,
],
allow_missing=True,
) + AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
return (
SelectFields(
[
FieldName.ITEM_ID,
FieldName.INFO,
FieldName.START,
FieldName.TARGET,
]
+ (
[FieldName.FEAT_DYNAMIC_REAL]
if self.num_feat_dynamic_real > 0
else []
),
allow_missing=True,
)
+ RenameFields({FieldName.FEAT_DYNAMIC_REAL: FieldName.FEAT_TIME})
+ AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
)
)

def create_lightning_module(self) -> pl.LightningModule:
Expand All @@ -192,6 +206,7 @@ def create_lightning_module(self) -> pl.LightningModule:
"d_model": self.d_model,
"nhead": self.nhead,
"dim_feedforward": self.dim_feedforward,
"num_feat_dynamic_real": self.num_feat_dynamic_real,
"dropout": self.dropout,
"activation": self.activation,
"norm_first": self.norm_first,
Expand Down Expand Up @@ -220,7 +235,10 @@ def _create_instance_splitter(
instance_sampler=instance_sampler,
past_length=self.context_length,
future_length=self.prediction_length,
time_series_fields=[FieldName.OBSERVED_VALUES],
time_series_fields=[FieldName.OBSERVED_VALUES]
+ (
[FieldName.FEAT_TIME] if self.num_feat_dynamic_real > 0 else []
),
dummy_value=self.distr_output.value_in_support,
)

Expand All @@ -239,7 +257,15 @@ def create_training_data_loader(
instances,
batch_size=self.batch_size,
shuffle_buffer_length=shuffle_buffer_length,
field_names=TRAINING_INPUT_NAMES,
field_names=TRAINING_INPUT_NAMES
+ (
[
f"past_{FieldName.FEAT_TIME}",
f"future_{FieldName.FEAT_TIME}",
]
if self.num_feat_dynamic_real > 0
else []
),
output_type=torch.tensor,
num_batches_per_epoch=self.num_batches_per_epoch,
)
Expand All @@ -253,7 +279,15 @@ def create_validation_data_loader(
return as_stacked_batches(
instances,
batch_size=self.batch_size,
field_names=TRAINING_INPUT_NAMES,
field_names=TRAINING_INPUT_NAMES
+ (
[
f"past_{FieldName.FEAT_TIME}",
f"future_{FieldName.FEAT_TIME}",
]
if self.num_feat_dynamic_real > 0
else []
),
output_type=torch.tensor,
)

Expand All @@ -264,7 +298,15 @@ def create_predictor(

return PyTorchPredictor(
input_transform=transformation + prediction_splitter,
input_names=PREDICTION_INPUT_NAMES,
input_names=PREDICTION_INPUT_NAMES
+ (
[
f"past_{FieldName.FEAT_TIME}",
f"future_{FieldName.FEAT_TIME}",
]
if self.num_feat_dynamic_real > 0
else []
),
prediction_net=module,
forecast_generator=self.distr_output.forecast_generator,
batch_size=self.batch_size,
Expand Down
69 changes: 64 additions & 5 deletions src/gluonts/torch/model/patch_tst/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Tuple
from typing import Optional, Tuple

import numpy as np
import torch
Expand All @@ -21,7 +21,7 @@
from gluonts.model import Input, InputSpec
from gluonts.torch.distributions import StudentTOutput
from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler
from gluonts.torch.util import unsqueeze_expand, weighted_average
from gluonts.torch.util import take_last, unsqueeze_expand, weighted_average
from gluonts.torch.model.simple_feedforward import make_linear_layer


Expand Down Expand Up @@ -85,6 +85,8 @@ class PatchTSTModel(nn.Module):
Number of time points to predict.
context_length
Number of time steps prior to prediction time that the model.
num_feat_dynamic_real
Number of dynamic real features in the data (default: 0).
distr_output
Distribution to use to evaluate observations and sample predictions.
Default: ``StudentTOutput()``.
Expand All @@ -101,6 +103,7 @@ def __init__(
d_model: int,
nhead: int,
dim_feedforward: int,
num_feat_dynamic_real: int,
dropout: float,
activation: str,
norm_first: bool,
Expand All @@ -120,6 +123,7 @@ def __init__(
self.d_model = d_model
self.padding_patch = padding_patch
self.distr_output = distr_output
self.num_feat_dynamic_real = num_feat_dynamic_real

if scaling == "mean":
self.scaler = MeanScaler(keepdim=True)
Expand All @@ -133,8 +137,11 @@ def __init__(
self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride))
self.patch_num += 1

# project from patch_len + 2 features (loc and scale) to d_model
self.patch_proj = make_linear_layer(patch_len + 2, d_model)
# project from `patch_len` + 2 features (`loc` and `scale`) +
# `num_feat_dynamic_real` x `patch_len` to d_model
self.patch_proj = make_linear_layer(
patch_len + 2 + self.num_feat_dynamic_real * patch_len, d_model
)

self.positional_encoding = SinusoidalPositionalEmbedding(
self.patch_num, d_model
Expand Down Expand Up @@ -163,6 +170,28 @@ def __init__(
self.args_proj = self.distr_output.get_args_proj(d_model)

def describe_inputs(self, batch_size=1) -> InputSpec:
if self.num_feat_dynamic_real > 0:
input_spec_feat = {
"past_time_feat": Input(
shape=(
batch_size,
self.context_length,
self.num_feat_dynamic_real,
),
dtype=torch.float,
),
"future_time_feat": Input(
shape=(
batch_size,
self.prediction_length,
self.num_feat_dynamic_real,
),
dtype=torch.float,
),
}
else:
input_spec_feat = {}

return InputSpec(
{
"past_target": Input(
Expand All @@ -171,6 +200,7 @@ def describe_inputs(self, batch_size=1) -> InputSpec:
"past_observed_values": Input(
shape=(batch_size, self.context_length), dtype=torch.float
),
**input_spec_feat,
},
torch.zeros,
)
Expand All @@ -179,6 +209,8 @@ def forward(
self,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
past_time_feat: Optional[torch.Tensor] = None,
future_time_feat: Optional[torch.Tensor] = None,
) -> Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor]:
# scale the input
past_target_scaled, loc, scale = self.scaler(
Expand All @@ -192,6 +224,25 @@ def forward(
dimension=1, size=self.patch_len, step=self.stride
)

# do patching for time features as well
if self.num_feat_dynamic_real > 0:
# shift time features by `prediction_length` so that they are
# aligned with the target input.
time_feat = take_last(
torch.cat((past_time_feat, future_time_feat), dim=1),
dim=1,
num=self.context_length,
)

# (bs x T x d) --> (bs x d x T) because the 1D padding is done on
# the last dimension.
time_feat = self.padding_patch_layer(
time_feat.transpose(-2, -1)
).transpose(-2, -1)
time_feat_patches = time_feat.unfold(
dimension=1, size=self.patch_len, step=self.stride
).flatten(-2, -1)

# add loc and scale to past_target_patches as additional features
log_abs_loc = loc.abs().log1p()
log_scale = scale.log()
Expand All @@ -202,6 +253,9 @@ def forward(
)
inputs = torch.cat((past_target_patches, expanded_static_feat), dim=-1)

if self.num_feat_dynamic_real > 0:
inputs = torch.cat((inputs, time_feat_patches), dim=-1)

# project patches
enc_in = self.patch_proj(inputs)
embed_pos = self.positional_encoding(enc_in.size())
Expand All @@ -224,9 +278,14 @@ def loss(
past_observed_values: torch.Tensor,
future_target: torch.Tensor,
future_observed_values: torch.Tensor,
past_time_feat: Optional[torch.Tensor] = None,
future_time_feat: Optional[torch.Tensor] = None,
) -> torch.Tensor:
distr_args, loc, scale = self(
past_target=past_target, past_observed_values=past_observed_values
past_target=past_target,
past_observed_values=past_observed_values,
past_time_feat=past_time_feat,
future_time_feat=future_time_feat,
)
loss = self.distr_output.loss(
target=future_target, distr_args=distr_args, loc=loc, scale=scale
Expand Down
19 changes: 19 additions & 0 deletions test/torch/model/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,25 @@ def test_estimator_constant_dataset(
num_batches_per_epoch=3,
epochs=2,
),
lambda freq, prediction_length: PatchTSTEstimator(
prediction_length=prediction_length,
context_length=2 * prediction_length,
num_feat_dynamic_real=3,
patch_len=16,
batch_size=4,
num_batches_per_epoch=3,
trainer_kwargs=dict(max_epochs=2),
),
lambda freq, prediction_length: PatchTSTEstimator(
prediction_length=prediction_length,
context_length=2 * prediction_length,
num_feat_dynamic_real=3,
distr_output=QuantileOutput(quantiles=[0.1, 0.6, 0.85]),
patch_len=16,
batch_size=4,
num_batches_per_epoch=3,
trainer_kwargs=dict(max_epochs=2),
),
lambda freq, prediction_length: WaveNetEstimator(
freq=freq,
prediction_length=prediction_length,
Expand Down
Loading