diff --git a/solara/components/matplotlib.py b/solara/components/matplotlib.py index c527e89c5..b887ca351 100644 --- a/solara/components/matplotlib.py +++ b/solara/components/matplotlib.py @@ -14,31 +14,44 @@ def FigureMatplotlib( ): """Display a matplotlib figure. - We recomment not to use the pyplot interface, but rather to create a figure directly, e.g: - ```python - import reacton - import solara as sol + ## Example + + ```solara + import solara from matplotlib.figure import Figure @solara.component def Page(): - # do this instead of plt.figure() fig = Figure() ax = fig.subplots() ax.plot([1, 2, 3], [1, 4, 9]) return solara.FigureMatplotlib(fig) - ``` - You should also avoid drawing using the pyplot interface, as it is not thread-safe. If you do use it, - your drawing might be corrupted due to another thread/user drawing at the same time. - When running under solara-server, we by default configure the same 'inline' backend as in the Jupyter notebook. For performance reasons, you might want to pass in a list of dependencies that indicate when the figure changed, to avoid re-rendering it on every render. + ## Example using pyplot + + Note that it is also possible to use the pyplot interface, but be sure to close the figure not to leak memory. + + ```solara + import solara + import matplotlib.pyplot as plt + + @solara.component + def Page(): + plt.figure() + plt.plot([1, 2, 3], [1, 4, 9]) + plt.show() + plt.close() + ``` + + Note that the figure is not the same size using the pyplot interface, due to the default figure size being different. + ## Arguments diff --git a/solara/server/app.py b/solara/server/app.py index 4b2138798..31ff174f8 100644 --- a/solara/server/app.py +++ b/solara/server/app.py @@ -15,7 +15,7 @@ import solara -from . import kernel_context, reload, settings +from . import kernel_context, patch, reload, settings from .kernel import Kernel from .utils import pdb_guard @@ -68,6 +68,9 @@ def __init__(self, name, default_app_name="Page"): with dummy_kernel_context: app = self._execute() + # We now ran the app, now we can check for patches that require heavy imports + patch.patch_heavy_imports() + self._first_execute_app = app reload.reloader.root_path = self.directory if self.type == AppType.MODULE: @@ -198,6 +201,9 @@ def run(self): self._first_execute_app = None self._first_execute_app = self._execute() print("Re-executed app", self.name) # noqa + # We now ran the app again, might contain new imports + patch.patch_heavy_imports() + return self._first_execute_app def on_file_change(self, name): diff --git a/solara/server/patch.py b/solara/server/patch.py index 7eef7cd93..d9e079724 100644 --- a/solara/server/patch.py +++ b/solara/server/patch.py @@ -1,3 +1,4 @@ +import functools import logging import os import pdb @@ -198,6 +199,16 @@ def __len__(self): def __setitem__(self, key, value): self._get_context_dict().__setitem__(key, value) + # support OrderedDict API for matplotlib + def move_to_end(self, key, last=True): + assert last, "only last=True is supported" + item = self.pop(key) + self[key] = item + + # matplotlib assumes .values() returns a list + def values(self): + return list(self._get_context_dict().values()) + class context_dict_widgets(context_dict): def _get_context_dict(self) -> dict: @@ -315,6 +326,41 @@ def patch_ipyreact(): ipyreact.importmap._update_import_map = lambda: None +def once(f): + called = False + return_value = None + + @functools.wraps(f) + def wrapper(): + nonlocal called + nonlocal return_value + if called: + return return_value + called = True + return_value = f() + return return_value + + return wrapper + + +@once +def patch_matplotlib(): + import matplotlib._pylab_helpers + + prev = matplotlib._pylab_helpers.Gcf.figs + matplotlib._pylab_helpers.Gcf.figs = context_dict_user("matplotlib.pylab.figure_managers", prev) # type: ignore + + def cleanup(): + matplotlib._pylab_helpers.Gcf.figs = prev + + return cleanup + + +def patch_heavy_imports(): + # patches that we only want to do if a package is imported, because they may slow down startup + patch_matplotlib() + + def patch(): global _patched global global_widgets_dict diff --git a/tests/unit/matplotlib_test.py b/tests/unit/matplotlib_test.py index 60115d6ac..203e4c839 100644 --- a/tests/unit/matplotlib_test.py +++ b/tests/unit/matplotlib_test.py @@ -1,6 +1,9 @@ -import solara from matplotlib.figure import Figure +import solara +import solara.server.patch +from solara.server import kernel + @solara.component def Page(): @@ -14,3 +17,36 @@ def Page(): def test_render(): box, rc = solara.render(Page(), handle_error=False) assert len(box.children) == 1 + + +def test_pylab(no_kernel_context): + cleanup = solara.server.patch.patch_matplotlib() + try: + kernel_1 = kernel.Kernel() + context_1 = solara.server.kernel_context.VirtualKernelContext(id="1", kernel=kernel_1, session_id="session-1") + kernel_2 = kernel.Kernel() + context_2 = solara.server.kernel_context.VirtualKernelContext(id="2", kernel=kernel_2, session_id="session-1") + import matplotlib.pyplot as plt + from matplotlib._pylab_helpers import Gcf + + assert len(Gcf.get_all_fig_managers()) == 0 + plt.figure() + assert len(Gcf.get_all_fig_managers()) == 1 + with context_1: + assert len(Gcf.get_all_fig_managers()) == 0 + plt.figure() + assert len(Gcf.get_all_fig_managers()) == 1 + assert len(Gcf.get_all_fig_managers()) == 1 + with context_2: + assert len(Gcf.get_all_fig_managers()) == 0 + plt.figure() + assert len(Gcf.get_all_fig_managers()) == 1 + plt.figure() + assert len(Gcf.get_all_fig_managers()) == 2 + with context_1: + assert len(Gcf.get_all_fig_managers()) == 1 + + finally: + cleanup() + context_1.close() + context_2.close()