diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e93eadf..1c629481 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Release Notes -## v0.10.0 UNRELEASED +## v0.10.0 Adding N-HiTS network (N-BEATS successor) (23/03/2022) ### Added @@ -8,6 +8,7 @@ - Allow using [torchmetrics](https://torchmetrics.readthedocs.io/) as loss metrics (#776) - Enable fitting `EncoderNormalizer()` with limited data history using `max_length` argument (#782) - More flexible `MultiEmbedding()` with convenience `output_size` and `input_size` properties (#829) +- Fix concatentation of attention (#902) ### Fixed diff --git a/pytorch_forecasting/metrics.py b/pytorch_forecasting/metrics.py index 371c3c83..f24bfef0 100644 --- a/pytorch_forecasting/metrics.py +++ b/pytorch_forecasting/metrics.py @@ -745,9 +745,9 @@ def to_quantiles(self, out: Dict[str, torch.Tensor], quantiles=None): else: quantiles = self.quantiles predictions = super().to_prediction(out) - return torch.stack([torch.tensor(scipy.stats.poisson(predictions.cpu()).ppf(q)) for q in quantiles], dim=-1).to( - predictions.device - ) + return torch.stack( + [torch.tensor(scipy.stats.poisson(predictions.detach().numpy()).ppf(q)) for q in quantiles], dim=-1 + ).to(predictions.device) class QuantileLoss(MultiHorizonMetric): diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index 79126c4c..f66a05e9 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -6,6 +6,7 @@ from copy import deepcopy import inspect from typing import Any, Callable, Dict, Iterable, List, Tuple, Union +import warnings import matplotlib.pyplot as plt import numpy as np @@ -75,7 +76,24 @@ def _torch_cat_na(x: List[torch.Tensor]) -> torch.Tensor: ) for xi in x ] - return torch.cat(x, dim=0) + + # check if remaining dimensions are all equal + if x[0].ndim > 2: + remaining_dimensions_equal = all([all([xi.size(i) == x[0].size(i) for xi in x]) for i in range(2, x[0].ndim)]) + else: + remaining_dimensions_equal = True + + # deaggregate + if remaining_dimensions_equal: + return torch.cat(x, dim=0) + else: + # make list instead but warn + warnings.warn( + f"Not all dimensions are equal for tensors shapes. Example tensor {x[0].shape}. " + "Returning list instead of torch.Tensor.", + UserWarning, + ) + return [xii for xi in x for xii in xi] def _concatenate_output( diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index a716539f..9012d8b9 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -23,7 +23,7 @@ InterpretableMultiHeadAttention, VariableSelectionNetwork, ) -from pytorch_forecasting.utils import autocorrelation, create_mask, detach, integer_histogram, padded_stack, to_list +from pytorch_forecasting.utils import create_mask, detach, integer_histogram, masked_op, padded_stack, to_list class TemporalFusionTransformer(BaseModelWithCovariates): @@ -501,7 +501,8 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return self.to_network_output( prediction=self.transform_output(output, target_scale=x["target_scale"]), - attention=attn_output_weights, + encoder_attention=attn_output_weights[..., :max_encoder_length], + decoder_attention=attn_output_weights[..., max_encoder_length:], static_variables=static_variable_selection, encoder_variables=encoder_sparse_weights, decoder_variables=decoder_sparse_weights, @@ -540,7 +541,6 @@ def interpret_output( out: Dict[str, torch.Tensor], reduction: str = "none", attention_prediction_horizon: int = 0, - attention_as_autocorrelation: bool = False, ) -> Dict[str, torch.Tensor]: """ interpret output of model @@ -550,12 +550,77 @@ def interpret_output( reduction: "none" for no averaging over batches, "sum" for summing attentions, "mean" for normalizing by encode lengths attention_prediction_horizon: which prediction horizon to use for attention - attention_as_autocorrelation: if to record attention as autocorrelation - this should be set to true in - case of ``reduction != "none"`` and differing prediction times of the samples. Defaults to False Returns: interpretations that can be plotted with ``plot_interpretation()`` """ + # take attention and concatenate if a list to proper attention object + if isinstance(out["decoder_attention"], (list, tuple)): + batch_size = len(out["decoder_attention"]) + # start with decoder attention + # assume issue is in last dimension, we need to find max + max_last_dimension = max(x.size(-1) for x in out["decoder_attention"]) + first_elm = out["decoder_attention"][0] + # create new attention tensor into which we will scatter + decoder_attention = torch.full( + (batch_size, *first_elm.shape[:-1], max_last_dimension), + float("nan"), + dtype=first_elm.dtype, + device=first_elm.device, + ) + # scatter into tensor + for idx, x in enumerate(out["decoder_attention"]): + decoder_length = out["decoder_lengths"][idx] + decoder_attention[idx, :, :, :decoder_length] = x[..., :decoder_length] + + # same game for encoder attention + # create new attention tensor into which we will scatter + first_elm = out["encoder_attention"][0] + encoder_attention = torch.full( + (batch_size, *first_elm.shape[:-1], self.hparams.max_encoder_length), + float("nan"), + dtype=first_elm.dtype, + device=first_elm.device, + ) + # scatter into tensor + for idx, x in enumerate(out["encoder_attention"]): + encoder_length = out["encoder_lengths"][idx] + encoder_attention[idx, :, :, self.hparams.max_encoder_length - encoder_length :] = x[ + ..., :encoder_length + ] + else: + decoder_attention = out["decoder_attention"] + decoder_mask = create_mask(out["decoder_attention"].size(1), out["decoder_lengths"]) + decoder_attention[decoder_mask[..., None, None].expand_as(decoder_attention)] = float("nan") + # roll encoder attention (so start last encoder value is on the right) + encoder_attention = out["encoder_attention"] + shifts = encoder_attention.size(3) - out["encoder_lengths"] + new_index = ( + torch.arange(encoder_attention.size(3))[None, None, None].expand_as(encoder_attention) + - shifts[:, None, None, None] + ) % encoder_attention.size(3) + encoder_attention = torch.gather(encoder_attention, dim=3, index=new_index) + # expand encoder_attentiont to full size + if encoder_attention.size(-1) < self.hparams.max_encoder_length: + encoder_attention = torch.concat( + [ + torch.full( + ( + *encoder_attention.shape[:-1], + self.hparams.max_encoder_length - out["encoder_lengths"].max(), + ), + float("nan"), + dtype=encoder_attention.dtype, + device=encoder_attention.device, + ), + encoder_attention, + ], + dim=-1, + ) + + # combine attention vector + attention = torch.concat([encoder_attention, decoder_attention], dim=-1) + attention[attention < 1e-5] = float("nan") # histogram of decode and encode lengths encoder_length_histogram = integer_histogram(out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length) @@ -582,53 +647,25 @@ def interpret_output( static_variables = out["static_variables"].squeeze(1) # attention is batch x time x heads x time_to_attend # average over heads + only keep prediction attention and attention on observed timesteps - attention = out["attention"][ - :, attention_prediction_horizon, :, : out["encoder_lengths"].max() + attention_prediction_horizon - ].mean(1) + attention = masked_op( + attention[ + :, attention_prediction_horizon, :, : self.hparams.max_encoder_length + attention_prediction_horizon + ], + op="mean", + dim=1, + ) if reduction != "none": # if to average over batches static_variables = static_variables.sum(dim=0) encoder_variables = encoder_variables.sum(dim=0) decoder_variables = decoder_variables.sum(dim=0) - # reorder attention or averaging - for i in range(len(attention)): # very inefficient but does the trick - if 0 < out["encoder_lengths"][i] < attention.size(1) - attention_prediction_horizon - 1: - relevant_attention = attention[ - i, : out["encoder_lengths"][i] + attention_prediction_horizon - ].clone() - if attention_as_autocorrelation: - relevant_attention = autocorrelation(relevant_attention) - attention[i, -out["encoder_lengths"][i] - attention_prediction_horizon :] = relevant_attention - attention[i, : attention.size(1) - out["encoder_lengths"][i] - attention_prediction_horizon] = 0.0 - elif attention_as_autocorrelation: - attention[i] = autocorrelation(attention[i]) - - attention = attention.sum(dim=0) - if reduction == "mean": - attention = attention / encoder_length_histogram[1:].flip(0).cumsum(0).clamp(1) - attention = attention / attention.sum(-1).unsqueeze(-1) # renormalize - elif reduction == "sum": - pass - else: - raise ValueError(f"Unknown reduction {reduction}") - - attention = torch.zeros( - self.hparams.max_encoder_length + attention_prediction_horizon, device=self.device - ).scatter( - dim=0, - index=torch.arange( - self.hparams.max_encoder_length + attention_prediction_horizon - attention.size(-1), - self.hparams.max_encoder_length + attention_prediction_horizon, - device=self.device, - ), - src=attention, - ) + attention = masked_op(attention, dim=0, op=reduction) else: - attention = attention / attention.sum(-1).unsqueeze(-1) # renormalize + attention = attention / masked_op(attention, dim=1, op="sum").unsqueeze(-1) # renormalize interpretation = dict( - attention=attention, + attention=attention.masked_fill(torch.isnan(attention), 0.0), static_variables=static_variables, encoder_variables=encoder_variables, decoder_variables=decoder_variables, @@ -677,15 +714,15 @@ def plot_prediction( # add attention on secondary axis if plot_attention: - interpretation = self.interpret_output(out) + interpretation = self.interpret_output(out.iget(slice(idx, idx + 1))) for f in to_list(fig): ax = f.axes[0] ax2 = ax.twinx() ax2.set_ylabel("Attention") - encoder_length = x["encoder_lengths"][idx] + encoder_length = x["encoder_lengths"][0] ax2.plot( torch.arange(-encoder_length, 0), - interpretation["attention"][idx, :encoder_length].detach().cpu(), + interpretation["attention"][0, -encoder_length:].detach().cpu(), alpha=0.2, color="k", ) diff --git a/pytorch_forecasting/utils.py b/pytorch_forecasting/utils.py index 407aa428..8e8f3420 100644 --- a/pytorch_forecasting/utils.py +++ b/pytorch_forecasting/utils.py @@ -338,6 +338,17 @@ def items(self): def keys(self): return self._fields + def iget(self, idx: Union[int, slice]): + """Select item(s) row-wise. + + Args: + idx ([int, slice]): item to select + + Returns: + Output of single item. + """ + return self.__class__(*(x[idx] for x in self)) + class TupleOutputMixIn: """MixIn to give output a namedtuple-like access capabilities with ``to_network_output() function``.""" @@ -436,3 +447,28 @@ def detach( return [detach(xi) for xi in x] else: return x + + +def masked_op(tensor: torch.Tensor, op: str = "mean", dim: int = 0, mask: torch.Tensor = None) -> torch.Tensor: + """Calculate operation on masked tensor. + + Args: + tensor (torch.Tensor): tensor to conduct operation over + op (str): operation to apply. One of ["mean", "sum"]. Defaults to "mean". + dim (int, optional): dimension to average over. Defaults to 0. + mask (torch.Tensor, optional): boolean mask to apply (True=will take mean, False=ignore). + Masks nan values by default. + + Returns: + torch.Tensor: tensor with averaged out dimension + """ + if mask is None: + mask = ~torch.isnan(tensor) + masked = tensor.masked_fill(~mask, 0.0) + summed = masked.sum(dim=dim) + if op == "mean": + return summed / mask.sum(dim=dim) # Find the average + elif op == "sum": + return summed + else: + raise ValueError(f"unkown operation {op}") diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index ef81907b..e367f27e 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -2,6 +2,7 @@ import shutil import sys +import numpy as np import pytest import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint @@ -256,6 +257,52 @@ def test_prediction_with_dataloder(model, dataloaders_with_covariates, kwargs): model.predict(val_dataloader, fast_dev_run=True, **kwargs) +def test_prediction_with_dataloder_raw(data_with_covariates, tmp_path): + # tests correct concatenation of raw output + test_data = data_with_covariates.copy() + np.random.seed(2) + test_data = test_data.sample(frac=0.5) + + dataset = TimeSeriesDataSet( + test_data, + time_idx="time_idx", + max_encoder_length=8, + max_prediction_length=10, + min_prediction_length=1, + min_encoder_length=1, + target="volume", + group_ids=["agency", "sku"], + constant_fill_strategy=dict(volume=0.0), + allow_missing_timesteps=True, + time_varying_unknown_reals=["volume"], + time_varying_known_reals=["time_idx"], + target_normalizer=GroupNormalizer(groups=["agency", "sku"]), + ) + + net = TemporalFusionTransformer.from_dataset( + dataset, + learning_rate=1e-6, + hidden_size=4, + attention_head_size=1, + dropout=0.2, + hidden_continuous_size=2, + log_interval=1, + log_val_interval=1, + log_gradient_flow=True, + ) + logger = TensorBoardLogger(tmp_path) + trainer = pl.Trainer(max_epochs=1, gradient_clip_val=1e-6, logger=logger) + trainer.fit(net, train_dataloaders=dataset.to_dataloader(batch_size=4, num_workers=0)) + + # choose small batch size to provoke issue + res = net.predict(dataset.to_dataloader(batch_size=2, num_workers=0), mode="raw") + # check that interpretation works + net.interpret_output(res)["attention"] + assert net.interpret_output(res.iget(slice(1)))["attention"].size() == torch.Size( + (1, net.hparams.max_encoder_length) + ) + + def test_prediction_with_dataset(model, dataloaders_with_covariates): val_dataloader = dataloaders_with_covariates["val"] model.predict(val_dataloader.dataset, fast_dev_run=True)