Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] change image logging from matplotlib figure to binary image #1417

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,19 @@ def result(self) -> Prediction:
return None


def fig2img(fig):
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
"""Convert a Matplotlib figure to a PIL Image and return it"""
import io
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img


class BaseModel(InitialParameterRepresenterMixIn, LightningModule, TupleOutputMixIn):
"""
BaseModel from which new timeseries models should inherit from.
Expand Down Expand Up @@ -985,16 +998,16 @@ def log_prediction(
tag += f" of item {idx} in batch {batch_idx}"
if isinstance(fig, (list, tuple)):
for idx, f in enumerate(fig):
self.logger.experiment.add_figure(
f"{self.target_names[idx]} {tag}",
f,
global_step=self.global_step,
self.logger.experiment.log_image(
run_id=self.logger.run_id,
image=fig2img(f),
artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png"
)
else:
self.logger.experiment.add_figure(
tag,
fig,
global_step=self.global_step,
self.logger.experiment.log_image(
run_id=self.logger.run_id,
image=fig2img(fig),
artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png"
)

def plot_prediction(
Expand Down Expand Up @@ -1157,7 +1170,7 @@ def log_gradient_flow(self, named_parameters: Dict[str, torch.Tensor]) -> None:
ax.set_ylabel("Average gradient")
ax.set_yscale("log")
ax.set_title("Gradient flow")
self.logger.experiment.add_figure("Gradient flow", fig, global_step=self.global_step)
self.logger.experiment.log_image(run_id=self.logger.run_id, image=fig2img(fig), artifact_file=f"gradient_flow.png")

def on_after_backward(self):
"""
Expand Down
30 changes: 21 additions & 9 deletions pytorch_forecasting/models/nhits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@
from pytorch_forecasting.utils._dependencies import _check_matplotlib


def fig2img(fig):
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
"""Convert a Matplotlib figure to a PIL Image and return it"""
import io
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img

class NHiTS(BaseModelWithCovariates):
def __init__(
self,
Expand Down Expand Up @@ -552,17 +564,17 @@ def log_interpretation(self, x, out, batch_idx):
name += f"step {self.global_step}"
else:
name += f"batch {batch_idx}"
self.logger.experiment.add_figure(name, fig, global_step=self.global_step)
self.logger.experiment.log_image(image=fig, artifact_file=f"{name}.png")
if isinstance(fig, (list, tuple)):
for idx, f in enumerate(fig):
self.logger.experiment.add_figure(
f"{self.target_names[idx]} {name}",
f,
global_step=self.global_step,
self.logger.experiment.log_image(
run_id=self.logger.run_id,
image=fig2img(f),
artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png"
)
else:
self.logger.experiment.add_figure(
name,
fig,
global_step=self.global_step,
self.logger.experiment.log_image(
run_id=self.logger.run_id,
image=fig2img(fig),
artifact_file=f"{self.target_names[idx]}_{tag}_step_{self.global_step}.png"
)
24 changes: 20 additions & 4 deletions pytorch_forecasting/models/temporal_fusion_transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@
from pytorch_forecasting.utils import create_mask, detach, integer_histogram, masked_op, padded_stack, to_list
from pytorch_forecasting.utils._dependencies import _check_matplotlib

def fig2img(fig):
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
"""Convert a Matplotlib figure to a PIL Image and return it"""
import io
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img


class TemporalFusionTransformer(BaseModelWithCovariates):
def __init__(
Expand Down Expand Up @@ -827,8 +839,10 @@ def log_interpretation(self, outputs):
label = self.current_stage
# log to tensorboard
for name, fig in figs.items():
self.logger.experiment.add_figure(
f"{label.capitalize()} {name} importance", fig, global_step=self.global_step
self.logger.experiment.log_image(
run_id=self.logger.run_id,
image=fig2img(fig),
artifact_file=f"{label.capitalize()}_{name}_step_{self.global_step}.png"
)

# log lengths of encoder/decoder
Expand All @@ -849,8 +863,10 @@ def log_interpretation(self, outputs):
ax.set_ylabel("Number of samples")
ax.set_title(f"{type.capitalize()} length distribution in {label} epoch")

self.logger.experiment.add_figure(
f"{label.capitalize()} {type} length distribution", fig, global_step=self.global_step
self.logger.experiment.log_image(
run_id=self.logger.run_id,
image=fig2img(fig),
artifact_file=f"{label.capitalize()}_{type}_length_distribution_step_{self.global_step}.png",
)

def log_embeddings(self):
Expand Down
Loading