From ba352c9755843a9137b4cd67a26c63b97ba1d4dd Mon Sep 17 00:00:00 2001 From: "Maarten A. Breddels" Date: Fri, 8 Dec 2023 10:21:22 +0100 Subject: [PATCH] refactor: keep reference to cull task for memory leak debugging --- solara/server/kernel_context.py | 37 +++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/solara/server/kernel_context.py b/solara/server/kernel_context.py index ff9914532..316e1e9dc 100644 --- a/solara/server/kernel_context.py +++ b/solara/server/kernel_context.py @@ -1,4 +1,5 @@ import asyncio +import concurrent.futures import contextlib import dataclasses import enum @@ -70,6 +71,7 @@ class VirtualKernelContext: page_status: Dict[str, PageStatus] = dataclasses.field(default_factory=dict) # only used for testing _last_kernel_cull_task: "Optional[asyncio.Future[None]]" = None + _last_kernel_cull_future: "Optional[concurrent.futures.Future[None]]" = None closed_event: threading.Event = dataclasses.field(default_factory=threading.Event) _on_close_callbacks: List[Callable[[], None]] = dataclasses.field(default_factory=list) @@ -112,6 +114,11 @@ def close(self): with self.lock: for key in self.page_status: self.page_status[key] = PageStatus.CLOSED + if self._last_kernel_cull_task: + self._last_kernel_cull_task.cancel() + if self._last_kernel_cull_future: + self._last_kernel_cull_future.cancel() + with self: for f in reversed(self._on_close_callbacks): f() @@ -157,6 +164,8 @@ def state_save(self, state_directory: os.PathLike): pickle.dump(state, f) def page_connect(self, page_id: str): + if self.closed_event.is_set(): + raise RuntimeError("Cannot connect a page to a closed kernel") logger.info("Connect page %s for kernel %s", page_id, self.id) with self.lock: if page_id in self.page_status and self.page_status.get(page_id) == PageStatus.CLOSED: @@ -210,20 +219,21 @@ async def kernel_cull(): has_connected_pages = PageStatus.CONNECTED in self.page_status.values() if not has_connected_pages: - # when we have no connected pages, we will schedule a kernel cull - if self._last_kernel_cull_task: - self._last_kernel_cull_task.cancel() + with self.lock: + # when we have no connected pages, we will schedule a kernel cull + if self._last_kernel_cull_task: + self._last_kernel_cull_task.cancel() - async def create_task(): - task = asyncio.create_task(kernel_cull()) - # create a reference to the task so we can cancel it later - self._last_kernel_cull_task = task - try: - await task - except RuntimeError: - pass # event loop already closed, happens during testing + async def create_task(): + task = asyncio.create_task(kernel_cull()) + # create a reference to the task so we can cancel it later + self._last_kernel_cull_task = task + try: + await task + except RuntimeError: + pass # event loop already closed, happens during testing - asyncio.run_coroutine_threadsafe(create_task(), keep_alive_event_loop) + self._last_kernel_cull_future = asyncio.run_coroutine_threadsafe(create_task(), keep_alive_event_loop) else: future.set_result(None) return future @@ -235,7 +245,8 @@ def page_close(self, page_id: str): different from a websocket/page disconnect, which we might want to recover from. """ - + if self.closed_event.is_set(): + raise RuntimeError("Cannot connect a page to a closed kernel") logger.info("page status: %s", self.page_status) with self.lock: if self.page_status[page_id] == PageStatus.CLOSED: