From 291c7214adf7a027145080021fd88cd52d931183 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Fri, 27 Sep 2024 15:27:40 +0200 Subject: [PATCH] various cleanups --- mesa/experimental/signals/signal.py | 62 ++++++++++++----------- mesa/experimental/signals/signals_util.py | 12 +++++ 2 files changed, 45 insertions(+), 29 deletions(-) create mode 100644 mesa/experimental/signals/signals_util.py diff --git a/mesa/experimental/signals/signal.py b/mesa/experimental/signals/signal.py index c32cab44d4c..a6d8e6715cf 100644 --- a/mesa/experimental/signals/signal.py +++ b/mesa/experimental/signals/signal.py @@ -4,13 +4,14 @@ import contextlib import functools -import itertools import weakref from abc import ABC, abstractmethod from collections import defaultdict, namedtuple from collections.abc import Callable from typing import Any +from .signals_util import create_weakref + __all__ = ["Observable", "HasObservables", "All", "Computable"] _hashable_signal = namedtuple("_HashableSignal", "instance name") @@ -18,7 +19,7 @@ CURRENT_COMPUTED: Computed | None = None # the current Computed that is evaluating PROCESSING_SIGNALS: set[tuple[str,]] = ( set() -) # fixme what to put here, we can't put the observable, but it is the name and has_observable combination +) class BaseObservable(ABC): @@ -76,7 +77,7 @@ class Observable(BaseObservable): # fixme, how do we "traverse" the tree # do we go by layer, or by branch? it seems signals goes by layer with its batch construction - def __init__(self): + def __init__(self, fallback_value=None): """Initialize an Observable.""" super().__init__() @@ -85,20 +86,16 @@ def __init__(self): self.signal_types: set = { "on_change", } - self.fallback_value = None # fixme, should this be user specifiable + self.fallback_value = fallback_value def __set__(self, instance: HasObservables, value): # noqa D103 if ( CURRENT_COMPUTED is not None and _hashable_signal(instance, self.public_name) in PROCESSING_SIGNALS ): - # fixme make cyclical dependency explict - # so CURRENT_COMPUTED tries to modified self while being dependent on self raise ValueError( f"cyclical dependency detected: Computed({CURRENT_COMPUTED.name}) tries to change " - f"{instance.__class__.__name__}.{self.public_name} while being dependent " - "" - f"on {instance.__class__.__name__}.{self.public_name}" + f"{instance.__class__.__name__}.{self.public_name} while also being dependent it" ) setattr(instance, self.private_name, value) @@ -111,11 +108,6 @@ def __set__(self, instance: HasObservables, value): # noqa D103 class Computable(BaseObservable): """A Computable that is depended on one or more Observables. - fixme how to make this work with Descriptors? - just to it as with ObservableList and SingalingList - so have a Computable and Computed class - declare the Computable at the top - assign the Computed in the instance .. code-block:: python @@ -124,7 +116,7 @@ class MyAgent(Agent): def __init__(self, model): super().__init__(model) - wealth = some_callable, args, kwargs # wip + wealth = Computed(func, args, kwargs) """ @@ -174,7 +166,7 @@ def __init__(self, func: Callable, *args, **kwargs): self._value = None self.name: str = "" # set by Computable - self.parents: weakref.WeakKeyDictionary[HasObservables, dict[str], Any] = ( + self.parents: weakref.WeakKeyDictionary[HasObservables, dict[str, Any]] = ( weakref.WeakKeyDictionary() ) @@ -282,17 +274,17 @@ class HasObservables: """HasObservables class.""" observables: dict[str, BaseObservable] = {} - subscribers: dict[str, dict[str, weakref.WeakSet]] + + # we can't use a weakset here because it does not handle bound methods correctly + # also, a list is faster for our use case + subscribers: dict[str, dict[str, list]] def __new__(cls, *args, **kwargs): # noqa D102 # fixme dirty hack because super does not work on agents obj = super().__new__(cls) - # some kind of nested dict - # we have the name of observable as a key - # we have signal_type as a key - # we want weakrefs for the callable - + # subscribers is a nested defaultdict + # obj.subscribers[observable_name][signal_type] = list of weakref handlers obj.subscribers = defaultdict(functools.partial(defaultdict, list)) return obj @@ -340,7 +332,7 @@ def observe( else: names = self.observables.keys() - # fixme, see unsubscribe, but event types differ accross names + # fixme, see unsubscribe, but event types differ across names if not isinstance(signal_type, All): if signal_type not in self.observables[name].signal_types: raise ValueError( @@ -354,13 +346,23 @@ def observe( else: signal_types = self.observables[name].signal_types - for name, signal_type in itertools.product(names, signal_types): - # fixme, we might built our own weakSet that handles this internally.... - if hasattr(handler, "__self__"): - ref = weakref.WeakMethod(handler) + for name in names: + if not isinstance(signal_type, All): + if signal_type not in self.observables[name].signal_types: + raise ValueError( + f"you are trying to subscribe to a signal of {signal_type}" + f"on Observable {name}, which does not emit this signal_type" + ) + else: + signal_types = [ + signal_type, + ] else: - ref = weakref.ref(handler) - self.subscribers[name][signal_type].append(ref) + signal_types = self.observables[name].signal_types + + ref = create_weakref(handler) + for signal_type in signal_types: + self.subscribers[name][signal_type].append(ref) def unobserve(self, name: str | All, signal_type: str | All): """Unsubscribe to the Observable for signal_type. @@ -428,6 +430,8 @@ def notify(self, observable: str, old_value: Any, new_value: Any, signal_type: s # attribute access. This will be richer than the current Signal named tuple signal = Signal(self, observable, old_value, new_value, signal_type) + # because we are using a list of subscribers + # we should update this list to subscribers that are still alive observers = self.subscribers[observable][signal_type] active_observers = [] for observer in observers: diff --git a/mesa/experimental/signals/signals_util.py b/mesa/experimental/signals/signals_util.py new file mode 100644 index 00000000000..b7bcf29b39b --- /dev/null +++ b/mesa/experimental/signals/signals_util.py @@ -0,0 +1,12 @@ +import weakref + +__all__ = ["create_weakref"] + + +def create_weakref(item, callback=None): + """Helper function to create a correct weakref for any item""" + if hasattr(item, "__self__"): + ref = weakref.WeakMethod(item, callback) + else: + ref = weakref.ref(item, callback) + return ref