Skip to content

Commit

Permalink
Merge pull request #504 from widgetti/fix_tracking_ref_double_counts
Browse files Browse the repository at this point in the history
Fix: tracking ref double counts
  • Loading branch information
maartenbreddels authored Feb 15, 2024
2 parents d86d63a + ad2027f commit 73647ee
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 15 deletions.
62 changes: 49 additions & 13 deletions solara/toestand.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import sys
import threading
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from operator import getitem
Expand Down Expand Up @@ -101,6 +102,9 @@ def value(self, value: T):
def set(self, value: T):
raise NotImplementedError

def peek(self) -> T:
raise NotImplementedError

def get(self) -> T:
raise NotImplementedError

Expand Down Expand Up @@ -222,6 +226,9 @@ def _get_dict(self):
scope_id = context.id
return cast(Dict[str, S], scope_dict), scope_id

def peek(self):
return self.get()

def get(self):
scope_dict, scope_id = self._get_dict()
if self.storage_key not in scope_dict:
Expand Down Expand Up @@ -314,7 +321,7 @@ def __set_name__(self, owner, name):
self._owner = owner

def __repr__(self):
value = self.get(add_watch=False)
value = self.peek()
if self._name:
return f"<Reactive {self._owner.__name__}.{self._name} value={value!r} id={hex(id(self))}>"
else:
Expand All @@ -341,10 +348,17 @@ def set(self, value: S):
raise ValueError("Can't set a reactive to itself")
self._storage.set(value)

def get(self, add_watch=True) -> S:
if add_watch and thread_local.reactive_used is not None:
def get(self, add_watch=None) -> S:
if add_watch is not None:
warnings.warn("add_watch is deprecated, use .peek()", DeprecationWarning)
if thread_local.reactive_used is not None:
thread_local.reactive_used.add(self)
return self._storage.get()
# peek to avoid parents also adding themselves to the reactive_used set
return self._storage.peek()

def peek(self) -> S:
"""Return the value without automatically subscribing to listeners."""
return self._storage.peek()

def subscribe(self, listener: Callable[[S], None], scope: Optional[ContextManager] = None):
return self._storage.subscribe(listener, scope=scope)
Expand Down Expand Up @@ -482,7 +496,7 @@ def wrapper(f: Callable[[], T]):
return wrapper(f)


class ValueSubField(ValueBase[T]):
class ReactiveField(ValueBase[T]):
def __init__(self, field: "FieldBase"):
super().__init__() # type: ignore
self._field = field
Expand All @@ -496,7 +510,7 @@ def __str__(self):
return str(self._field)

def __repr__(self):
return f"<Reactive subfield {self._field}>"
return f"<Reactive field {self._field}>"

@property
def lock(self):
Expand Down Expand Up @@ -530,18 +544,24 @@ def on_change(new, old):

return self._root.subscribe_change(on_change, scope=scope)

def get(self, obj=None, add_watch=True) -> T:
if add_watch and thread_local.reactive_used is not None:
def get(self, add_watch=None) -> T:
if add_watch is not None:
warnings.warn("add_watch is deprecated, use .peek()", DeprecationWarning)
if thread_local.reactive_used is not None:
thread_local.reactive_used.add(self)
return self._field.get(obj)
# peek to avoid parents also adding themselves to the reactive_used set
return self._field.peek()

def peek(self) -> T:
return self._field.peek()

def set(self, value: T):
self._field.set(value)


def Ref(field: T) -> Reactive[T]:
_field = cast(FieldBase, field)
return Reactive[T](ValueSubField[T](_field))
return cast(Reactive[T], ReactiveField[T](_field))


class FieldBase:
Expand Down Expand Up @@ -572,7 +592,14 @@ def get(self, obj=None):
# so we can get the 'old' value
if obj is not None:
return obj
return self._parent.get(add_watch=False)
return self._parent.get()

def peek(self, obj=None):
# we are at the root, so override the object
# so we can get the 'old' value
if obj is not None:
return obj
return self._parent.peek()

def set(self, value):
self._parent.set(value)
Expand All @@ -591,9 +618,13 @@ def get(self, obj=None):
obj = self._parent.get(obj)
return getattr(obj, self.key)

def peek(self, obj=None):
obj = self._parent.peek(obj)
return getattr(obj, self.key)

def set(self, value):
with self._lock:
parent_value = self._parent.get()
parent_value = self._parent.peek()
if isinstance(self.key, str):
parent_value = merge_state(parent_value, **{self.key: value})
self._parent.set(parent_value)
Expand All @@ -617,9 +648,13 @@ def get(self, obj=None):
obj = self._parent.get(obj)
return getitem(obj, self.key)

def peek(self, obj=None):
obj = self._parent.peek(obj)
return getitem(obj, self.key)

def set(self, value):
with self._lock:
parent_value = self._parent.get()
parent_value = self._parent.peek()
if isinstance(self.key, int) and isinstance(parent_value, (list, tuple)):
parent_type = type(parent_value)
parent_value = parent_value.copy() # type: ignore
Expand Down Expand Up @@ -725,6 +760,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):

# alias for compatibility
State = Reactive
ValueSubField = ReactiveField

auto_subscribe_context_manager = AutoSubscribeContextManagerReacton
reacton.core._component_context_manager_classes.append(auto_subscribe_context_manager)
18 changes: 16 additions & 2 deletions tests/unit/toestand_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import solara
import solara as sol
import solara.lab
import solara.toestand as toestand
from solara.server import kernel, kernel_context
from solara.toestand import Reactive, Ref, State, use_sync_external_store

Expand Down Expand Up @@ -842,17 +843,24 @@ def test_reactive_auto_subscribe_sub():
bears = Reactive(Bears(type="brown", count=1))
renders = 0

ref = Ref(bears.fields.count)
reactive_used = None

@solara.component
def Test():
nonlocal reactive_used
nonlocal renders
reactive_used = toestand.thread_local.reactive_used
renders += 1
count = Ref(bears.fields.count).value
count = ref.value
return solara.Info(f"{count} bears around here")

box, rc = solara.render(Test(), handle_error=False)
assert rc.find(v.Alert).widget.children[0] == "1 bears around here"
Ref(bears.fields.count).value += 1
assert reactive_used == {ref}
ref.value += 1
assert rc.find(v.Alert).widget.children[0] == "2 bears around here"
assert reactive_used == {ref}
# now check that we didn't listen to the while object, just count changes
renders_before = renders
Ref(bears.fields.type).value = "pink"
Expand Down Expand Up @@ -900,19 +908,25 @@ def Test():
def test_reactive_auto_subscribe_subfield_limit(kernel_context):
bears = Reactive(Bears(type="brown", count=1))
renders = 0
reactive_used = None

@solara.component
def Test():
nonlocal renders
nonlocal reactive_used
reactive_used = toestand.thread_local.reactive_used
renders += 1
_ = bears.value # access it to trigger the subscription
return solara.IntSlider("test", value=Ref(bears.fields.count).value)

box, rc = solara.render(Test(), handle_error=False)
assert rc.find(v.Slider).widget.v_model == 1
assert renders == 1
assert reactive_used is not None
assert len(reactive_used) == 2 # bears and bears.fields.count
Ref(bears.fields.count).value = 2
assert renders == 2
assert len(reactive_used) == 2 # bears and bears.fields.count
rc.close()
assert not bears._storage.listeners[kernel_context.id]
assert not bears._storage.listeners2[kernel_context.id]
Expand Down

0 comments on commit 73647ee

Please sign in to comment.