Skip to content

Commit

Permalink
feat: support matplotlib pylab interface in solara server
Browse files Browse the repository at this point in the history
This makes code that 'just works' in the notebook also work in solara
server.
  • Loading branch information
maartenbreddels committed Mar 10, 2024
1 parent b5793b5 commit 57ce0b2
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 11 deletions.
31 changes: 22 additions & 9 deletions solara/components/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,44 @@ def FigureMatplotlib(
):
"""Display a matplotlib figure.
We recomment not to use the pyplot interface, but rather to create a figure directly, e.g:
```python
import reacton
import solara as sol
## Example
```solara
import solara
from matplotlib.figure import Figure
@solara.component
def Page():
# do this instead of plt.figure()
fig = Figure()
ax = fig.subplots()
ax.plot([1, 2, 3], [1, 4, 9])
return solara.FigureMatplotlib(fig)
```
You should also avoid drawing using the pyplot interface, as it is not thread-safe. If you do use it,
your drawing might be corrupted due to another thread/user drawing at the same time.
When running under solara-server, we by default configure the same 'inline' backend as in the Jupyter notebook.
For performance reasons, you might want to pass in a list of dependencies that indicate when
the figure changed, to avoid re-rendering it on every render.
## Example using pyplot
Note that it is also possible to use the pyplot interface, but be sure to close the figure not to leak memory.
```solara
import solara
import matplotlib.pyplot as plt
@solara.component
def Page():
plt.figure()
plt.plot([1, 2, 3], [1, 4, 9])
plt.show()
plt.close()
```
Note that the figure is not the same size using the pyplot interface, due to the default figure size being different.
## Arguments
Expand Down
8 changes: 7 additions & 1 deletion solara/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import solara

from . import kernel_context, reload, settings
from . import kernel_context, patch, reload, settings
from .kernel import Kernel
from .utils import pdb_guard

Expand Down Expand Up @@ -68,6 +68,9 @@ def __init__(self, name, default_app_name="Page"):
with dummy_kernel_context:
app = self._execute()

# We now ran the app, now we can check for patches that require heavy imports
patch.patch_heavy_imports()

self._first_execute_app = app
reload.reloader.root_path = self.directory
if self.type == AppType.MODULE:
Expand Down Expand Up @@ -198,6 +201,9 @@ def run(self):
self._first_execute_app = None
self._first_execute_app = self._execute()
print("Re-executed app", self.name) # noqa
# We now ran the app again, might contain new imports
patch.patch_heavy_imports()

return self._first_execute_app

def on_file_change(self, name):
Expand Down
47 changes: 47 additions & 0 deletions solara/server/patch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import logging
import os
import pdb
Expand Down Expand Up @@ -198,6 +199,16 @@ def __len__(self):
def __setitem__(self, key, value):
self._get_context_dict().__setitem__(key, value)

# support OrderedDict API for matplotlib
def move_to_end(self, key, last=True):
assert last, "only last=True is supported"
item = self.pop(key)
self[key] = item

# matplotlib assumes .values() returns a list
def values(self):
return list(self._get_context_dict().values())


class context_dict_widgets(context_dict):
def _get_context_dict(self) -> dict:
Expand Down Expand Up @@ -315,6 +326,42 @@ def patch_ipyreact():
ipyreact.importmap._update_import_map = lambda: None


def once(f):
called = False
return_value = None

@functools.wraps(f)
def wrapper():
nonlocal called
nonlocal return_value
if called:
return return_value
called = True
return_value = f()
return return_value

return wrapper


@once
def patch_matplotlib():
import matplotlib._pylab_helpers

prev = matplotlib._pylab_helpers.Gcf.figs
matplotlib._pylab_helpers.Gcf.figs = context_dict_user("matplotlib.pylab.figure_managers", prev) # type: ignore

def cleanup():
matplotlib._pylab_helpers.Gcf.figs = prev

return cleanup


def patch_heavy_imports():
# patches that we only want to do if a package is imported, because they may slow down startup
if "matplotlib" in sys.modules:
patch_matplotlib()


def patch():
global _patched
global global_widgets_dict
Expand Down
38 changes: 37 additions & 1 deletion tests/unit/matplotlib_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import solara
from matplotlib.figure import Figure

import solara
import solara.server.patch
from solara.server import kernel


@solara.component
def Page():
Expand All @@ -14,3 +17,36 @@ def Page():
def test_render():
box, rc = solara.render(Page(), handle_error=False)
assert len(box.children) == 1


def test_pylab(no_kernel_context):
cleanup = solara.server.patch.patch_matplotlib()
try:
kernel_1 = kernel.Kernel()
context_1 = solara.server.kernel_context.VirtualKernelContext(id="1", kernel=kernel_1, session_id="session-1")
kernel_2 = kernel.Kernel()
context_2 = solara.server.kernel_context.VirtualKernelContext(id="2", kernel=kernel_2, session_id="session-1")
import matplotlib.pyplot as plt
from matplotlib._pylab_helpers import Gcf

assert len(Gcf.get_all_fig_managers()) == 0
plt.figure()
assert len(Gcf.get_all_fig_managers()) == 1
with context_1:
assert len(Gcf.get_all_fig_managers()) == 0
plt.figure()
assert len(Gcf.get_all_fig_managers()) == 1
assert len(Gcf.get_all_fig_managers()) == 1
with context_2:
assert len(Gcf.get_all_fig_managers()) == 0
plt.figure()
assert len(Gcf.get_all_fig_managers()) == 1
plt.figure()
assert len(Gcf.get_all_fig_managers()) == 2
with context_1:
assert len(Gcf.get_all_fig_managers()) == 1

finally:
cleanup()
context_1.close()
context_2.close()

0 comments on commit 57ce0b2

Please sign in to comment.