diff --git a/solara/server/patch.py b/solara/server/patch.py index 929148b37..d2b3c6f80 100644 --- a/solara/server/patch.py +++ b/solara/server/patch.py @@ -345,13 +345,66 @@ def wrapper(): @once def patch_matplotlib(): + import matplotlib import matplotlib._pylab_helpers prev = matplotlib._pylab_helpers.Gcf.figs matplotlib._pylab_helpers.Gcf.figs = context_dict_user("matplotlib.pylab.figure_managers", prev) # type: ignore + RcParamsOriginal = matplotlib.RcParams + counter = 0 + lock = threading.Lock() + + class RcParamsScoped(context_dict, matplotlib.RcParams): + _was_initialized = False + _without_kernel_dict: Dict[Any, Any] + + def __init__(self, *args, **kwargs) -> None: + self._init() + RcParamsOriginal.__init__(self, *args, **kwargs) + + def _init(self): + nonlocal counter + with lock: + counter += 1 + self._user_dict_name = f"matplotlib.rcParams:{counter}" + # this creates a copy of the CPython side of the dict + self._without_kernel_dict = dict(zip(dict.keys(self), dict.values(self))) + self._was_initialized = True + + def _set(self, key, val): + # in matplotlib this directly calls dict.__setitem__ + # which would not call context_dict.__setitem__ + self[key] = val + + def _get(self, key): + # same as _get + return self[key] + + def _get_context_dict(self) -> dict: + if not self._was_initialized: + # since we monkey patch the class after __init__ was called + # we may have to do that later on + self._init() + if kernel_context.has_current_context(): + context = kernel_context.get_current_context() + if self._user_dict_name not in context.user_dicts: + # copy over the global settings when needed + context.user_dicts[self._user_dict_name] = self._without_kernel_dict.copy() + return context.user_dicts[self._user_dict_name] + else: + return self._without_kernel_dict + + matplotlib.RcParams = RcParamsScoped + matplotlib.rcParams.__class__ = RcParamsScoped + # we chose to monkeypatch the class, instead of re-assiging to reParams for 2 reasons: + # 1. the RcParams object could be imported in different namespaces + # 2. the rcParams has extra methods, which means we have to otherwise monkeypatch the context_dict + def cleanup(): matplotlib._pylab_helpers.Gcf.figs = prev + matplotlib.RcParams = RcParamsOriginal + matplotlib.rcParams.__class__ = RcParamsOriginal return cleanup @@ -383,7 +436,7 @@ def patch(): if "MPLBACKEND" not in os.environ: if ipykernel_version_major < 6: # changed in https://github.com/ipython/ipykernel/pull/591 - os.environ["MPLBACKEND"] = "ipykernel.pylab.backend_inline" + os.environ["MPLBACKEND"] = "module://ipykernel.pylab.backend_inline" else: os.environ["MPLBACKEND"] = "module://matplotlib_inline.backend_inline" diff --git a/tests/unit/matplotlib_test.py b/tests/unit/matplotlib_test.py index 203e4c839..754caf9cd 100644 --- a/tests/unit/matplotlib_test.py +++ b/tests/unit/matplotlib_test.py @@ -31,20 +31,45 @@ def test_pylab(no_kernel_context): assert len(Gcf.get_all_fig_managers()) == 0 plt.figure() + assert len(Gcf.get_all_fig_managers()) == 1 + + default_color = (1, 1, 1, 0) + white = "white" + black = "black" + assert plt.rcParams["figure.facecolor"] in [default_color, white] + plt.style.use("default") + assert plt.rcParams["figure.facecolor"] == white + + plt.style.use("dark_background") + assert plt.rcParams["figure.facecolor"] == black + with context_1: assert len(Gcf.get_all_fig_managers()) == 0 plt.figure() + assert plt.rcParams["figure.facecolor"] == black assert len(Gcf.get_all_fig_managers()) == 1 + plt.style.use("default") + assert plt.rcParams["figure.facecolor"] == white + + assert plt.rcParams["figure.facecolor"] == black assert len(Gcf.get_all_fig_managers()) == 1 + plt.style.use("default") + assert plt.rcParams["figure.facecolor"] == white + with context_2: + assert plt.rcParams["figure.facecolor"] == white assert len(Gcf.get_all_fig_managers()) == 0 plt.figure() assert len(Gcf.get_all_fig_managers()) == 1 plt.figure() assert len(Gcf.get_all_fig_managers()) == 2 + plt.style.use("dark_background") + assert plt.rcParams["figure.facecolor"] == black with context_1: assert len(Gcf.get_all_fig_managers()) == 1 + assert plt.rcParams["figure.facecolor"] == white + assert plt.rcParams["figure.facecolor"] == white finally: cleanup()