Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple Target Prediction Plotting Bug #1317

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

terbed
Copy link

@terbed terbed commented May 27, 2023

Description

When calling the plot_prediction() function in PytorchForecasting with multiple targets, the function reuses the same axes for each target. This behavior results in overlapped plots for different targets, rather than separate plots for each target.
This pull request fixes this issue.

Issue

#1314

The faulty code part

def plot_prediction(
    self,
    x: Dict[str, torch.Tensor],
    out: Dict[str, torch.Tensor],
    idx: int = 0,
    add_loss_to_title: Union[Metric, torch.Tensor, bool] = False,
    show_future_observed: bool = True,
    ax=None,
    quantiles_kwargs: Dict[str, Any] = {},
    prediction_kwargs: Dict[str, Any] = {},
) -> plt.Figure:

    #...

    # for each target, plot
    figs = []
    for y_raw, y_hat, y_quantile, encoder_target, decoder_target in zip(
        y_raws, y_hats, y_quantiles, encoder_targets, decoder_targets
    ):
        # ...

        # create figure
        if ax is None:
            fig, ax = plt.subplots()
        else:
            fig = ax.get_figure()
        
        # ...

        figs.append(fig)
    
    return figs

Expected behavior:
Each target should be plotted on a separate figure.

Actual behavior:
All targets are plotted on the same figure, resulting in overlapped plots.

Solution

In the above snippet, the variable ax should be updated within the loop over targets but instead after the first target, the same ax is reused (as ax is no longer None). The proposed issue fix is:

    def plot_prediction(
        self,
        x: Dict[str, torch.Tensor],
        out: Dict[str, torch.Tensor],
        idx: int = 0,
        add_loss_to_title: Union[Metric, torch.Tensor, bool] = False,
        show_future_observed: bool = True,
        ax=None,
        quantiles_kwargs: Dict[str, Any] = {},
        prediction_kwargs: Dict[str, Any] = {},
    ) -> plt.Figure:

        # ...
        # for each target, plot
        figs = []
        ax_provided = ax is not None
        for y_raw, y_hat, y_quantile, encoder_target, decoder_target in zip(
            y_raws, y_hats, y_quantiles, encoder_targets, decoder_targets
        ):

            # ...
            # create figure
            if (ax is None) or (not ax_provided):
                fig, ax = plt.subplots()
            else:
                fig = ax.get_figure()

Bonus

Corrected mistakes in documentation. The encoder's log1p transformation is incorrectly called logp1 in the documentation.
#1247

@terbed
Copy link
Author

terbed commented Jun 14, 2023

Is this repo not maintained?

@@ -1213,7 +1217,7 @@ def configure_optimizers(self):
min_lr=self.hparams.reduce_on_plateau_min_lr,
),
"monitor": "val_loss", # Default: val_loss
"interval": "epoch",
"interval": "step",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think this makes sense. reduce each step is very aggressive

capsize=1.0,
)
except ValueError:
print(f"Warning: could not plot error bars. Quantiles: {quantiles}, y: {y}, yerr: {quantiles - y[-n_pred]}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would have to be a logger.warning() instead of print()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does the error actually happen. Seems like this solves a different problem

@@ -1012,7 +1013,7 @@ def plot_prediction(
# move to cpu
y = y.detach().cpu()
# create figure
if ax is None:
if (ax is None) or (not ax_provided):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is confusing. I would say: ax_not_provided = ax is None before the loop and then if ax_not_provided after the loop

@codecov-commenter
Copy link

codecov-commenter commented Sep 10, 2023

Codecov Report

Patch coverage: 33.33% and project coverage change: -0.08% ⚠️

Comparison is base (9995d0a) 90.13% compared to head (120f3e4) 90.05%.
Report is 2 commits behind head on master.

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1317      +/-   ##
==========================================
- Coverage   90.13%   90.05%   -0.08%     
==========================================
  Files          30       30              
  Lines        4712     4716       +4     
==========================================
  Hits         4247     4247              
- Misses        465      469       +4     
Flag Coverage Δ
cpu 90.05% <33.33%> (-0.08%) ⬇️
pytest 90.05% <33.33%> (-0.08%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Changed Coverage Δ
pytorch_forecasting/data/encoders.py 87.25% <ø> (ø)
pytorch_forecasting/models/base_model.py 87.77% <33.33%> (-0.42%) ⬇️

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants