Skip to content

Commit

Permalink
Merge pull request #759 from danielgafni/support-testing
Browse files Browse the repository at this point in the history
Add support for testing with lightning (`trainer.test`)
  • Loading branch information
jdb78 authored Nov 29, 2021
2 parents bbd51ea + 175aa47 commit 334aec9
Show file tree
Hide file tree
Showing 11 changed files with 215 additions and 16 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@

### Added

- Added support for running `pytorch_lightning.trainer.test` (#759)

### Fixed

- Fix inattention mutation to `x_cont` (#732).
- Compatability with pytorch-lightning 1.5 (#758)

### Contributors

- eavae
- danielgafni
- jdb78

## v0.9.1 Maintenance Release (26/09/2021)

Expand Down
150 changes: 144 additions & 6 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ pytest-github-actions-annotate-failures = {version = "*", optional = true}

[tool.poetry.dev-dependencies]
# checks and make tools
pre-commit = "^2.15.0"

invoke = "*"
flake8 = "*"
mypy = "*"
Expand Down
39 changes: 33 additions & 6 deletions pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from numpy.lib.function_base import iterable
import pandas as pd
from pytorch_lightning import LightningModule
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.parsing import AttributeDict, get_init_args
import scipy.stats
import torch
Expand Down Expand Up @@ -120,6 +121,15 @@ def _concatenate_output(
return output_cat


STAGE_STATES = {
RunningStage.TRAINING: "train",
RunningStage.VALIDATING: "val",
RunningStage.TESTING: "test",
RunningStage.PREDICTING: "predict",
RunningStage.SANITY_CHECKING: "sanity_check",
}


class BaseModel(LightningModule):
"""
BaseModel from which new timeseries models should inherit from.
Expand Down Expand Up @@ -261,6 +271,14 @@ def __init__(
del self._hparams[k]
del self._hparams_initial[k]

@property
def current_stage(self) -> str:
"""
Available inside lightning loops.
:return: current trainer stage. One of ["train", "val", "test", "predict", "sanity_check"]
"""
return STAGE_STATES[self.trainer.state.stage]

@property
def n_targets(self) -> int:
"""
Expand Down Expand Up @@ -371,6 +389,18 @@ def validation_step(self, batch, batch_idx):
log.update(self.create_log(x, y, out, batch_idx))
return log

def validation_epoch_end(self, outputs):
self.epoch_end(outputs)

def test_step(self, batch, batch_idx):
x, y = batch
log, out = self.step(x, y, batch_idx)
log.update(self.create_log(x, y, out, batch_idx))
return log

def test_epoch_end(self, outputs):
self.epoch_end(outputs)

def create_log(
self,
x: Dict[str, torch.Tensor],
Expand Down Expand Up @@ -404,9 +434,6 @@ def create_log(
)
return {}

def validation_epoch_end(self, outputs):
self.epoch_end(outputs)

def step(
self, x: Dict[str, torch.Tensor], y: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, **kwargs
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
Expand Down Expand Up @@ -500,7 +527,7 @@ def step(
else:
loss = self.loss(prediction, y)

self.log(f"{['val', 'train'][self.training]}_loss", loss, on_step=self.training, on_epoch=True, prog_bar=True)
self.log(f"{self.current_stage}_loss", loss, on_step=self.training, on_epoch=True, prog_bar=True)
log = {"loss": loss, "n_samples": x["decoder_lengths"].size(0)}
return log, out

Expand Down Expand Up @@ -548,7 +575,7 @@ def log_metrics(
else:
target_tag = ""
self.log(
f"{target_tag}{['val', 'train'][self.training]}_{metric.name}",
f"{target_tag}{self.current_stage}_{metric.name}",
loss_value,
on_step=self.training,
on_epoch=True,
Expand Down Expand Up @@ -662,7 +689,7 @@ def log_prediction(
log_indices = [0]
for idx in log_indices:
fig = self.plot_prediction(x, out, idx=idx, add_loss_to_title=True, **kwargs)
tag = f"{['Val', 'Train'][self.training]} prediction"
tag = f"{self.current_stage} prediction"
if self.training:
tag += f" of item {idx} in global batch {self.global_step}"
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ def log_interpretation(self, outputs):
interpretation["attention"] = interpretation["attention"] / interpretation["attention"].sum()

figs = self.plot_interpretation(interpretation) # make interpretation figures
label = ["val", "train"][self.training]
label = self.current_stage
# log to tensorboard
for name, fig in figs.items():
self.logger.experiment.add_figure(
Expand Down
6 changes: 4 additions & 2 deletions tests/test_models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ def make_dataloaders(data_with_covariates, **kwargs):
batch_size = 4
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)
test_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)

return dict(train=train_dataloader, val=val_dataloader)
return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader)


@pytest.fixture(
Expand Down Expand Up @@ -165,5 +166,6 @@ def dataloaders_fixed_window_without_covariates():
batch_size = 4
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)
test_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)

return dict(train=train_dataloader, val=val_dataloader)
return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader)
6 changes: 6 additions & 0 deletions tests/test_models/test_deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ def _integration(
)
data_loader_default_kwargs.update(data_loader_kwargs)
dataloaders_with_covariates = make_dataloaders(data_with_covariates, **data_loader_default_kwargs)

train_dataloader = dataloaders_with_covariates["train"]
val_dataloader = dataloaders_with_covariates["val"]
test_dataloader = dataloaders_with_covariates["test"]

early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")

logger = TensorBoardLogger(tmp_path)
Expand All @@ -48,6 +51,7 @@ def _integration(
default_root_dir=tmp_path,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
logger=logger,
)

Expand All @@ -67,6 +71,8 @@ def _integration(
train_dataloader=train_dataloader,
val_dataloaders=val_dataloader,
)
test_outputs = trainer.test(net, test_dataloaders=test_dataloader)
assert len(test_outputs) > 0
# check loading
net = DeepAR.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

Expand Down
6 changes: 5 additions & 1 deletion tests/test_models/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def _integration(data_with_covariates, tmp_path, gpus, data_loader_kwargs={}, tr
dataloaders_with_covariates = make_dataloaders(data_with_covariates, **data_loader_default_kwargs)
train_dataloader = dataloaders_with_covariates["train"]
val_dataloader = dataloaders_with_covariates["val"]
test_dataloader = dataloaders_with_covariates["test"]
early_stop_callback = EarlyStopping(
monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min", strict=False
)
Expand All @@ -39,6 +40,7 @@ def _integration(data_with_covariates, tmp_path, gpus, data_loader_kwargs={}, tr
default_root_dir=tmp_path,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
logger=logger,
)

Expand All @@ -60,6 +62,9 @@ def _integration(data_with_covariates, tmp_path, gpus, data_loader_kwargs={}, tr

# check prediction
net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True)
# check test dataloader
test_outputs = trainer.test(net, test_dataloaders=test_dataloader)
assert len(test_outputs) > 0
finally:
shutil.rmtree(tmp_path, ignore_errors=True)

Expand All @@ -86,7 +91,6 @@ def _integration(data_with_covariates, tmp_path, gpus, data_loader_kwargs={}, tr
),
dict(optimizer="SGD", weight_decay=1e-3),
dict(optimizer=lambda params, lr: SGD(params, lr=lr, weight_decay=1e-3)),
dict(optimizer=SGD, weight_decay=1e-4),
],
)
def test_integration(data_with_covariates, tmp_path, gpus, kwargs):
Expand Down
5 changes: 5 additions & 0 deletions tests/test_models/test_nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
def test_integration(dataloaders_fixed_window_without_covariates, tmp_path, gpus):
train_dataloader = dataloaders_fixed_window_without_covariates["train"]
val_dataloader = dataloaders_fixed_window_without_covariates["val"]
test_dataloader = dataloaders_fixed_window_without_covariates["test"]

early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")

logger = TensorBoardLogger(tmp_path)
Expand All @@ -25,6 +27,7 @@ def test_integration(dataloaders_fixed_window_without_covariates, tmp_path, gpus
default_root_dir=tmp_path,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
logger=logger,
)

Expand All @@ -43,6 +46,8 @@ def test_integration(dataloaders_fixed_window_without_covariates, tmp_path, gpus
train_dataloader=train_dataloader,
val_dataloaders=val_dataloader,
)
test_outputs = trainer.test(net, test_dataloaders=test_dataloader)
assert len(test_outputs) > 0
# check loading
net = NBeats.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

Expand Down
5 changes: 5 additions & 0 deletions tests/test_models/test_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def _integration(
dataloaders_with_covariates = make_dataloaders(data_with_covariates, **data_loader_default_kwargs)
train_dataloader = dataloaders_with_covariates["train"]
val_dataloader = dataloaders_with_covariates["val"]
test_dataloader = dataloaders_with_covariates["test"]

early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")

logger = TensorBoardLogger(tmp_path)
Expand All @@ -44,6 +46,7 @@ def _integration(
default_root_dir=tmp_path,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
logger=logger,
)

Expand All @@ -62,6 +65,8 @@ def _integration(
train_dataloader=train_dataloader,
val_dataloaders=val_dataloader,
)
test_outputs = trainer.test(net, test_dataloaders=test_dataloader)
assert len(test_outputs) > 0
# check loading
net = RecurrentNetwork.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

Expand Down
5 changes: 5 additions & 0 deletions tests/test_models/test_temporal_fusion_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def test_distribution_loss(data_with_covariates, tmp_path, gpus):
def _integration(dataloader, tmp_path, gpus, loss=None):
train_dataloader = dataloader["train"]
val_dataloader = dataloader["val"]
test_dataloader = dataloader["test"]

early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")

# check training
Expand All @@ -70,6 +72,7 @@ def _integration(dataloader, tmp_path, gpus, loss=None):
default_root_dir=tmp_path,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
logger=logger,
)
# test monotone constraints automatically
Expand Down Expand Up @@ -114,6 +117,8 @@ def _integration(dataloader, tmp_path, gpus, loss=None):
train_dataloader=train_dataloader,
val_dataloaders=val_dataloader,
)
test_outputs = trainer.test(net, test_dataloaders=test_dataloader)
assert len(test_outputs) > 0

# check loading
net = TemporalFusionTransformer.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
Expand Down

0 comments on commit 334aec9

Please sign in to comment.