Skip to content

Commit

Permalink
new makefig context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
ejolly committed Oct 22, 2023
1 parent 0315a50 commit 6d52336
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
40 changes: 40 additions & 0 deletions utilz/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
from toolz import curry
from typing import Union
from contextlib import contextmanager


def mpinit(figsize: tuple = (8, 6), subplots: tuple = (1, 1)):
Expand Down Expand Up @@ -269,3 +270,42 @@ def newax(*args, **kwargs):
"""Short hand for a new axis on a new figure. Usueful for calling multiple plotting
routines in a pipe() but wanting separate figures."""
return plt.subplots()[1]


@contextmanager
def makefig(figorax, **kwargs):
"""
A context manager to handling a figure and optionally saving it.
Handles all kwargs to plt.subplots() as well as a save kwarg that
should point to a file to auto-saving
Args:
figorax (str): 'fig' or 'ax'
Yields:
figure or axis handle
Examples:
>>> x = np.random.randn(10)
>>> with makefig('ax', figsize=(3,3)) as ax:
>>> ax.plot(x)
>>> with makefig('ax', save='myfig.jpg') as ax:
>>> ax.plot(x)
>>> # figure is saved to 'myfig.jpg'
"""

save = kwargs.pop("save", None)
bbox_inches = kwargs.pop("bbox_inches", "tight")

f, ax = plt.subplots(**kwargs)
if figorax == "fig":
yield f
elif figorax == "ax":
yield ax
else:
raise ValueError("Fist arg to makefig() must be 'fig' or 'ax'")

if save is not None:
f.savefig(save, bbox_inches=bbox_inches)
22 changes: 21 additions & 1 deletion utilz/tests/test_plot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from utilz.plot import mpinit, stripbarplot, savefig
from utilz.plot import mpinit, stripbarplot, savefig, makefig
from utilz.boilerplate import randdf
import matplotlib.pyplot as plt
from pathlib import Path
Expand Down Expand Up @@ -51,3 +51,23 @@ def test_savefig(tmp_path: Path):

dir_save_raster.unlink()
dir_save_vector.unlink()


def test_makefig(tmp_path: Path):
x = np.random.randn(10)

# Smoke test axis
with makefig("ax") as ax:
ax.plot(x)

# Smoke test figure
with makefig("fig") as f:
f.get_axes()[0].plot(x)

# Saving
with makefig("ax", save=tmp_path / "myfig.jpg") as ax:
ax.plot(x)

fpath = tmp_path / "myfig.jpg"
assert fpath.exists()
fpath.unlink()

0 comments on commit 6d52336

Please sign in to comment.