diff --git a/mesa/experimental/signals/observable_collections.py b/mesa/experimental/signals/observable_collections.py index 97450706134..2aa83e4f6c9 100644 --- a/mesa/experimental/signals/observable_collections.py +++ b/mesa/experimental/signals/observable_collections.py @@ -24,7 +24,6 @@ def __init__(self): self.signal_types: set = { "added", "removed", - "cleared", "replaced", "on_change", } diff --git a/mesa/experimental/signals/signal.py b/mesa/experimental/signals/signal.py index 44aa683c398..916dd9dd37a 100644 --- a/mesa/experimental/signals/signal.py +++ b/mesa/experimental/signals/signal.py @@ -1,6 +1,7 @@ """Core classes for Observables.""" - +import contextlib import functools +import itertools import weakref from collections import defaultdict, namedtuple from collections.abc import Callable @@ -57,6 +58,8 @@ def __init__(self): """Initialize an Observable.""" self.public_name: str self.private_name: str + + # fixme can we make this an innerclass enum? self.signal_types: set = set( "on_change", ) @@ -99,6 +102,7 @@ class HasObservables: """HasObservables class.""" observables: dict[str, Observable] = {} + subscribers: dict[str, dict[str, weakref.WeakSet]] def __new__(cls, *args, **kwargs): # noqa D102 # fixme dirty hack because super does not work on agents @@ -108,9 +112,9 @@ def __new__(cls, *args, **kwargs): # noqa D102 # we have the name of observable as a key # we have signal_type as a key # we want weakrefs for the callable - obj.subscribers: dict[str, dict[str, weakref.WeakValueDictionary]] = ( - defaultdict(functools.partial(defaultdict, weakref.WeakSet)) - ) + + obj.subscribers: dict[str, dict[str, weakref.WeakSet]] = defaultdict( + functools.partial(defaultdict, weakref.WeakSet)) return obj @@ -123,10 +127,12 @@ def register_observable(self, observable: Observable): """ self.observables[observable.public_name] = observable + + def observe( self, name: str | All, - signal_type: str, + signal_type: str | All, handler: Callable, ): """Subscribe to the Observable for signal_type. @@ -136,27 +142,40 @@ def observe( signal_type: the type of signal on the Observable to subscribe to handler: the handler to call - # fixme, all should also work for signal type + Raises: + ValueError: if the Observable is not registered or if the Observable + does not emit the given signal_type + + + fixme should name/signal_type also take a list? """ - names = ( - [ - name, - ] - if not isinstance(name, All) - else self.observables.keys() - ) + # fixme: we have the same code here twice, can we move this to a helper method? + if not isinstance(name, All): + if name not in self.observables: + raise ValueError( + f"you are trying to subscribe to {name}, but this Observable is not known" + ) + else: + names = [name,] + else: + names = self.observables.keys() - for name in names: + if not isinstance(signal_type, All): if signal_type not in self.observables[name].signal_types: raise ValueError( f"you are trying to subscribe to a signal of {signal_type}" - "on Observable {name}, which does not emit this signal_type" + f"on Observable {name}, which does not emit this signal_type" ) + else: + signal_types = [signal_type, ] + else: + signal_types = self.observables[name].signal_types + for name, signal_type in itertools.product(names, signal_types): self.subscribers[name][signal_type].add(handler) - def unobserve(self, name: str | All, signal_type: str): + def unobserve(self, name: str | All, signal_type: str | All): """Unsubscribe to the Observable for signal_type. Args: @@ -171,9 +190,16 @@ def unobserve(self, name: str | All, signal_type: str): if not isinstance(name, All) else self.observables.keys() ) + if isinstance(signal_type, All): + signal_types = self.observables[name].signal_types + else: + signal_types = [signal_type, ] - for name in names: - del self.subscribers[name][signal_type] + for name, signal_type in itertools.product(names, signal_types): + with contextlib.suppress(KeyError): + del self.subscribers[name][signal_type] + # we silently ignore trying to remove unsubscribed + # observables and/or signal types def unobserve_all(self, name: str | All): """Clears all subscriptions for the observable . @@ -185,7 +211,9 @@ def unobserve_all(self, name: str | All): """ if name is not isinstance(name, All): - del self.subscribers[name] + with contextlib.suppress(KeyError): + del self.subscribers[name] + # ignore when unsubscribing to Observables that have no subscription else: self.subscribers = defaultdict( functools.partial(defaultdict, weakref.WeakValueDictionary)