diff --git a/solara/server/patch.py b/solara/server/patch.py index 354b70636..cd8b06d5a 100644 --- a/solara/server/patch.py +++ b/solara/server/patch.py @@ -264,6 +264,7 @@ def Thread_debug_run(self): _patched = False global_widgets_dict = {} global_templates_dict: Dict[Any, Any] = {} +widgets = context_dict_widgets() def Output_enter(self): @@ -314,14 +315,14 @@ def patch(): if ipywidget_version_major < 8: global_widgets_dict = ipywidgets.widget.Widget.widgets - ipywidgets.widget.Widget.widgets = context_dict_widgets() # type: ignore + ipywidgets.widget.Widget.widgets = widgets # type: ignore else: if hasattr(ipywidgets.widgets.widget, "_instances"): # since 8.0.3 global_widgets_dict = ipywidgets.widgets.widget._instances - ipywidgets.widgets.widget._instances = context_dict_widgets() # type: ignore + ipywidgets.widgets.widget._instances = widgets # type: ignore elif hasattr(ipywidgets.widget.Widget, "_instances"): global_widgets_dict = ipywidgets.widget.Widget._instances - ipywidgets.widget.Widget._instances = context_dict_widgets() # type: ignore + ipywidgets.widget.Widget._instances = widgets # type: ignore else: raise RuntimeError("Could not find _instances on ipywidgets version %r" % ipywidgets.__version__) threading.Thread.__init__ = WidgetContextAwareThread__init__ # type: ignore diff --git a/solara/server/server.py b/solara/server/server.py index 1b5f779c4..9cc959d2d 100644 --- a/solara/server/server.py +++ b/solara/server/server.py @@ -16,7 +16,7 @@ import solara import solara.routing -from . import app, jupytertools, settings, websocket +from . import app, jupytertools, patch, settings, websocket from .kernel import Kernel, deserialize_binary_message from .kernel_context import initialize_virtual_kernel @@ -135,6 +135,8 @@ async def app_loop(ws: websocket.WebsocketWrapper, session_id: str, kernel_id: s solara_user.set(user) while True: + if settings.main.timing: + widgets_ids = set(patch.widgets) try: message = await ws.receive() except websocket.WebSocketDisconnect: @@ -156,7 +158,13 @@ async def app_loop(ws: websocket.WebsocketWrapper, session_id: str, kernel_id: s return t2 = time.time() if settings.main.timing: - print(f"timing: total={t2-t0:.3f}s, deserialize={t1-t0:.3f}s, kernel={t2-t1:.3f}s") # noqa: T201 + widgets_ids_after = set(patch.widgets) + created_widgets_count = len(widgets_ids_after - widgets_ids) + close_widgets_count = len(widgets_ids - widgets_ids_after) + print( # noqa: T201 + f"timing: total={t2-t0:.3f}s, deserialize={t1-t0:.3f}s, kernel={t2-t1:.3f}s" + f" widget: created: {created_widgets_count} closed: {close_widgets_count}" + ) finally: context.page_disconnect(page_id)