Skip to content

Commit

Permalink
change how fallback_value is used
Browse files Browse the repository at this point in the history
  • Loading branch information
quaquel committed Sep 22, 2024
1 parent c7cedde commit e4b9593
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion mesa/experimental/signals/observable_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __set__(self, instance: "HasObservables", value: Iterable):
"""
instance.notify(
self.public_name, self.__get__(instance, None), value, "on_change"
self.public_name, getattr(instance, self.private_name, self.fallback_value), value, "on_change"
)
setattr(
instance,
Expand Down
16 changes: 10 additions & 6 deletions mesa/experimental/signals/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def __get__(self, instance, owner):

class Observable:
"""Base Observable class."""
# fixme, we might want to have a base observable
# and move some of this into this class and use super to allways use it
# for example the on_change notify can go there

def __init__(self):
"""Initialize an Observable."""
Expand All @@ -67,9 +70,7 @@ def __init__(self):
self.fallback_value = None # fixme, should this be user specifiable

def __get__(self, instance, owner): # noqa D103
# fixme how do we want to handle the fallback value
# and when should it raise an attribute error?
return getattr(instance, self.private_name, self.fallback_value)
return getattr(instance, self.private_name)

def __set_name__(self, owner, name): # noqa D103
self.public_name = name
Expand All @@ -78,7 +79,7 @@ def __set_name__(self, owner, name): # noqa D103

def __set__(self, instance: "HasObservables", value): # noqa D103
instance.notify(
self.public_name, self.__get__(instance, None), value, "on_change"
self.public_name, getattr(instance, self.private_name, self.fallback_value), value, "on_change"
)
setattr(instance, self.private_name, value)

Expand Down Expand Up @@ -194,6 +195,7 @@ def unobserve(self, name: str | All, signal_type: str | All):
if not isinstance(name, All)
else self.observables.keys()
)

if isinstance(signal_type, All):
signal_types = self.observables[name].signal_types
else:
Expand Down Expand Up @@ -222,7 +224,7 @@ def unobserve_all(self, name: str | All):
# ignore when unsubscribing to Observables that have no subscription
else:
self.subscribers = defaultdict(
functools.partial(defaultdict, weakref.WeakValueDictionary)
functools.partial(defaultdict, weakref.WeakSet)
)

def notify(self, observable: str, old_value: Any, new_value: Any, signal_type: str):
Expand All @@ -237,7 +239,9 @@ def notify(self, observable: str, old_value: Any, new_value: Any, signal_type: s
"""
# fixme: currently strongly tied to just on_change signals
# this needs to be refined for e.g. list and dicts in due course
# idea is to just mimic how traitlets handles this
# idea is to just mimic how traitlets handles this.
# Traitlets uses a Bunch helper class which turns a dict into something with
# attribute access. This will be richer than the current Signal named tuple
signal = Signal(self, observable, old_value, new_value, signal_type)

observers = self.subscribers[observable][signal_type]
Expand Down

0 comments on commit e4b9593

Please sign in to comment.