Skip to content

Commit

Permalink
feat: improve matplotlib support - mimic Jupyter (#540)
Browse files Browse the repository at this point in the history
feat: improve matplotlib support - as Jupyter


Several improvements:

## Default backend
We do this by checking if the env var 'MPLBACKEND' is not set, and if so, we set it to
module://matplotlib_inline.backend_inline

## Support for display(figure)

This makes it behave more like the Jupyter notebook environment.

## Support matplotlib pylab interface in solara server

This makes code that 'just works' in the notebook also work in solara
server. We patch global dicts with scoped dicts.

See #539
  • Loading branch information
maartenbreddels authored Mar 11, 2024
1 parent 83e398b commit fce107f
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 21 deletions.
34 changes: 23 additions & 11 deletions solara/components/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,44 @@ def FigureMatplotlib(
):
"""Display a matplotlib figure.
We recomment not to use the pyplot interface, but rather to create a figure directly, e.g:
```python
import reacton
import solara as sol
## Example
```solara
import solara
from matplotlib.figure import Figure
@solara.component
def Page():
# do this instead of plt.figure()
fig = Figure()
ax = fig.subplots()
ax.plot([1, 2, 3], [1, 4, 9])
return solara.FigureMatplotlib(fig)
```
You should also avoid drawing using the pyplot interface, as it is not thread-safe. If you do use it,
your drawing might be corrupted due to another thread/user drawing at the same time.
If you still must use pyplot to create the figure, make sure you call `plt.switch_backend("agg")`
before creating the figure, to avoid starting an interactive backend.
When running under solara-server, we by default configure the same 'inline' backend as in the Jupyter notebook.
For performance reasons, you might want to pass in a list of dependencies that indicate when
the figure changed, to avoid re-rendering it on every render.
## Example using pyplot
Note that it is also possible to use the pyplot interface, but be sure to close the figure not to leak memory.
```solara
import solara
import matplotlib.pyplot as plt
@solara.component
def Page():
plt.figure()
plt.plot([1, 2, 3], [1, 4, 9])
plt.show()
plt.close()
```
Note that the figure is not the same size using the pyplot interface, due to the default figure size being different.
## Arguments
Expand Down
8 changes: 7 additions & 1 deletion solara/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import solara

from . import kernel_context, reload, settings
from . import kernel_context, patch, reload, settings
from .kernel import Kernel
from .utils import pdb_guard

Expand Down Expand Up @@ -68,6 +68,9 @@ def __init__(self, name, default_app_name="Page"):
with dummy_kernel_context:
app = self._execute()

# We now ran the app, now we can check for patches that require heavy imports
patch.patch_heavy_imports()

self._first_execute_app = app
reload.reloader.root_path = self.directory
if self.type == AppType.MODULE:
Expand Down Expand Up @@ -198,6 +201,9 @@ def run(self):
self._first_execute_app = None
self._first_execute_app = self._execute()
print("Re-executed app", self.name) # noqa
# We now ran the app again, might contain new imports
patch.patch_heavy_imports()

return self._first_execute_app

def on_file_change(self, name):
Expand Down
109 changes: 109 additions & 0 deletions solara/server/patch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import functools
import logging
import os
import pdb
import sys
import threading
Expand All @@ -25,6 +27,7 @@
if patch_display is not None and sys.platform != "emscripten":
patch_display()
ipywidget_version_major = int(ipywidgets.__version__.split(".")[0])
ipykernel_version_major = int(ipykernel.__version__.split(".")[0])


class FakeIPython:
Expand Down Expand Up @@ -196,6 +199,16 @@ def __len__(self):
def __setitem__(self, key, value):
self._get_context_dict().__setitem__(key, value)

# support OrderedDict API for matplotlib
def move_to_end(self, key, last=True):
assert last, "only last=True is supported"
item = self.pop(key)
self[key] = item

# matplotlib assumes .values() returns a list
def values(self):
return list(self._get_context_dict().values())


class context_dict_widgets(context_dict):
def _get_context_dict(self) -> dict:
Expand Down Expand Up @@ -316,6 +329,95 @@ def patch_ipyreact():
ipyreact.importmap._update_import_map = lambda: None


def once(f):
called = False
return_value = None

@functools.wraps(f)
def wrapper():
nonlocal called
nonlocal return_value
if called:
return return_value
called = True
return_value = f()
return return_value

return wrapper


@once
def patch_matplotlib():
import matplotlib
import matplotlib._pylab_helpers

prev = matplotlib._pylab_helpers.Gcf.figs
matplotlib._pylab_helpers.Gcf.figs = context_dict_user("matplotlib.pylab.figure_managers", prev) # type: ignore

RcParamsOriginal = matplotlib.RcParams
counter = 0
lock = threading.Lock()

class RcParamsScoped(context_dict, matplotlib.RcParams):
_was_initialized = False
_without_kernel_dict: Dict[Any, Any]

def __init__(self, *args, **kwargs) -> None:
self._init()
RcParamsOriginal.__init__(self, *args, **kwargs)

def _init(self):
nonlocal counter
with lock:
counter += 1
self._user_dict_name = f"matplotlib.rcParams:{counter}"
# this creates a copy of the CPython side of the dict
self._without_kernel_dict = dict(zip(dict.keys(self), dict.values(self)))
self._was_initialized = True

def _set(self, key, val):
# in matplotlib this directly calls dict.__setitem__
# which would not call context_dict.__setitem__
self[key] = val

def _get(self, key):
# same as _get
return self[key]

def _get_context_dict(self) -> dict:
if not self._was_initialized:
# since we monkey patch the class after __init__ was called
# we may have to do that later on
self._init()
if kernel_context.has_current_context():
context = kernel_context.get_current_context()
if self._user_dict_name not in context.user_dicts:
# copy over the global settings when needed
context.user_dicts[self._user_dict_name] = self._without_kernel_dict.copy()
return context.user_dicts[self._user_dict_name]
else:
return self._without_kernel_dict

matplotlib.RcParams = RcParamsScoped
matplotlib.rcParams.__class__ = RcParamsScoped
# we chose to monkeypatch the class, instead of re-assiging to reParams for 2 reasons:
# 1. the RcParams object could be imported in different namespaces
# 2. the rcParams has extra methods, which means we have to otherwise monkeypatch the context_dict

def cleanup():
matplotlib._pylab_helpers.Gcf.figs = prev
matplotlib.RcParams = RcParamsOriginal
matplotlib.rcParams.__class__ = RcParamsOriginal

return cleanup


def patch_heavy_imports():
# patches that we only want to do if a package is imported, because they may slow down startup
if "matplotlib" in sys.modules:
patch_matplotlib()


def patch():
global _patched
global global_widgets_dict
Expand All @@ -334,6 +436,13 @@ def patch():
else:
patch_ipyreact()

if "MPLBACKEND" not in os.environ:
if ipykernel_version_major < 6:
# changed in https://github.com/ipython/ipykernel/pull/591
os.environ["MPLBACKEND"] = "module://ipykernel.pylab.backend_inline"
else:
os.environ["MPLBACKEND"] = "module://matplotlib_inline.backend_inline"

# the ipyvue.Template module cannot be accessed like ipyvue.Template
# because the import in ipvue overrides it
template_mod = sys.modules["ipyvue.Template"]
Expand Down
17 changes: 17 additions & 0 deletions solara/server/shell.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import io
import sys
from binascii import b2a_base64
from threading import local
from unittest.mock import Mock

Expand Down Expand Up @@ -184,8 +186,23 @@ def init_history(self):

def init_display_formatter(self):
super().init_display_formatter()
assert self.display_formatter is not None
self.display_formatter.ipython_display_formatter = reacton.patch_display.ReactonDisplayFormatter()

# matplotlib support for display(figure)
# IPython.core.pylabtools has support for this, but it requires importing matplotlib
# which would slow down startup, so we do it here using for_type using a string as argument.
def encode_png(figure, **kwargs):
f = io.BytesIO()
format = "png"
figure.savefig(f, format=format, **kwargs)
bytes_data = f.getvalue()
base64_data = b2a_base64(bytes_data, newline=False).decode("ascii")
return base64_data

formatter = self.display_formatter.formatters["image/png"]
formatter.for_type("matplotlib.figure.Figure", encode_png)

def init_display_pub(self):
super().init_display_pub()
self.display_pub.register_hook(self.display_in_reacton_hook)
Expand Down
5 changes: 0 additions & 5 deletions solara/website/components/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@
import solara.toestand
from solara.components.markdown import ExceptionGuard

if solara._using_solara_server():
from matplotlib import pyplot as plt

plt.switch_backend("agg")

HERE = Path(__file__).parent


Expand Down
3 changes: 0 additions & 3 deletions solara/website/pages/examples/general/live_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@

import solara

# ensure that an interactive backend doesn't start when plotting with matplotlib
plt.switch_backend("agg")


@solara.component
def Page():
Expand Down
63 changes: 62 additions & 1 deletion tests/unit/matplotlib_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import solara
from matplotlib.figure import Figure

import solara
import solara.server.patch
from solara.server import kernel


@solara.component
def Page():
Expand All @@ -14,3 +17,61 @@ def Page():
def test_render():
box, rc = solara.render(Page(), handle_error=False)
assert len(box.children) == 1


def test_pylab(no_kernel_context):
cleanup = solara.server.patch.patch_matplotlib()
try:
kernel_1 = kernel.Kernel()
context_1 = solara.server.kernel_context.VirtualKernelContext(id="1", kernel=kernel_1, session_id="session-1")
kernel_2 = kernel.Kernel()
context_2 = solara.server.kernel_context.VirtualKernelContext(id="2", kernel=kernel_2, session_id="session-1")
import matplotlib.pyplot as plt
from matplotlib._pylab_helpers import Gcf

assert len(Gcf.get_all_fig_managers()) == 0
plt.figure()

assert len(Gcf.get_all_fig_managers()) == 1

default_color = (1, 1, 1, 0)
white = "white"
black = "black"
assert plt.rcParams["figure.facecolor"] in [default_color, white]
plt.style.use("default")
assert plt.rcParams["figure.facecolor"] == white

plt.style.use("dark_background")
assert plt.rcParams["figure.facecolor"] == black

with context_1:
assert len(Gcf.get_all_fig_managers()) == 0
plt.figure()
assert plt.rcParams["figure.facecolor"] == black
assert len(Gcf.get_all_fig_managers()) == 1
plt.style.use("default")
assert plt.rcParams["figure.facecolor"] == white

assert plt.rcParams["figure.facecolor"] == black
assert len(Gcf.get_all_fig_managers()) == 1
plt.style.use("default")
assert plt.rcParams["figure.facecolor"] == white

with context_2:
assert plt.rcParams["figure.facecolor"] == white
assert len(Gcf.get_all_fig_managers()) == 0
plt.figure()
assert len(Gcf.get_all_fig_managers()) == 1
plt.figure()
assert len(Gcf.get_all_fig_managers()) == 2
plt.style.use("dark_background")
assert plt.rcParams["figure.facecolor"] == black
with context_1:
assert len(Gcf.get_all_fig_managers()) == 1
assert plt.rcParams["figure.facecolor"] == white
assert plt.rcParams["figure.facecolor"] == white

finally:
cleanup()
context_1.close()
context_2.close()

0 comments on commit fce107f

Please sign in to comment.