From 448c76d27ccc824c7b4ae6a968fdebefbd02a533 Mon Sep 17 00:00:00 2001 From: Alan Fleming <> Date: Sun, 26 May 2024 10:39:20 +1000 Subject: [PATCH] Add weakref with opt in automatic widget deletion using 'enable_weakreference'. --- .../ipywidgets/ipywidgets/widgets/__init__.py | 2 +- .../ipywidgets/widgets/tests/test_widget.py | 173 +++++++++++++++++- .../widgets/tests/test_widget_box.py | 92 ++++++++-- .../ipywidgets/ipywidgets/widgets/widget.py | 157 ++++++++++------ 4 files changed, 347 insertions(+), 77 deletions(-) diff --git a/python/ipywidgets/ipywidgets/widgets/__init__.py b/python/ipywidgets/ipywidgets/widgets/__init__.py index b90d3ee111..0951bad905 100644 --- a/python/ipywidgets/ipywidgets/widgets/__init__.py +++ b/python/ipywidgets/ipywidgets/widgets/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -from .widget import Widget, CallbackDispatcher, register, widget_serialization +from .widget import Widget, CallbackDispatcher, register, widget_serialization, enable_weakreference, disable_weakreference from .domwidget import DOMWidget from .valuewidget import ValueWidget diff --git a/python/ipywidgets/ipywidgets/widgets/tests/test_widget.py b/python/ipywidgets/ipywidgets/widgets/tests/test_widget.py index c5aa36048a..34fd9402a2 100644 --- a/python/ipywidgets/ipywidgets/widgets/tests/test_widget.py +++ b/python/ipywidgets/ipywidgets/widgets/tests/test_widget.py @@ -3,17 +3,21 @@ """Test Widget.""" +import copy +import gc import inspect +import weakref import pytest from IPython.core.interactiveshell import InteractiveShell from IPython.display import display from IPython.utils.capture import capture_output +import ipywidgets as ipw + from .. import widget from ..widget import Widget from ..widget_button import Button -import copy def test_no_widget_view(): @@ -88,4 +92,169 @@ def test_widget_copy(): with pytest.raises(NotImplementedError): copy.copy(button) with pytest.raises(NotImplementedError): - copy.deepcopy(button) \ No newline at end of file + copy.deepcopy(button) + + +def test_widget_open(): + button = Button() + model_id = button.model_id + assert model_id in widget._instances + spec = button.get_view_spec() + assert list(spec) == ["version_major", "version_minor", "model_id"] + assert spec["model_id"] + button.close() + assert model_id not in widget._instances + with pytest.raises(RuntimeError, match="Widget is closed"): + button.open() + with pytest.raises(RuntimeError, match="Widget is closed"): + button.get_view_spec() + + +@pytest.mark.parametrize( + "class_name", + [ + "Accordion", + "AppLayout", + "Audio", + "BoundedFloatText", + "BoundedIntText", + "Box", + "Button", + "ButtonStyle", + "Checkbox", + "ColorPicker", + "ColorsInput", + "Combobox", + "Controller", + "CoreWidget", + "DOMWidget", + "DatePicker", + "DatetimePicker", + "Dropdown", + "FileUpload", + "FloatLogSlider", + "FloatProgress", + "FloatRangeSlider", + "FloatSlider", + "FloatText", + "FloatsInput", + "GridBox", + "HBox", + "HTML", + "HTMLMath", + "Image", + "IntProgress", + "IntRangeSlider", + "IntSlider", + "IntText", + "IntsInput", + "Label", + "Layout", + "NaiveDatetimePicker", + "Output", + "Password", + "Play", + "RadioButtons", + "Select", + "SelectMultiple", + "SelectionRangeSlider", + "SelectionSlider", + "SliderStyle", + "Stack", + "Style", + "Tab", + "TagsInput", + "Text", + "Textarea", + "TimePicker", + "ToggleButton", + "ToggleButtons", + "ToggleButtonsStyle", + "TwoByTwoLayout", + "VBox", + "Valid", + "ValueWidget", + "Video", + "Widget", + ], +) +@pytest.mark.parametrize("enable_weakref", [True, False]) +def test_weakreference(class_name, enable_weakref): + # Ensure the base instance of all widgets can be deleted / garbage collected. + if enable_weakref: + ipw.enable_weakreference() + cls = getattr(ipw, class_name) + if class_name in ['SelectionRangeSlider', 'SelectionSlider']: + kwgs = {"options": [1, 2, 4]} + else: + kwgs = {} + try: + w = cls(**kwgs) + deleted = False + def on_delete(): + nonlocal deleted + deleted = True + weakref.finalize(w, on_delete) + # w should be the only strong ref to the widget. + # calling `del` should invoke its immediate deletion calling the `__del__` method. + if not enable_weakref: + w.close() + del w + gc.collect() + assert deleted + finally: + if enable_weakref: + ipw.disable_weakreference() + + +@pytest.mark.parametrize("weakref_enabled", [True, False]) +def test_button_weakreference(weakref_enabled: bool): + try: + click_count = 0 + deleted = False + + def on_delete(): + nonlocal deleted + deleted = True + + class TestButton(Button): + def my_click(self, b): + nonlocal click_count + click_count += 1 + + b = TestButton(description="button") + weakref.finalize(b, on_delete) + b_ref = weakref.ref(b) + assert b in widget._instances.values() + + b.on_click(b.my_click) + b.on_click(lambda x: setattr(x, "clicked", True)) + + b.click() + assert click_count == 1 + + if weakref_enabled: + ipw.enable_weakreference() + assert b in widget._instances.values(), "Instances not transferred" + ipw.disable_weakreference() + assert b in widget._instances.values(), "Instances not transferred" + ipw.enable_weakreference() + assert b in widget._instances.values(), "Instances not transferred" + + b.click() + assert click_count == 2 + assert getattr(b, "clicked") + + del b + gc.collect() + if weakref_enabled: + assert deleted + else: + assert not deleted + assert b_ref() in widget._instances.values() + b_ref().close() + gc.collect() + assert deleted, "Closing should remove the last strong reference." + + finally: + ipw.disable_weakreference() diff --git a/python/ipywidgets/ipywidgets/widgets/tests/test_widget_box.py b/python/ipywidgets/ipywidgets/widgets/tests/test_widget_box.py index 551f68dcc4..5d50324d08 100644 --- a/python/ipywidgets/ipywidgets/widgets/tests/test_widget_box.py +++ b/python/ipywidgets/ipywidgets/widgets/tests/test_widget_box.py @@ -1,33 +1,85 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -from unittest import TestCase +import gc +import weakref +import pytest from traitlets import TraitError import ipywidgets as widgets -class TestBox(TestCase): +def test_box_construction(): + box = widgets.Box() + assert box.get_state()["children"] == [] - def test_construction(self): - box = widgets.Box() - assert box.get_state()['children'] == [] - def test_construction_with_children(self): - html = widgets.HTML('some html') - slider = widgets.IntSlider() - box = widgets.Box([html, slider]) - children_state = box.get_state()['children'] - assert children_state == [ - widgets.widget._widget_to_json(html, None), - widgets.widget._widget_to_json(slider, None), - ] +def test_box_construction_with_children(): + html = widgets.HTML("some html") + slider = widgets.IntSlider() + box = widgets.Box([html, slider]) + children_state = box.get_state()["children"] + assert children_state == [ + widgets.widget._widget_to_json(html, None), + widgets.widget._widget_to_json(slider, None), + ] - def test_construction_style(self): - box = widgets.Box(box_style='warning') - assert box.get_state()['box_style'] == 'warning' - def test_construction_invalid_style(self): - with self.assertRaises(TraitError): - widgets.Box(box_style='invalid') +def test_box_construction_style(): + box = widgets.Box(box_style="warning") + assert box.get_state()["box_style"] == "warning" + + +def test_construction_invalid_style(): + with pytest.raises(TraitError): + widgets.Box(box_style="invalid") + + +def test_box_validate_mode(): + slider = widgets.IntSlider() + closed_button = widgets.Button() + closed_button.close() + with pytest.raises(TraitError, match="Invalid or closed items found.*"): + widgets.Box( + children=[closed_button, slider, "Not a widget"] + ) + box = widgets.Box( + children=[closed_button, slider, "Not a widget"], + validate_mode="log_error", + ) + assert len (box.children) == 1, "Invalid items should be dropped." + assert slider in box.children + + box.validate_mode = "raise" + with pytest.raises(TraitError): + box.children += ("Not a widget", closed_button) + + +def test_box_gc(): + widgets.VBox._active_widgets + widgets.enable_weakreference() + # Test Box gc collected and children lifecycle managed. + try: + deleted = False + + class TestButton(widgets.Button): + def my_click(self, b): + pass + + button = TestButton(description="button") + button.on_click(button.my_click) + + b = widgets.VBox(children=[button]) + + def on_delete(): + nonlocal deleted + deleted = True + + weakref.finalize(b, on_delete) + del b + gc.collect() + assert deleted + widgets.VBox._active_widgets + finally: + widgets.disable_weakreference() diff --git a/python/ipywidgets/ipywidgets/widgets/widget.py b/python/ipywidgets/ipywidgets/widgets/widget.py index 2dc674097d..a9473f04c5 100644 --- a/python/ipywidgets/ipywidgets/widgets/widget.py +++ b/python/ipywidgets/ipywidgets/widgets/widget.py @@ -6,13 +6,13 @@ in the Jupyter notebook front-end. """ import os -import sys import typing +import weakref from contextlib import contextmanager from collections.abc import Iterable from IPython import get_ipython from traitlets import ( - Any, HasTraits, Unicode, Dict, Instance, List, Int, Set, Bytes, observe, default, Container, + Any, HasTraits, Unicode, Dict, Instance, List, Int, Set, observe, default, Container, Undefined) from json import loads as jsonloads, dumps as jsondumps from .. import comm @@ -41,17 +41,39 @@ def envset(name, default): PROTOCOL_VERSION_MAJOR = __protocol_version__.split('.')[0] CONTROL_PROTOCOL_VERSION_MAJOR = __control_protocol_version__.split('.')[0] JUPYTER_WIDGETS_ECHO = envset('JUPYTER_WIDGETS_ECHO', default=True) -# we keep a strong reference for every widget created, for a discussion on using weak references see: +# for a discussion on using weak references see: # https://github.com/jupyter-widgets/ipywidgets/issues/1345 _instances : typing.MutableMapping[str, "Widget"] = {} +def enable_weakreference(): + """Use a WeakValueDictionary instead of a standard dictionary to map + `comm_id` to `widget` for every widget instance. + + By default widgets are mapped using a standard dictionary. Use this feature + to permit widget garbage collection. + """ + global _instances + if not isinstance(_instances, weakref.WeakValueDictionary): + _instances = weakref.WeakValueDictionary(_instances) + +def disable_weakreference(): + """Use a Dictionary to map `comm_id` to `widget` for every widget instance. + + Note: this is the default setting and maintains a strong reference to the + the widget preventing automatic garbage collection. If the close method + is called, the widget will remove itself enabling garbage collection. + """ + global _instances + if isinstance(_instances, weakref.WeakValueDictionary): + _instances = dict(_instances) + def _widget_to_json(x, obj): - if isinstance(x, dict): - return {k: _widget_to_json(v, obj) for k, v in x.items()} + if isinstance(x, Widget): + return f"IPY_MODEL_{x.model_id}" elif isinstance(x, (list, tuple)): return [_widget_to_json(v, obj) for v in x] - elif isinstance(x, Widget): - return "IPY_MODEL_" + x.model_id + elif isinstance(x, dict): + return {k: _widget_to_json(v, obj) for k, v in x.items()} else: return x @@ -215,18 +237,6 @@ def register_callback(self, callback, remove=False): elif not remove and callback not in self.callbacks: self.callbacks.append(callback) -def _show_traceback(method): - """decorator for showing tracebacks""" - def m(self, *args, **kwargs): - try: - return(method(self, *args, **kwargs)) - except Exception as e: - ip = get_ipython() - if ip is None: - self.log.warning("Exception in widget method %s: %s", method, e, exc_info=True) - else: - ip.showtraceback() - return m class WidgetRegistry: @@ -304,7 +314,7 @@ class Widget(LoggingHasTraits): #------------------------------------------------------------------------- _widget_construction_callback = None _control_comm = None - + @_staticproperty def widgets(): # Because this is a static attribute, it will be accessed when initializing this class. In that case, since a user @@ -461,7 +471,7 @@ def _get_embed_state(self, drop_defaults=False): return state def get_view_spec(self): - return dict(version_major=2, version_minor=0, model_id=self._model_id) + return {"version_major":2, "version_minor":0, "model_id": self.model_id} #------------------------------------------------------------------------- # Traits @@ -499,11 +509,12 @@ def _default_keys(self): #------------------------------------------------------------------------- def __init__(self, **kwargs): """Public constructor""" - self._model_id = kwargs.pop('model_id', None) + if 'model_id' in kwargs: + self.comm = self._create_comm(kwargs.pop('model_id')) super().__init__(**kwargs) + self.open() Widget._call_widget_constructed(self) - self.open() def __copy__(self): raise NotImplementedError("Widgets cannot be copied; custom implementation required") @@ -519,53 +530,80 @@ def __del__(self): # Properties #------------------------------------------------------------------------- - def open(self): - """Open a comm to the frontend if one isn't already open.""" - if self.comm is None: - state, buffer_paths, buffers = _remove_buffers(self.get_state()) - args = dict(target_name='jupyter.widget', - data={'state': state, 'buffer_paths': buffer_paths}, - buffers=buffers, - metadata={'version': __protocol_version__} - ) - if self._model_id is not None: - args['comm_id'] = self._model_id + @default('comm') + def _default_comm(self): + return self._create_comm() - self.comm = comm.create_comm(**args) + def open(self): + """Open a comm to the frontend if one isn't already open.""" + assert self.model_id + + + def _create_comm(self, comm_id=None): + """Open a new comm to the frontend.""" + state, buffer_paths, buffers = _remove_buffers(self.get_state()) + self.comm = comm_ = comm.create_comm( + target_name="jupyter.widget", + data={"state": state, "buffer_paths": buffer_paths}, + buffers=buffers, + metadata={"version": __protocol_version__}, + comm_id=comm_id, + ) + return comm_ + @observe('comm') def _comm_changed(self, change): """Called when the comm is changed.""" - if change['new'] is None: - return - self._model_id = self.model_id + if change['old']: + change['old'].on_msg(None) + change['old'].close() + # On python shutdown _instances can be None + if isinstance(_instances, dict): + _instances.pop(change['old'].comm_id, None) + if change['new']: + if isinstance(_instances, dict): + _instances[change['new'].comm_id] = self + + # prevent memory leaks by using a weak reference to self. + ref = weakref.ref(self) + def _handle_msg(msg): + self_ = ref() + if self_ is not None: + try: + self_._handle_msg(msg) + except Exception as e: + self_._show_traceback(self_._handle_msg, e) + + change['new'].on_msg(_handle_msg) + - self.comm.on_msg(self._handle_msg) - _instances[self.model_id] = self @property def model_id(self): """Gets the model id of this widget. If a Comm doesn't exist yet, a Comm will be created automagically.""" - return self.comm.comm_id + if not self._repr_mimebundle_: + # a closed widget will not be found at the frontend so raise an error here. + msg = f"Widget is closed: {self!r}" + raise RuntimeError(msg) + return getattr(self.comm, "comm_id", None) + #------------------------------------------------------------------------- # Methods #------------------------------------------------------------------------- def close(self): - """Close method. + """Permanently close the widget. Closes the underlying comm. When the comm is closed, all of the widget views are automatically removed from the front-end.""" - if self.comm is not None: - _instances.pop(self.model_id, None) - self.comm.close() - self.comm = None - self._repr_mimebundle_ = None + self._repr_mimebundle_ = None + self.comm = None def send_state(self, key=None): """Sends the widget state, or a piece of it, to the front-end, if it exists. @@ -693,15 +731,18 @@ def notify_change(self, change): # Send the state to the frontend before the user-registered callbacks # are called. name = change['name'] - if self.comm is not None and getattr(self.comm, 'kernel', True) is not None: + comm = self._trait_values.get('comm') + if comm and getattr(comm, 'kernel', None): # Make sure this isn't information that the front-end just sent us. - if name in self.keys and self._should_send_property(name, getattr(self, name)): + if name in self.keys and self._should_send_property(name, change['new']): # Send new state to front-end self.send_state(key=name) super().notify_change(change) def __repr__(self): - return self._gen_repr_from_keys(self._repr_keys()) + if not self._repr_mimebundle_: + return f'' + return self._gen_repr_from_keys(self._repr_keys()) #------------------------------------------------------------------------- # Support methods @@ -759,7 +800,6 @@ def _should_send_property(self, key, value): return True # Event handlers - @_show_traceback def _handle_msg(self, msg): """Called when a msg is received from the front-end""" data = msg['content']['data'] @@ -785,6 +825,14 @@ def _handle_msg(self, msg): else: self.log.error('Unknown front-end to back-end widget msg with method "%s"' % method) + def _show_traceback(self, method, e:Exception): + ip = get_ipython() + if ip is None: + self.log.warning("Exception in widget method %s: %s", method, e, exc_info=True) + else: + ip.showtraceback() + + def _handle_custom_msg(self, content, buffers): """Called when a custom msg is received.""" self._msg_callbacks(self, content, buffers) @@ -815,14 +863,15 @@ def _repr_mimebundle_(self, **kwargs): data['application/vnd.jupyter.widget-view+json'] = { 'version_major': 2, 'version_minor': 0, - 'model_id': self._model_id + 'model_id': self.model_id } return data def _send(self, msg, buffers=None): """Sends a message to the model in the front-end.""" - if self.comm is not None and (self.comm.kernel is not None if hasattr(self.comm, "kernel") else True): - self.comm.send(data=msg, buffers=buffers) + comm = self.comm + if comm is not None and getattr(comm, "kernel", True): + comm.send(data=msg, buffers=buffers) def _repr_keys(self): traits = self.traits()