diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 07f238bcb..30e245992 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -38,6 +38,8 @@ jobs: steps: - uses: actions/checkout@v2 + - name: Setup Graphviz + uses: ts-graphviz/setup-graphviz@v1 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: diff --git a/pyproject.toml b/pyproject.toml index 4d5bf0c57..ee46d6ddc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ dev = [ "dask[dataframe]; python_version < '3.7'", "playwright; python_version > '3.6'", "pytest-playwright; python_version > '3.6'", + "objgraph", ] assets = [ "solara-assets==1.22.0" diff --git a/solara/server/app.py b/solara/server/app.py index ee9bdb611..7577b6450 100644 --- a/solara/server/app.py +++ b/solara/server/app.py @@ -5,6 +5,7 @@ import sys import threading import traceback +import weakref from enum import Enum from pathlib import Path from typing import Any, Dict, List, Optional, cast @@ -125,7 +126,7 @@ def add_path(): else: # the module itself will be added by reloader # automatically - with reload.reloader.watch(): + with kernel_context.without_context(), reload.reloader.watch(): self.type = AppType.MODULE try: spec = importlib.util.find_spec(self.name) @@ -345,6 +346,9 @@ def solara_comm_target(comm, msg_first): def on_msg(msg): nonlocal app + comm = comm_ref() + assert comm is not None + context = kernel_context.get_current_context() data = msg["content"]["data"] method = data["method"] if method == "run": @@ -378,9 +382,10 @@ def on_msg(msg): else: logger.error("Unknown comm method called on solara.control comm: %s", method) - comm.on_msg(on_msg) - def reload(): + comm = comm_ref() + assert comm is not None + context = kernel_context.get_current_context() # we don't reload the app ourself, we send a message to the client # this ensures that we don't run code of any client that for some reason is connected # but not working anymore. And it indirectly passes a message from the current thread @@ -388,8 +393,11 @@ def reload(): logger.debug(f"Send reload to client: {context.id}") comm.send({"method": "reload"}) - context = kernel_context.get_current_context() - context.reload = reload + comm.on_msg(on_msg) + comm_ref = weakref.ref(comm) + del comm + + kernel_context.get_current_context().reload = reload def register_solara_comm_target(kernel: Kernel): diff --git a/solara/server/kernel.py b/solara/server/kernel.py index b9aa206e1..057667f57 100644 --- a/solara/server/kernel.py +++ b/solara/server/kernel.py @@ -290,6 +290,37 @@ def __init__(self): self.shell.display_pub.session = self.session self.shell.display_pub.pub_socket = self.iopub_socket + def close(self): + if self.comm_manager is None: + raise RuntimeError("Kernel already closed") + self.session.close() + self._cleanup_references() + + def _cleanup_references(self): + try: + # all of these reduce the circular references + # making it easier for the garbage collector to clean up + self.shell_handlers.clear() + self.control_handlers.clear() + for comm_object in list(self.comm_manager.comms.values()): # type: ignore + comm_object.close() + self.comm_manager.targets.clear() # type: ignore + # self.comm_manager.kernel points to us, but we cannot set it to None + # so we remove the circular reference by setting the comm_manager to None + self.comm_manager = None # type: ignore + self.session.parent = None # type: ignore + + self.shell.display_pub.session = None # type: ignore + self.shell.display_pub.pub_socket = None # type: ignore + self.shell = None # type: ignore + self.session.stream = None # type: ignore + self.session = None # type: ignore + self.stream.session = None # type: ignore + self.stream = None # type: ignore + self.iopub_socket = None # type: ignore + except Exception: + logger.exception("Error cleaning up references from kernel, not fatal") + async def _flush_control_queue(self): pass diff --git a/solara/server/kernel_context.py b/solara/server/kernel_context.py index 41609ce32..ff84fc30f 100644 --- a/solara/server/kernel_context.py +++ b/solara/server/kernel_context.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import dataclasses import enum import logging @@ -63,6 +64,7 @@ class VirtualKernelContext: # only used for testing _last_kernel_cull_task: "Optional[asyncio.Future[None]]" = None closed_event: threading.Event = dataclasses.field(default_factory=threading.Event) + lock: threading.RLock = dataclasses.field(default_factory=threading.RLock) def display(self, *args): print(args) # noqa @@ -81,6 +83,9 @@ def __exit__(self, *args): def close(self): logger.info("Shut down virtual kernel: %s", self.id) + with self.lock: + for key in self.page_status: + self.page_status[key] = PageStatus.CLOSED with self: if self.app_object is not None: if isinstance(self.app_object, reacton.core._RenderContext): @@ -93,9 +98,11 @@ def close(self): # what if we reference each other # import gc # gc.collect() - self.kernel.session.close() + self.kernel.close() + self.kernel = None # type: ignore if self.id in contexts: del contexts[self.id] + del current_context[get_current_thread_key()] self.closed_event.set() def _state_reset(self): @@ -123,10 +130,12 @@ def state_save(self, state_directory: os.PathLike): def page_connect(self, page_id: str): logger.info("Connect page %s for kernel %s", page_id, self.id) - assert self.page_status.get(page_id) != PageStatus.CLOSED, "cannot connect with the same page_id after a close" - self.page_status[page_id] = PageStatus.CONNECTED - if self._last_kernel_cull_task: - self._last_kernel_cull_task.cancel() + with self.lock: + if page_id in self.page_status and self.page_status.get(page_id) == PageStatus.CLOSED: + raise RuntimeError("Cannot connect a page that is already closed") + self.page_status[page_id] = PageStatus.CONNECTED + if self._last_kernel_cull_task: + self._last_kernel_cull_task.cancel() def page_disconnect(self, page_id: str) -> "asyncio.Future[None]": """Signal that a page has disconnected, and schedule a kernel cull if needed. @@ -139,7 +148,13 @@ def page_disconnect(self, page_id: str) -> "asyncio.Future[None]": """ logger.info("Disconnect page %s for kernel %s", page_id, self.id) future: "asyncio.Future[None]" = asyncio.Future() - self.page_status[page_id] = PageStatus.DISCONNECTED + with self.lock: + if self.page_status[page_id] == PageStatus.CLOSED: + logger.info("Page %s already closed for kernel %s", page_id, self.id) + future.set_result(None) + return future + assert self.page_status[page_id] == PageStatus.CONNECTED, "cannot disconnect a page that is in state: %r" % self.page_status[page_id] + self.page_status[page_id] = PageStatus.DISCONNECTED current_event_loop = asyncio.get_event_loop() async def kernel_cull(): @@ -147,15 +162,22 @@ async def kernel_cull(): cull_timeout_sleep_seconds = solara.util.parse_timedelta(solara.server.settings.kernel.cull_timeout) logger.info("Scheduling kernel cull, will wait for max %s before shutting down the virtual kernel %s", cull_timeout_sleep_seconds, self.id) await asyncio.sleep(cull_timeout_sleep_seconds) - has_connected_pages = PageStatus.CONNECTED in self.page_status.values() - if has_connected_pages: - logger.info("We have (re)connected pages, keeping the virtual kernel %s alive", self.id) - else: - logger.info("No connected pages, and timeout reached, shutting down virtual kernel %s", self.id) - self.close() - current_event_loop.call_soon_threadsafe(future.set_result, None) + with self.lock: + has_connected_pages = PageStatus.CONNECTED in self.page_status.values() + if has_connected_pages: + logger.info("We have (re)connected pages, keeping the virtual kernel %s alive", self.id) + else: + logger.info("No connected pages, and timeout reached, shutting down virtual kernel %s", self.id) + self.close() + try: + current_event_loop.call_soon_threadsafe(future.set_result, None) + except RuntimeError: + pass # event loop already closed, happens during testing except asyncio.CancelledError: - current_event_loop.call_soon_threadsafe(future.cancel, "cancelled because a new cull task was scheduled") + try: + current_event_loop.call_soon_threadsafe(future.cancel, "cancelled because a new cull task was scheduled") + except RuntimeError: + pass # event loop already closed, happens during testing raise has_connected_pages = PageStatus.CONNECTED in self.page_status.values() @@ -168,7 +190,10 @@ async def create_task(): task = asyncio.create_task(kernel_cull()) # create a reference to the task so we can cancel it later self._last_kernel_cull_task = task - await task + try: + await task + except RuntimeError: + pass # event loop already closed, happens during testing asyncio.run_coroutine_threadsafe(create_task(), keep_alive_event_loop) else: @@ -182,15 +207,21 @@ def page_close(self, page_id: str): different from a websocket/page disconnect, which we might want to recover from. """ - self.page_status[page_id] = PageStatus.CLOSED - logger.info("Close page %s for kernel %s", page_id, self.id) - has_connected_pages = PageStatus.CONNECTED in self.page_status.values() - has_disconnected_pages = PageStatus.DISCONNECTED in self.page_status.values() - if not (has_connected_pages or has_disconnected_pages): - logger.info("No connected or disconnected pages, shutting down virtual kernel %s", self.id) - if self._last_kernel_cull_task: - self._last_kernel_cull_task.cancel() - self.close() + + logger.info("page status: %s", self.page_status) + with self.lock: + if self.page_status[page_id] == PageStatus.CLOSED: + logger.info("Page %s already closed for kernel %s", page_id, self.id) + return + self.page_status[page_id] = PageStatus.CLOSED + logger.info("Close page %s for kernel %s", page_id, self.id) + has_connected_pages = PageStatus.CONNECTED in self.page_status.values() + has_disconnected_pages = PageStatus.DISCONNECTED in self.page_status.values() + if not (has_connected_pages or has_disconnected_pages): + logger.info("No connected or disconnected pages, shutting down virtual kernel %s", self.id) + if self._last_kernel_cull_task: + self._last_kernel_cull_task.cancel() + self.close() try: @@ -267,6 +298,21 @@ def set_current_context(context: Optional[VirtualKernelContext]): current_context[thread_key] = context +@contextlib.contextmanager +def without_context(): + context = None + try: + context = get_current_context() + except RuntimeError: + pass + thread_key = get_current_thread_key() + current_context[thread_key] = None + try: + yield + finally: + current_context[thread_key] = context + + def initialize_virtual_kernel(session_id: str, kernel_id: str, websocket: websocket.WebsocketWrapper): from solara.server import app as appmodule diff --git a/solara/server/patch.py b/solara/server/patch.py index 354b70636..c677ac833 100644 --- a/solara/server/patch.py +++ b/solara/server/patch.py @@ -13,6 +13,8 @@ import ipywidgets.widgets.widget_output from IPython.core.interactiveshell import InteractiveShell +import solara.util + from . import app, kernel_context, reload, settings from .utils import pdb_guard @@ -235,7 +237,8 @@ def auto_watch_get_template(get_template): def wrapper(abs_path): template = get_template(abs_path) - reload.reloader.watcher.add_file(abs_path) + with kernel_context.without_context(): + reload.reloader.watcher.add_file(abs_path) return template return wrapper @@ -255,9 +258,8 @@ def WidgetContextAwareThread__init__(self, *args, **kwargs): def Thread_debug_run(self): - if self.current_context: - kernel_context.set_context_for_thread(self.current_context, self) - with pdb_guard(): + context = self.current_context or solara.util.nullcontext() + with pdb_guard(), context: Thread__run(self) diff --git a/solara/server/shell.py b/solara/server/shell.py index 4a1d0df57..6def31ef3 100644 --- a/solara/server/shell.py +++ b/solara/server/shell.py @@ -1,3 +1,4 @@ +import atexit import sys from threading import local from unittest.mock import Mock @@ -175,10 +176,20 @@ class SolaraInteractiveShell(InteractiveShell): display_pub_class = Type(SolaraDisplayPublisher) history_manager = Any() # type: ignore + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + atexit.unregister(self.atexit_operations) + + magic = self.magics_manager.registry["ScriptMagics"] + atexit.unregister(magic.kill_bg_processes) + def set_parent(self, parent): """Tell the children about the parent message.""" self.display_pub.set_parent(parent) + def init_sys_modules(self): + pass # don't create a __main__, it will cause a mem leak + def init_history(self): self.history_manager = Mock() # type: ignore diff --git a/tests/integration/memleak_test.py b/tests/integration/memleak_test.py new file mode 100644 index 000000000..0a68454d1 --- /dev/null +++ b/tests/integration/memleak_test.py @@ -0,0 +1,87 @@ +import gc +import time +import weakref +from pathlib import Path +from typing import Optional + +import objgraph +import playwright.sync_api +import pytest + +import solara +import solara.server.kernel_context + +HERE = Path(__file__).parent + + +set_value = None +context: Optional["solara.server.kernel_context.VirtualKernelContext"] = None + + +@pytest.fixture +def no_cull_timeout(): + cull_timeout_previous = solara.server.settings.kernel.cull_timeout + solara.server.settings.kernel.cull_timeout = "0.0001s" + try: + yield + finally: + solara.server.settings.kernel.cull_timeout = cull_timeout_previous + + +def _scoped_test_memleak( + page_session: playwright.sync_api.Page, + solara_server, + solara_app, + extra_include_path, +): + with solara_app("solara.website.pages"): + page_session.goto(solara_server.base_url) + page_session.locator("text=Examples").first.wait_for() + assert len(solara.server.kernel_context.contexts) == 1 + context = list(solara.server.kernel_context.contexts.values())[0] + # we should not have created a new context + assert len(solara.server.kernel_context.contexts) == 1 + kernel = context.kernel + shell = kernel.shell + session = kernel.session + page_session.goto("about:blank") + assert context.closed_event.wait(10) + del shell.__dict__ + context_ref = weakref.ref(context) + kernel_ref = weakref.ref(kernel) + shell_ref = weakref.ref(shell) + session_ref = weakref.ref(session) + del context, kernel, shell, session + return context_ref, kernel_ref, shell_ref, session_ref + + +def test_memleak( + pytestconfig, + request, + browser: playwright.sync_api.Browser, + page_session: playwright.sync_api.Page, + solara_server, + solara_app, + extra_include_path, + no_cull_timeout, +): + # for unknown reasons, del does not work in CI + context_ref, kernel_ref, shell_ref, session_ref = _scoped_test_memleak(page_session, solara_server, solara_app, extra_include_path) + + for i in range(200): + time.sleep(0.1) + for gen in [2, 1, 0]: + gc.collect(gen) + if context_ref() is None and kernel_ref() is None and shell_ref() is None and session_ref() is None: + break + else: + name = solara_server.__class__.__name__ + output_path = Path(pytestconfig.getoption("--output")) / f"mem-leak-{name}.pdf" + output_path.parent.mkdir(parents=True, exist_ok=True) + print("output to", output_path, output_path.resolve()) # noqa + objgraph.show_backrefs([context_ref(), kernel_ref(), shell_ref(), session_ref()], filename=str(output_path), max_depth=15, too_many=15) + + assert context_ref() is None + assert kernel_ref() is None + assert shell_ref() is None + assert session_ref() is None diff --git a/tests/unit/lifecycle_test.py b/tests/unit/lifecycle_test.py index 92baab1d2..112cbeaa7 100644 --- a/tests/unit/lifecycle_test.py +++ b/tests/unit/lifecycle_test.py @@ -84,7 +84,7 @@ async def test_kernel_lifecycle_close_single(close_first, short_cull_timeout): async def test_kernel_lifecycle_close_while_disconnected(close_first, short_cull_timeout): # a reconnect should be possible within the reconnect window websocket = Mock() - context = kernel_context.initialize_virtual_kernel("session-id-1", "kernel-id-1", websocket) + context = kernel_context.initialize_virtual_kernel(f"session-id-1-{close_first}", f"kernel-id-1-{close_first}", websocket) context.page_connect("page-id-1") cull_task_1 = context.page_disconnect("page-id-1") await asyncio.sleep(0.1) @@ -93,7 +93,8 @@ async def test_kernel_lifecycle_close_while_disconnected(close_first, short_cull if close_first: context.page_close("page-id-2") await asyncio.sleep(0.01) - cull_task_2 = context.page_disconnect("page-id-2") + context.page_connect("page-id-3") + cull_task_2 = context.page_disconnect("page-id-3") else: cull_task_2 = context.page_disconnect("page-id-2") await asyncio.sleep(0.01) diff --git a/tests/unit/patch_test.py b/tests/unit/patch_test.py index 15976bc30..c00a47475 100644 --- a/tests/unit/patch_test.py +++ b/tests/unit/patch_test.py @@ -20,9 +20,10 @@ def test_widget_error_message_outside_context(no_kernel_context): def test_widget_dict(no_kernel_context): - kernel_shared = kernel.Kernel() - context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel_shared, session_id="session-1") - context2 = kernel_context.VirtualKernelContext(id="2", kernel=kernel_shared, session_id="session-2") + kernel1 = kernel.Kernel() + kernel2 = kernel.Kernel() + context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel1, session_id="session-1") + context2 = kernel_context.VirtualKernelContext(id="2", kernel=kernel2, session_id="session-2") with context1: btn1 = widgets.Button(description="context1")