Skip to content

Commit

Permalink
feat: pylab.style support (is scoped per virtual kernel)
Browse files Browse the repository at this point in the history
  • Loading branch information
maartenbreddels committed Mar 7, 2024
1 parent 1630e3a commit 2de1d21
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 1 deletion.
55 changes: 54 additions & 1 deletion solara/server/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,66 @@ def wrapper():

@once
def patch_matplotlib():
import matplotlib
import matplotlib._pylab_helpers

prev = matplotlib._pylab_helpers.Gcf.figs
matplotlib._pylab_helpers.Gcf.figs = context_dict_user("matplotlib.pylab.figure_managers", prev) # type: ignore

RcParamsOriginal = matplotlib.RcParams
counter = 0
lock = threading.Lock()

class RcParamsScoped(context_dict, matplotlib.RcParams):
_was_initialized = False
_without_kernel_dict: Dict[Any, Any]

def __init__(self, *args, **kwargs) -> None:
self._init()
RcParamsOriginal.__init__(self, *args, **kwargs)

def _init(self):
nonlocal counter
with lock:
counter += 1
self._user_dict_name = f"matplotlib.rcParams:{counter}"
# this creates a copy of the CPython side of the dict
self._without_kernel_dict = dict(zip(dict.keys(self), dict.values(self)))
self._was_initialized = True

def _set(self, key, val):
# in matplotlib this directly calls dict.__setitem__
# which would not call context_dict.__setitem__
self[key] = val

def _get(self, key):
# same as _get
return self[key]

def _get_context_dict(self) -> dict:
if not self._was_initialized:
# since we monkey patch the class after __init__ was called
# we may have to do that later on
self._init()
if kernel_context.has_current_context():
context = kernel_context.get_current_context()
if self._user_dict_name not in context.user_dicts:
# copy over the global settings when needed
context.user_dicts[self._user_dict_name] = self._without_kernel_dict.copy()
return context.user_dicts[self._user_dict_name]
else:
return self._without_kernel_dict

matplotlib.RcParams = RcParamsScoped
matplotlib.rcParams.__class__ = RcParamsScoped
# we chose to monkeypatch the class, instead of re-assiging to reParams for 2 reasons:
# 1. the RcParams object could be imported in different namespaces
# 2. the rcParams has extra methods, which means we have to otherwise monkeypatch the context_dict

def cleanup():
matplotlib._pylab_helpers.Gcf.figs = prev
matplotlib.RcParams = RcParamsOriginal
matplotlib.rcParams.__class__ = RcParamsOriginal

return cleanup

Expand Down Expand Up @@ -382,7 +435,7 @@ def patch():
if "MPLBACKEND" not in os.environ:
if ipykernel_version_major < 6:
# changed in https://github.com/ipython/ipykernel/pull/591
os.environ["MPLBACKEND"] = "ipykernel.pylab.backend_inline"
os.environ["MPLBACKEND"] = "module://ipykernel.pylab.backend_inline"
else:
os.environ["MPLBACKEND"] = "module://matplotlib_inline.backend_inline"

Expand Down
25 changes: 25 additions & 0 deletions tests/unit/matplotlib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,45 @@ def test_pylab(no_kernel_context):

assert len(Gcf.get_all_fig_managers()) == 0
plt.figure()

assert len(Gcf.get_all_fig_managers()) == 1

default_color = (1, 1, 1, 0)
white = "white"
black = "black"
assert plt.rcParams["figure.facecolor"] in [default_color, white]
plt.style.use("default")
assert plt.rcParams["figure.facecolor"] == white

plt.style.use("dark_background")
assert plt.rcParams["figure.facecolor"] == black

with context_1:
assert len(Gcf.get_all_fig_managers()) == 0
plt.figure()
assert plt.rcParams["figure.facecolor"] == black
assert len(Gcf.get_all_fig_managers()) == 1
plt.style.use("default")
assert plt.rcParams["figure.facecolor"] == white

assert plt.rcParams["figure.facecolor"] == black
assert len(Gcf.get_all_fig_managers()) == 1
plt.style.use("default")
assert plt.rcParams["figure.facecolor"] == white

with context_2:
assert plt.rcParams["figure.facecolor"] == white
assert len(Gcf.get_all_fig_managers()) == 0
plt.figure()
assert len(Gcf.get_all_fig_managers()) == 1
plt.figure()
assert len(Gcf.get_all_fig_managers()) == 2
plt.style.use("dark_background")
assert plt.rcParams["figure.facecolor"] == black
with context_1:
assert len(Gcf.get_all_fig_managers()) == 1
assert plt.rcParams["figure.facecolor"] == white
assert plt.rcParams["figure.facecolor"] == white

finally:
cleanup()
Expand Down

0 comments on commit 2de1d21

Please sign in to comment.