Skip to content

Commit

Permalink
make All() work for signal_type
Browse files Browse the repository at this point in the history
  • Loading branch information
quaquel committed Sep 22, 2024
1 parent 62d27ae commit a0102e9
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 20 deletions.
1 change: 0 additions & 1 deletion mesa/experimental/signals/observable_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(self):
self.signal_types: set = {
"added",
"removed",
"cleared",
"replaced",
"on_change",
}
Expand Down
66 changes: 47 additions & 19 deletions mesa/experimental/signals/signal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Core classes for Observables."""

import contextlib
import functools
import itertools
import weakref
from collections import defaultdict, namedtuple
from collections.abc import Callable
Expand Down Expand Up @@ -57,6 +58,8 @@ def __init__(self):
"""Initialize an Observable."""
self.public_name: str
self.private_name: str

# fixme can we make this an innerclass enum?
self.signal_types: set = set(
"on_change",
)
Expand Down Expand Up @@ -99,6 +102,7 @@ class HasObservables:
"""HasObservables class."""

observables: dict[str, Observable] = {}
subscribers: dict[str, dict[str, weakref.WeakSet]]

def __new__(cls, *args, **kwargs): # noqa D102
# fixme dirty hack because super does not work on agents
Expand All @@ -108,9 +112,9 @@ def __new__(cls, *args, **kwargs): # noqa D102
# we have the name of observable as a key
# we have signal_type as a key
# we want weakrefs for the callable
obj.subscribers: dict[str, dict[str, weakref.WeakValueDictionary]] = (
defaultdict(functools.partial(defaultdict, weakref.WeakSet))
)

obj.subscribers: dict[str, dict[str, weakref.WeakSet]] = defaultdict(
functools.partial(defaultdict, weakref.WeakSet))

return obj

Expand All @@ -123,10 +127,12 @@ def register_observable(self, observable: Observable):
"""
self.observables[observable.public_name] = observable



def observe(
self,
name: str | All,
signal_type: str,
signal_type: str | All,
handler: Callable,
):
"""Subscribe to the Observable <name> for signal_type.
Expand All @@ -136,27 +142,40 @@ def observe(
signal_type: the type of signal on the Observable to subscribe to
handler: the handler to call
# fixme, all should also work for signal type
Raises:
ValueError: if the Observable <name> is not registered or if the Observable
does not emit the given signal_type
fixme should name/signal_type also take a list?
"""
names = (
[
name,
]
if not isinstance(name, All)
else self.observables.keys()
)
# fixme: we have the same code here twice, can we move this to a helper method?
if not isinstance(name, All):
if name not in self.observables:
raise ValueError(
f"you are trying to subscribe to {name}, but this Observable is not known"
)
else:
names = [name,]
else:
names = self.observables.keys()

for name in names:
if not isinstance(signal_type, All):
if signal_type not in self.observables[name].signal_types:
raise ValueError(
f"you are trying to subscribe to a signal of {signal_type}"
"on Observable {name}, which does not emit this signal_type"
f"on Observable {name}, which does not emit this signal_type"
)
else:
signal_types = [signal_type, ]
else:
signal_types = self.observables[name].signal_types

for name, signal_type in itertools.product(names, signal_types):
self.subscribers[name][signal_type].add(handler)

def unobserve(self, name: str | All, signal_type: str):
def unobserve(self, name: str | All, signal_type: str | All):
"""Unsubscribe to the Observable <name> for signal_type.
Args:
Expand All @@ -171,9 +190,16 @@ def unobserve(self, name: str | All, signal_type: str):
if not isinstance(name, All)
else self.observables.keys()
)
if isinstance(signal_type, All):
signal_types = self.observables[name].signal_types
else:
signal_types = [signal_type, ]

for name in names:
del self.subscribers[name][signal_type]
for name, signal_type in itertools.product(names, signal_types):
with contextlib.suppress(KeyError):
del self.subscribers[name][signal_type]
# we silently ignore trying to remove unsubscribed
# observables and/or signal types

def unobserve_all(self, name: str | All):
"""Clears all subscriptions for the observable <name>.
Expand All @@ -185,7 +211,9 @@ def unobserve_all(self, name: str | All):
"""
if name is not isinstance(name, All):
del self.subscribers[name]
with contextlib.suppress(KeyError):
del self.subscribers[name]
# ignore when unsubscribing to Observables that have no subscription
else:
self.subscribers = defaultdict(
functools.partial(defaultdict, weakref.WeakValueDictionary)
Expand Down

0 comments on commit a0102e9

Please sign in to comment.