From 53e1d6d8c49ac6b4d2b4a7ece3189a29bd2d4c43 Mon Sep 17 00:00:00 2001 From: "Maarten A. Breddels" Date: Mon, 13 Nov 2023 12:30:11 +0100 Subject: [PATCH] fix: test and avoid memory leaks by checking references to virtual kernel --- .github/workflows/integration.yml | 2 + pyproject.toml | 1 + solara/server/app.py | 18 +++-- solara/server/kernel.py | 39 +++++++++- solara/server/kernel_context.py | 124 ++++++++++++++++++++---------- solara/server/patch.py | 10 ++- solara/server/server.py | 16 ++-- solara/server/shell.py | 11 +++ tests/integration/memleak_test.py | 83 ++++++++++++++++++++ tests/unit/lifecycle_test.py | 5 +- tests/unit/patch_test.py | 7 +- 11 files changed, 257 insertions(+), 59 deletions(-) create mode 100644 tests/integration/memleak_test.py diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 07f238bcb..30e245992 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -38,6 +38,8 @@ jobs: steps: - uses: actions/checkout@v2 + - name: Setup Graphviz + uses: ts-graphviz/setup-graphviz@v1 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: diff --git a/pyproject.toml b/pyproject.toml index 4d5bf0c57..ee46d6ddc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ dev = [ "dask[dataframe]; python_version < '3.7'", "playwright; python_version > '3.6'", "pytest-playwright; python_version > '3.6'", + "objgraph", ] assets = [ "solara-assets==1.22.0" diff --git a/solara/server/app.py b/solara/server/app.py index ee9bdb611..7577b6450 100644 --- a/solara/server/app.py +++ b/solara/server/app.py @@ -5,6 +5,7 @@ import sys import threading import traceback +import weakref from enum import Enum from pathlib import Path from typing import Any, Dict, List, Optional, cast @@ -125,7 +126,7 @@ def add_path(): else: # the module itself will be added by reloader # automatically - with reload.reloader.watch(): + with kernel_context.without_context(), reload.reloader.watch(): self.type = AppType.MODULE try: spec = importlib.util.find_spec(self.name) @@ -345,6 +346,9 @@ def solara_comm_target(comm, msg_first): def on_msg(msg): nonlocal app + comm = comm_ref() + assert comm is not None + context = kernel_context.get_current_context() data = msg["content"]["data"] method = data["method"] if method == "run": @@ -378,9 +382,10 @@ def on_msg(msg): else: logger.error("Unknown comm method called on solara.control comm: %s", method) - comm.on_msg(on_msg) - def reload(): + comm = comm_ref() + assert comm is not None + context = kernel_context.get_current_context() # we don't reload the app ourself, we send a message to the client # this ensures that we don't run code of any client that for some reason is connected # but not working anymore. And it indirectly passes a message from the current thread @@ -388,8 +393,11 @@ def reload(): logger.debug(f"Send reload to client: {context.id}") comm.send({"method": "reload"}) - context = kernel_context.get_current_context() - context.reload = reload + comm.on_msg(on_msg) + comm_ref = weakref.ref(comm) + del comm + + kernel_context.get_current_context().reload = reload def register_solara_comm_target(kernel: Kernel): diff --git a/solara/server/kernel.py b/solara/server/kernel.py index b9aa206e1..e5f6ef0bc 100644 --- a/solara/server/kernel.py +++ b/solara/server/kernel.py @@ -217,7 +217,10 @@ def send_websockets(websockets: Set[websocket.WebsocketWrapper], binary_msg): ws.send(binary_msg) except: # noqa # in case of any issue, we simply remove it from the list - websockets.remove(ws) + try: + websockets.remove(ws) + except KeyError: + pass # already removed class SessionWebsocket(session.Session): @@ -233,6 +236,8 @@ def close(self): pass def send(self, stream, msg_or_type, content=None, parent=None, ident=None, buffers=None, track=False, header=None, metadata=None): + if stream is None: + return # can happen when the kernel is closed but someone was still trying to send a message try: if isinstance(msg_or_type, dict): msg = msg_or_type @@ -290,6 +295,38 @@ def __init__(self): self.shell.display_pub.session = self.session self.shell.display_pub.pub_socket = self.iopub_socket + def close(self): + if self.comm_manager is None: + raise RuntimeError("Kernel already closed") + self.session.close() + self._cleanup_references() + + def _cleanup_references(self): + try: + # all of these reduce the circular references + # making it easier for the garbage collector to clean up + self.shell_handlers.clear() + self.control_handlers.clear() + for comm_object in list(self.comm_manager.comms.values()): # type: ignore + comm_object.close() + self.comm_manager.targets.clear() # type: ignore + # self.comm_manager.kernel points to us, but we cannot set it to None + # so we remove the circular reference by setting the comm_manager to None + self.comm_manager = None # type: ignore + self.session.parent = None # type: ignore + + self.shell.display_pub.session = None # type: ignore + self.shell.display_pub.pub_socket = None # type: ignore + self.shell = None # type: ignore + self.session.websockets.clear() + self.session.stream = None # type: ignore + self.session = None # type: ignore + self.stream.session = None # type: ignore + self.stream = None # type: ignore + self.iopub_socket = None # type: ignore + except Exception: + logger.exception("Error cleaning up references from kernel, not fatal") + async def _flush_control_queue(self): pass diff --git a/solara/server/kernel_context.py b/solara/server/kernel_context.py index 41609ce32..eeff9b8ce 100644 --- a/solara/server/kernel_context.py +++ b/solara/server/kernel_context.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import dataclasses import enum import logging @@ -63,6 +64,7 @@ class VirtualKernelContext: # only used for testing _last_kernel_cull_task: "Optional[asyncio.Future[None]]" = None closed_event: threading.Event = dataclasses.field(default_factory=threading.Event) + lock: threading.RLock = dataclasses.field(default_factory=threading.RLock) def display(self, *args): print(args) # noqa @@ -81,22 +83,27 @@ def __exit__(self, *args): def close(self): logger.info("Shut down virtual kernel: %s", self.id) - with self: - if self.app_object is not None: - if isinstance(self.app_object, reacton.core._RenderContext): - try: - self.app_object.close() - except Exception as e: - logger.exception("Could not close render context: %s", e) - # we want to continue, so we at least close all widgets - widgets.Widget.close_all() - # what if we reference each other - # import gc - # gc.collect() - self.kernel.session.close() - if self.id in contexts: - del contexts[self.id] - self.closed_event.set() + with self.lock: + for key in self.page_status: + self.page_status[key] = PageStatus.CLOSED + with self: + if self.app_object is not None: + if isinstance(self.app_object, reacton.core._RenderContext): + try: + self.app_object.close() + except Exception as e: + logger.exception("Could not close render context: %s", e) + # we want to continue, so we at least close all widgets + widgets.Widget.close_all() + # what if we reference each other + # import gc + # gc.collect() + self.kernel.close() + self.kernel = None # type: ignore + if self.id in contexts: + del contexts[self.id] + del current_context[get_current_thread_key()] + self.closed_event.set() def _state_reset(self): state_directory = Path(".") / "states" @@ -123,10 +130,12 @@ def state_save(self, state_directory: os.PathLike): def page_connect(self, page_id: str): logger.info("Connect page %s for kernel %s", page_id, self.id) - assert self.page_status.get(page_id) != PageStatus.CLOSED, "cannot connect with the same page_id after a close" - self.page_status[page_id] = PageStatus.CONNECTED - if self._last_kernel_cull_task: - self._last_kernel_cull_task.cancel() + with self.lock: + if page_id in self.page_status and self.page_status.get(page_id) == PageStatus.CLOSED: + raise RuntimeError("Cannot connect a page that is already closed") + self.page_status[page_id] = PageStatus.CONNECTED + if self._last_kernel_cull_task: + self._last_kernel_cull_task.cancel() def page_disconnect(self, page_id: str) -> "asyncio.Future[None]": """Signal that a page has disconnected, and schedule a kernel cull if needed. @@ -139,7 +148,13 @@ def page_disconnect(self, page_id: str) -> "asyncio.Future[None]": """ logger.info("Disconnect page %s for kernel %s", page_id, self.id) future: "asyncio.Future[None]" = asyncio.Future() - self.page_status[page_id] = PageStatus.DISCONNECTED + with self.lock: + if self.page_status[page_id] == PageStatus.CLOSED: + logger.info("Page %s already closed for kernel %s", page_id, self.id) + future.set_result(None) + return future + assert self.page_status[page_id] == PageStatus.CONNECTED, "cannot disconnect a page that is in state: %r" % self.page_status[page_id] + self.page_status[page_id] = PageStatus.DISCONNECTED current_event_loop = asyncio.get_event_loop() async def kernel_cull(): @@ -147,15 +162,22 @@ async def kernel_cull(): cull_timeout_sleep_seconds = solara.util.parse_timedelta(solara.server.settings.kernel.cull_timeout) logger.info("Scheduling kernel cull, will wait for max %s before shutting down the virtual kernel %s", cull_timeout_sleep_seconds, self.id) await asyncio.sleep(cull_timeout_sleep_seconds) - has_connected_pages = PageStatus.CONNECTED in self.page_status.values() - if has_connected_pages: - logger.info("We have (re)connected pages, keeping the virtual kernel %s alive", self.id) - else: - logger.info("No connected pages, and timeout reached, shutting down virtual kernel %s", self.id) - self.close() - current_event_loop.call_soon_threadsafe(future.set_result, None) + with self.lock: + has_connected_pages = PageStatus.CONNECTED in self.page_status.values() + if has_connected_pages: + logger.info("We have (re)connected pages, keeping the virtual kernel %s alive", self.id) + else: + logger.info("No connected pages, and timeout reached, shutting down virtual kernel %s", self.id) + self.close() + try: + current_event_loop.call_soon_threadsafe(future.set_result, None) + except RuntimeError: + pass # event loop already closed, happens during testing except asyncio.CancelledError: - current_event_loop.call_soon_threadsafe(future.cancel, "cancelled because a new cull task was scheduled") + try: + current_event_loop.call_soon_threadsafe(future.cancel, "cancelled because a new cull task was scheduled") + except RuntimeError: + pass # event loop already closed, happens during testing raise has_connected_pages = PageStatus.CONNECTED in self.page_status.values() @@ -168,7 +190,10 @@ async def create_task(): task = asyncio.create_task(kernel_cull()) # create a reference to the task so we can cancel it later self._last_kernel_cull_task = task - await task + try: + await task + except RuntimeError: + pass # event loop already closed, happens during testing asyncio.run_coroutine_threadsafe(create_task(), keep_alive_event_loop) else: @@ -182,15 +207,21 @@ def page_close(self, page_id: str): different from a websocket/page disconnect, which we might want to recover from. """ - self.page_status[page_id] = PageStatus.CLOSED - logger.info("Close page %s for kernel %s", page_id, self.id) - has_connected_pages = PageStatus.CONNECTED in self.page_status.values() - has_disconnected_pages = PageStatus.DISCONNECTED in self.page_status.values() - if not (has_connected_pages or has_disconnected_pages): - logger.info("No connected or disconnected pages, shutting down virtual kernel %s", self.id) - if self._last_kernel_cull_task: - self._last_kernel_cull_task.cancel() - self.close() + + logger.info("page status: %s", self.page_status) + with self.lock: + if self.page_status[page_id] == PageStatus.CLOSED: + logger.info("Page %s already closed for kernel %s", page_id, self.id) + return + self.page_status[page_id] = PageStatus.CLOSED + logger.info("Close page %s for kernel %s", page_id, self.id) + has_connected_pages = PageStatus.CONNECTED in self.page_status.values() + has_disconnected_pages = PageStatus.DISCONNECTED in self.page_status.values() + if not (has_connected_pages or has_disconnected_pages): + logger.info("No connected or disconnected pages, shutting down virtual kernel %s", self.id) + if self._last_kernel_cull_task: + self._last_kernel_cull_task.cancel() + self.close() try: @@ -267,6 +298,21 @@ def set_current_context(context: Optional[VirtualKernelContext]): current_context[thread_key] = context +@contextlib.contextmanager +def without_context(): + context = None + try: + context = get_current_context() + except RuntimeError: + pass + thread_key = get_current_thread_key() + current_context[thread_key] = None + try: + yield + finally: + current_context[thread_key] = context + + def initialize_virtual_kernel(session_id: str, kernel_id: str, websocket: websocket.WebsocketWrapper): from solara.server import app as appmodule diff --git a/solara/server/patch.py b/solara/server/patch.py index 354b70636..c677ac833 100644 --- a/solara/server/patch.py +++ b/solara/server/patch.py @@ -13,6 +13,8 @@ import ipywidgets.widgets.widget_output from IPython.core.interactiveshell import InteractiveShell +import solara.util + from . import app, kernel_context, reload, settings from .utils import pdb_guard @@ -235,7 +237,8 @@ def auto_watch_get_template(get_template): def wrapper(abs_path): template = get_template(abs_path) - reload.reloader.watcher.add_file(abs_path) + with kernel_context.without_context(): + reload.reloader.watcher.add_file(abs_path) return template return wrapper @@ -255,9 +258,8 @@ def WidgetContextAwareThread__init__(self, *args, **kwargs): def Thread_debug_run(self): - if self.current_context: - kernel_context.set_context_for_thread(self.current_context, self) - with pdb_guard(): + context = self.current_context or solara.util.nullcontext() + with pdb_guard(), context: Thread__run(self) diff --git a/solara/server/server.py b/solara/server/server.py index 1b5f779c4..f38230af4 100644 --- a/solara/server/server.py +++ b/solara/server/server.py @@ -139,7 +139,8 @@ async def app_loop(ws: websocket.WebsocketWrapper, session_id: str, kernel_id: s message = await ws.receive() except websocket.WebSocketDisconnect: try: - context.kernel.session.websockets.remove(ws) + if context.kernel is not None and context.kernel.session is not None: + context.kernel.session.websockets.remove(ws) except KeyError: pass logger.debug("Disconnected") @@ -150,10 +151,15 @@ async def app_loop(ws: websocket.WebsocketWrapper, session_id: str, kernel_id: s else: msg = deserialize_binary_message(message) t1 = time.time() - if not process_kernel_messages(kernel, msg): - # if we shut down the kernel, we do not keep the page session alive - context.close() - return + # we don't want to have the kernel closed while we are processing a message + # therefore we use this mutex that is also used in the context.close method + with context.lock: + if context.closed_event.is_set(): + return + if not process_kernel_messages(kernel, msg): + # if we shut down the kernel, we do not keep the page session alive + context.close() + return t2 = time.time() if settings.main.timing: print(f"timing: total={t2-t0:.3f}s, deserialize={t1-t0:.3f}s, kernel={t2-t1:.3f}s") # noqa: T201 diff --git a/solara/server/shell.py b/solara/server/shell.py index 4a1d0df57..6def31ef3 100644 --- a/solara/server/shell.py +++ b/solara/server/shell.py @@ -1,3 +1,4 @@ +import atexit import sys from threading import local from unittest.mock import Mock @@ -175,10 +176,20 @@ class SolaraInteractiveShell(InteractiveShell): display_pub_class = Type(SolaraDisplayPublisher) history_manager = Any() # type: ignore + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + atexit.unregister(self.atexit_operations) + + magic = self.magics_manager.registry["ScriptMagics"] + atexit.unregister(magic.kill_bg_processes) + def set_parent(self, parent): """Tell the children about the parent message.""" self.display_pub.set_parent(parent) + def init_sys_modules(self): + pass # don't create a __main__, it will cause a mem leak + def init_history(self): self.history_manager = Mock() # type: ignore diff --git a/tests/integration/memleak_test.py b/tests/integration/memleak_test.py new file mode 100644 index 000000000..8a5a78470 --- /dev/null +++ b/tests/integration/memleak_test.py @@ -0,0 +1,83 @@ +import gc +import time +import weakref +from pathlib import Path +from typing import Optional + +import objgraph +import playwright.sync_api +import pytest + +import solara +import solara.server.kernel_context + +HERE = Path(__file__).parent + + +set_value = None +context: Optional["solara.server.kernel_context.VirtualKernelContext"] = None + + +@pytest.fixture +def no_cull_timeout(): + cull_timeout_previous = solara.server.settings.kernel.cull_timeout + solara.server.settings.kernel.cull_timeout = "0.0001s" + try: + yield + finally: + solara.server.settings.kernel.cull_timeout = cull_timeout_previous + + +def _scoped_test_memleak( + page_session: playwright.sync_api.Page, + solara_server, + solara_app, + extra_include_path, +): + with solara_app("solara.website.pages"): + page_session.goto(solara_server.base_url) + page_session.locator("text=Examples").first.wait_for() + assert len(solara.server.kernel_context.contexts) == 1 + context = weakref.ref(list(solara.server.kernel_context.contexts.values())[0]) + # we should not have created a new context + assert len(solara.server.kernel_context.contexts) == 1 + kernel = weakref.ref(context().kernel) + shell = weakref.ref(kernel().shell) + session = weakref.ref(kernel().session) + page_session.goto("about:blank") + assert context().closed_event.wait(10) + if shell(): + del shell().__dict__ + return context, kernel, shell, session + + +def test_memleak( + pytestconfig, + request, + browser: playwright.sync_api.Browser, + page_session: playwright.sync_api.Page, + solara_server, + solara_app, + extra_include_path, + no_cull_timeout, +): + # for unknown reasons, del does not work in CI + context_ref, kernel_ref, shell_ref, session_ref = _scoped_test_memleak(page_session, solara_server, solara_app, extra_include_path) + + for i in range(200): + time.sleep(0.1) + for gen in [2, 1, 0]: + gc.collect(gen) + if context_ref() is None and kernel_ref() is None and shell_ref() is None and session_ref() is None: + break + else: + name = solara_server.__class__.__name__ + output_path = Path(pytestconfig.getoption("--output")) / f"mem-leak-{name}.pdf" + output_path.parent.mkdir(parents=True, exist_ok=True) + print("output to", output_path, output_path.resolve()) # noqa + objgraph.show_backrefs([context_ref(), kernel_ref(), shell_ref(), session_ref()], filename=str(output_path), max_depth=15, too_many=15) + + assert context_ref() is None + assert kernel_ref() is None + assert shell_ref() is None + assert session_ref() is None diff --git a/tests/unit/lifecycle_test.py b/tests/unit/lifecycle_test.py index 92baab1d2..112cbeaa7 100644 --- a/tests/unit/lifecycle_test.py +++ b/tests/unit/lifecycle_test.py @@ -84,7 +84,7 @@ async def test_kernel_lifecycle_close_single(close_first, short_cull_timeout): async def test_kernel_lifecycle_close_while_disconnected(close_first, short_cull_timeout): # a reconnect should be possible within the reconnect window websocket = Mock() - context = kernel_context.initialize_virtual_kernel("session-id-1", "kernel-id-1", websocket) + context = kernel_context.initialize_virtual_kernel(f"session-id-1-{close_first}", f"kernel-id-1-{close_first}", websocket) context.page_connect("page-id-1") cull_task_1 = context.page_disconnect("page-id-1") await asyncio.sleep(0.1) @@ -93,7 +93,8 @@ async def test_kernel_lifecycle_close_while_disconnected(close_first, short_cull if close_first: context.page_close("page-id-2") await asyncio.sleep(0.01) - cull_task_2 = context.page_disconnect("page-id-2") + context.page_connect("page-id-3") + cull_task_2 = context.page_disconnect("page-id-3") else: cull_task_2 = context.page_disconnect("page-id-2") await asyncio.sleep(0.01) diff --git a/tests/unit/patch_test.py b/tests/unit/patch_test.py index 15976bc30..c00a47475 100644 --- a/tests/unit/patch_test.py +++ b/tests/unit/patch_test.py @@ -20,9 +20,10 @@ def test_widget_error_message_outside_context(no_kernel_context): def test_widget_dict(no_kernel_context): - kernel_shared = kernel.Kernel() - context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel_shared, session_id="session-1") - context2 = kernel_context.VirtualKernelContext(id="2", kernel=kernel_shared, session_id="session-2") + kernel1 = kernel.Kernel() + kernel2 = kernel.Kernel() + context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel1, session_id="session-1") + context2 = kernel_context.VirtualKernelContext(id="2", kernel=kernel2, session_id="session-2") with context1: btn1 = widgets.Button(description="context1")