diff --git a/mesa/experimental/signals/observable_collections.py b/mesa/experimental/signals/observable_collections.py index 2aa83e4f6c9..c5a31ad1fce 100644 --- a/mesa/experimental/signals/observable_collections.py +++ b/mesa/experimental/signals/observable_collections.py @@ -51,7 +51,7 @@ def __set__(self, instance: "HasObservables", value: Iterable): """ instance.notify( - self.public_name, self.__get__(instance, None), value, "on_change" + self.public_name, getattr(instance, self.private_name, self.fallback_value), value, "on_change" ) setattr( instance, diff --git a/mesa/experimental/signals/signal.py b/mesa/experimental/signals/signal.py index e42fab824b4..3ef8dadc9f9 100644 --- a/mesa/experimental/signals/signal.py +++ b/mesa/experimental/signals/signal.py @@ -54,6 +54,9 @@ def __get__(self, instance, owner): class Observable: """Base Observable class.""" + # fixme, we might want to have a base observable + # and move some of this into this class and use super to allways use it + # for example the on_change notify can go there def __init__(self): """Initialize an Observable.""" @@ -67,9 +70,7 @@ def __init__(self): self.fallback_value = None # fixme, should this be user specifiable def __get__(self, instance, owner): # noqa D103 - # fixme how do we want to handle the fallback value - # and when should it raise an attribute error? - return getattr(instance, self.private_name, self.fallback_value) + return getattr(instance, self.private_name) def __set_name__(self, owner, name): # noqa D103 self.public_name = name @@ -78,7 +79,7 @@ def __set_name__(self, owner, name): # noqa D103 def __set__(self, instance: "HasObservables", value): # noqa D103 instance.notify( - self.public_name, self.__get__(instance, None), value, "on_change" + self.public_name, getattr(instance, self.private_name, self.fallback_value), value, "on_change" ) setattr(instance, self.private_name, value) @@ -194,6 +195,7 @@ def unobserve(self, name: str | All, signal_type: str | All): if not isinstance(name, All) else self.observables.keys() ) + if isinstance(signal_type, All): signal_types = self.observables[name].signal_types else: @@ -222,7 +224,7 @@ def unobserve_all(self, name: str | All): # ignore when unsubscribing to Observables that have no subscription else: self.subscribers = defaultdict( - functools.partial(defaultdict, weakref.WeakValueDictionary) + functools.partial(defaultdict, weakref.WeakSet) ) def notify(self, observable: str, old_value: Any, new_value: Any, signal_type: str): @@ -237,7 +239,9 @@ def notify(self, observable: str, old_value: Any, new_value: Any, signal_type: s """ # fixme: currently strongly tied to just on_change signals # this needs to be refined for e.g. list and dicts in due course - # idea is to just mimic how traitlets handles this + # idea is to just mimic how traitlets handles this. + # Traitlets uses a Bunch helper class which turns a dict into something with + # attribute access. This will be richer than the current Signal named tuple signal = Signal(self, observable, old_value, new_value, signal_type) observers = self.subscribers[observable][signal_type]