Skip to content

Commit

Permalink
Merge pull request #15 from jdb78/fix/gpus
Browse files Browse the repository at this point in the history
Fix using GPUs
  • Loading branch information
jdb78 authored Aug 23, 2020
2 parents 389d3ed + 33c9c2e commit a3fbb80
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 39 deletions.
2 changes: 1 addition & 1 deletion examples/ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from data import generate_ar_data

data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100)
data["static"] = 2
data["static"] = "2"
data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D")
validation = data.series.sample(20)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_forecasting/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def forward(self, y_pred: Dict[str, torch.Tensor], target: Union[torch.Tensor, r
# batch sizes reside on the CPU by default -> we need to bring them to GPU
lengths = lengths.to(target.device)
else:
lengths = torch.LongTensor([target.size(1)], device=target.device).expand(target.size(0))
lengths = torch.ones(target.size(0), device=target.device, dtype=torch.long) * target.size(1)
assert not target.requires_grad

# calculate loss with "none" reduction
Expand Down
6 changes: 5 additions & 1 deletion pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ def step(self, x: Dict[str, torch.Tensor], y: torch.Tensor, batch_idx: int, labe
if label == "train" and len(self.hparams.monotone_constaints) > 0:
# calculate gradient with respect to continous decoder features
x["decoder_cont"].requires_grad_(True)
assert not torch._C._get_cudnn_enabled(), (
"To use monotone constraints, wrap model and training in context "
"`torch.backends.cudnn.flags(enable=False)`"
)
out = self(x)
out["prediction"] = self.transform_output(out)
prediction = out["prediction"]
Expand All @@ -176,7 +180,7 @@ def step(self, x: Dict[str, torch.Tensor], y: torch.Tensor, batch_idx: int, labe
[self.hparams.x_reals.index(name) for name in self.hparams.monotone_constaints.keys()]
)
monotonicity = torch.tensor(
[val for val in self.hparams.monotone_constaints.values()], dtype=gradient.dtype
[val for val in self.hparams.monotone_constaints.values()], dtype=gradient.dtype, device=gradient.device
)
# add additionl loss if gradient points in wrong direction
gradient = gradient[..., indices] * monotonicity[None, None]
Expand Down
6 changes: 3 additions & 3 deletions pytorch_forecasting/models/nbeats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ def plot_interpretation(
"""
Plot interpretation.
[extended_summary]
Plot two pannels: prediction and backcast vs actuals and
decomposition of prediction into trend, seasonality and generic forecast.
Args:
x (Dict[str, torch.Tensor]): network input
Expand All @@ -237,8 +238,7 @@ def plot_interpretation(
generic forecast on secondary axis in second panel. Defaults to False.
Returns:
plt.Figure: matplotlib figure with two pannels: prediction and backcast vs actuals and
decomposition of prediction into trend, seasonality and generic forecast
plt.Figure: matplotlib figure
"""
if ax is None:
fig, ax = plt.subplots(2, 1, figsize=(6, 8))
Expand Down
15 changes: 12 additions & 3 deletions pytorch_forecasting/models/temporal_fusion_transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,10 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
}
static_embedding, static_variable_selection = self.static_variable_selection(static_embedding)
else:
static_embedding = torch.zeros((x_cont.size(0), self.hparams.hidden_size), dtype=self.dtype)
static_variable_selection = torch.zeros((x_cont.size(0), 0), dtype=self.dtype)
static_embedding = torch.zeros(
(x_cont.size(0), self.hparams.hidden_size), dtype=self.dtype, device=self.device
)
static_variable_selection = torch.zeros((x_cont.size(0), 0), dtype=self.dtype, device=self.device)

static_context_variable_selection = self.expand_static_context(
self.static_context_variable_selection(static_embedding), timesteps
Expand Down Expand Up @@ -806,7 +808,14 @@ def _log_interpretation(self, outputs, label="train"):
attention_occurances = interpretation["encoder_length_histogram"][1:].flip(0).cumsum(0).float()
attention_occurances = attention_occurances / attention_occurances.max()
attention_occurances = torch.cat(
[attention_occurances, torch.ones(interpretation["attention"].size(0) - attention_occurances.size(0))],
[
attention_occurances,
torch.ones(
interpretation["attention"].size(0) - attention_occurances.size(0),
dtype=attention_occurances.dtype,
device=attention_occurances.device,
),
],
dim=0,
)
interpretation["attention"] = interpretation["attention"] / attention_occurances.pow(2).clamp(1.0)
Expand Down
9 changes: 9 additions & 0 deletions tests/test_models/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import pytest
import numpy as np
import torch
from data import get_stallion_data, generate_ar_data
from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer, NaNLabelEncoder, EncoderNormalizer


@pytest.fixture
def gpus():
if torch.cuda.is_available():
return [0]
else:
return 0


@pytest.fixture
def data_with_covariates():
data = get_stallion_data()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_models/test_nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pytorch_forecasting.models import NBeats


def test_integration(dataloaders_fixed_window_without_coveratiates, tmp_path):
def test_integration(dataloaders_fixed_window_without_coveratiates, tmp_path, gpus):
train_dataloader = dataloaders_fixed_window_without_coveratiates["train"]
val_dataloader = dataloaders_fixed_window_without_coveratiates["val"]
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")
Expand All @@ -16,7 +16,7 @@ def test_integration(dataloaders_fixed_window_without_coveratiates, tmp_path):
trainer = pl.Trainer(
checkpoint_callback=checkpoint,
max_epochs=3,
gpus=0,
gpus=gpus,
weights_summary="top",
gradient_clip_val=0.1,
early_stop_callback=early_stop_callback,
Expand Down
60 changes: 32 additions & 28 deletions tests/test_models/test_temporal_fusion_transformer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from pytorch_forecasting.data import TimeSeriesDataSet
import pytest
import torch
import shutil
from contextlib import nullcontext
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_forecasting.metrics import QuantileLoss
from pytorch_forecasting.models import TemporalFusionTransformer


def test_integration(dataloaders_with_coveratiates, tmp_path):
def test_integration(dataloaders_with_coveratiates, tmp_path, gpus):
train_dataloader = dataloaders_with_coveratiates["train"]
val_dataloader = dataloaders_with_coveratiates["val"]
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")
Expand All @@ -19,7 +19,7 @@ def test_integration(dataloaders_with_coveratiates, tmp_path):
trainer = pl.Trainer(
checkpoint_callback=checkpoint,
max_epochs=3,
gpus=0,
gpus=gpus,
weights_summary="top",
gradient_clip_val=0.1,
early_stop_callback=early_stop_callback,
Expand All @@ -29,32 +29,36 @@ def test_integration(dataloaders_with_coveratiates, tmp_path):
# test monotone constraints automatically
if "discount_in_percent" in dataloaders_with_coveratiates["train"].dataset.reals:
monotone_constaints = {"discount_in_percent": +1}
cuda_context = torch.backends.cudnn.flags(enabled=False)
else:
monotone_constaints = {}
net = TemporalFusionTransformer.from_dataset(
train_dataloader.dataset,
learning_rate=0.15,
hidden_size=4,
attention_head_size=1,
dropout=0.2,
hidden_continuous_size=2,
loss=QuantileLoss(),
log_interval=5,
log_val_interval=1,
log_gradient_flow=True,
monotone_constaints=monotone_constaints,
)
net.size()
try:
trainer.fit(
net, train_dataloader=train_dataloader, val_dataloaders=val_dataloader,
cuda_context = nullcontext()

with cuda_context:
net = TemporalFusionTransformer.from_dataset(
train_dataloader.dataset,
learning_rate=0.15,
hidden_size=4,
attention_head_size=1,
dropout=0.2,
hidden_continuous_size=2,
loss=QuantileLoss(),
log_interval=5,
log_val_interval=1,
log_gradient_flow=True,
monotone_constaints=monotone_constaints,
)
net.size()
try:
trainer.fit(
net, train_dataloader=train_dataloader, val_dataloaders=val_dataloader,
)

# check loading
fname = f"{trainer.checkpoint_callback.dirpath}/epoch=0.ckpt"
net = TemporalFusionTransformer.load_from_checkpoint(fname)
# check loading
fname = f"{trainer.checkpoint_callback.dirpath}/epoch=0.ckpt"
net = TemporalFusionTransformer.load_from_checkpoint(fname)

# check prediction
net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True)
finally:
shutil.rmtree(tmp_path, ignore_errors=True)
# check prediction
net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True)
finally:
shutil.rmtree(tmp_path, ignore_errors=True)

0 comments on commit a3fbb80

Please sign in to comment.