From 79f99a1130886d756a7ad356f51c0e3c5c4b7279 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Sun, 22 Sep 2024 21:19:50 +0200 Subject: [PATCH] start of adding Computed and Computable --- mesa/experimental/signals/code/__init__.py | 21 + mesa/experimental/signals/code/_cellmagic.py | 176 ++++++++ mesa/experimental/signals/code/_core.py | 450 +++++++++++++++++++ mesa/experimental/signals/code/_version.py | 5 + mesa/experimental/signals/code/test.py | 10 + mesa/experimental/signals/signal.py | 193 ++++++-- 6 files changed, 807 insertions(+), 48 deletions(-) create mode 100644 mesa/experimental/signals/code/__init__.py create mode 100644 mesa/experimental/signals/code/_cellmagic.py create mode 100644 mesa/experimental/signals/code/_core.py create mode 100644 mesa/experimental/signals/code/_version.py create mode 100644 mesa/experimental/signals/code/test.py diff --git a/mesa/experimental/signals/code/__init__.py b/mesa/experimental/signals/code/__init__.py new file mode 100644 index 00000000000..b6901fef336 --- /dev/null +++ b/mesa/experimental/signals/code/__init__.py @@ -0,0 +1,21 @@ +"""A signals implementation for Python.""" + +from __future__ import annotations + +from ._core import Signal, batch, computed, effect +from ._version import __version__ + + +def load_ipython_extension(ipython): + """Load the IPython extension. + + `%load_ext signals` will load the extension and enable the `%%effect` cell magic. + + Parameters + ---------- + ipython : IPython.core.interactiveshell.InteractiveShell + The IPython shell instance. + """ + from ._cellmagic import load_ipython_extension # noqa: PLC0415 + + load_ipython_extension(ipython) diff --git a/mesa/experimental/signals/code/_cellmagic.py b/mesa/experimental/signals/code/_cellmagic.py new file mode 100644 index 00000000000..da6bbc293bc --- /dev/null +++ b/mesa/experimental/signals/code/_cellmagic.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import ast +import os +import typing + +from IPython.core.magic import Magics, cell_magic, magics_class +from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring +from IPython.display import DisplayHandle, display + +from ._core import effect + +if typing.TYPE_CHECKING: + from IPython.core.interactiveshell import InteractiveShell + +EFFECTS = {} +CELL_ID = None + + +def run_ast_nodes( + nodelist: list, + cell_name: str, + user_global_ns: dict, + user_ns: dict, +) -> dict: + """Run a list of AST nodes. + + Parameters + ---------- + nodelist : list + List of AST nodes. + cell_name : str + The name of the cell. + user_global_ns : dict + The global namespace. + user_ns : dict + The user namespace. + + Returns + ------- + dict + A dictionary with the value of the last expression in the cell. + """ + # If the last node is not an expression, run everything + if not isinstance(nodelist[-1], ast.Expr): + code = compile(ast.Module(nodelist, []), cell_name, "exec") + exec(code, user_global_ns, user_ns) + return {} + + to_run_exec = nodelist[:-1] + if to_run_exec: + exec_code = compile(ast.Module(to_run_exec, []), cell_name, "exec") + exec(exec_code, user_global_ns, user_ns) + + expr_code = compile(ast.Expression(nodelist[-1].value), cell_name, "eval") + value = eval(expr_code, user_global_ns, user_ns) + return {"value": value} + + +def prepare_cell_execution(shell: InteractiveShell, raw_code: str): + dh = DisplayHandle() + transformed_code = shell.transform_cell(raw_code) + cell_name = shell.compile.cache( + transformed_code=transformed_code, + number=shell.execution_count, + raw_code=raw_code, + ) + code_ast = shell.compile.ast_parse(transformed_code, filename=cell_name) + + def run_cell(): + try: + result = run_ast_nodes( + nodelist=code_ast.body, + cell_name=cell_name, + user_global_ns=shell.user_global_ns, + user_ns=shell.user_ns, + ) + if "value" in result: + dh.update(result["value"]) + except Exception: + shell.showtraceback() + + dh.display(None) # create the display + return effect(run_cell) + + +def prepare_cell_execution_ipywidgets(shell: InteractiveShell, raw_code: str): + try: + import ipywidgets # noqa: PLC0415 + except ImportError: + raise ImportError("ipywidgets is required for this feature.") # noqa: B904 + + import ipywidgets # noqa: PLC0415 + + output_widget = ipywidgets.Output() + display(output_widget) + + @output_widget.capture(clear_output=True, wait=True) + def run_cell(): + shell.run_cell(raw_code) + + cell_effect = effect(run_cell) + + def cleanup(): + cell_effect() + output_widget.close() + + return cleanup + + +@magics_class +class SignalsMagics(Magics): + @magic_arguments() + @argument( + "-n", + "--name", + type=str, + default=None, + help="Name the effect. Effects are cleaned up by name. default is the cell id.", + ) + @argument( + "--mode", + type=str, + default="displayhandle", + help="The output mode for the effect. Either 'widget' or 'displayhandle'.", + ) + @cell_magic + def effect(self, line, cell): + """Excute code cell as an effect.""" + args = parse_argstring(SignalsMagics.effect, line) + name = args.name or CELL_ID + + # Cleanup previous effect + if name in EFFECTS: + cleanup = EFFECTS.pop(name) + cleanup() + + shell = typing.cast("InteractiveShell", self.shell) + mode = os.environ.get("SIGNALS_MODE", args.mode) + + if mode == "widget": + cleanup = prepare_cell_execution_ipywidgets(shell, cell) + elif mode == "displayhandle": + cleanup = prepare_cell_execution(shell, cell) + else: + raise ValueError(f"Invalid mode: {args.mode}") + + EFFECTS[name] = cleanup + + @cell_magic + def clear_effects(self, line, cell): # noqa: PLR6301 + """Clear all effects.""" + for cleanup in EFFECTS.values(): + cleanup() + EFFECTS.clear() + + +def load_ipython_extension(ipython): + """Load the IPython extension. + + `%load_ext signals` will load the extension and enable the `%%effect` cell magic. + + Parameters + ---------- + ipython : IPython.core.interactiveshell.InteractiveShell + The IPython shell instance. + """ + + # Not how else to get the cell id, seems like a hack + # https://stackoverflow.com/questions/75185964/ipython-cell-magic-access-to-cell-id + def pre_run_cell(info): + global CELL_ID # noqa: PLW0603 + CELL_ID = info.cell_id + + ipython.events.register("pre_run_cell", pre_run_cell) + ipython.register_magics(SignalsMagics) diff --git a/mesa/experimental/signals/code/_core.py b/mesa/experimental/signals/code/_core.py new file mode 100644 index 00000000000..eb0312ec91a --- /dev/null +++ b/mesa/experimental/signals/code/_core.py @@ -0,0 +1,450 @@ +"""A signals implementation for Python based on @preact/signals.""" + +from __future__ import annotations + +import typing +import weakref + +__all__ = ["Signal", "batch", "computed", "effect"] + +Disposer = typing.Callable[[], None] +Listener = typing.Callable[[], None] + +# current computed that is running +CURRENT_COMPUTED: Computed | None = None + +# a set of listeners which will be triggered after the batch is complete +BATCH_PENDING: set[Listener] | None = None + +PROCESSING_SIGNALS: set[Signal] = set() + +T = typing.TypeVar("T") + + +def batch(fn: typing.Callable[[], T]) -> T: + """Combine multiple updates into one "commit" at the end of the provided callback. + + Batches can be nested, and changes are only flushed once the outermost batch + callback completes. Accessing a signal that has been modified within a batch + will reflect its updated value. + + Parameters + ---------- + fn : Callable[[], T] + The callback function to execute within the batch. + + Returns + ------- + T + The value returned by the callback function. + """ + global BATCH_PENDING # noqa: PLW0603 + + if BATCH_PENDING is None: + listeners = set() + old = BATCH_PENDING + BATCH_PENDING = listeners + + try: + return fn() + finally: + BATCH_PENDING = old + PROCESSING_SIGNALS.clear() + + # trigger any pending listeners + for listener in listeners: + listener() + else: + return fn() + + +class Signal(typing.Generic[T]): + """Represents a signal that can be subscribed to for changes in value.""" + + __slots__ = ["__weakref__", "_children", "_value"] + + _value: T + + # Uses weak references to avoid memory leaks + # If the child is not used anywhere, then it can be garbage collected + _children: set[weakref.ref[Signal]] + + def __init__(self, value: T) -> None: + self._value = value + self._children = set() + + def __call__(self) -> T: + """Get the current value of the signal.""" + return self.get() + + # def __str__(self) -> str: + # return f"{self()}" + + # def __repr__(self) -> str: + # return f"Signal({self()})" + + # Recurse down all children, marking them as diry and adding + # listeners to batch_pending + def _wakeup(self): + to_remove = set() + for child_ref in self._children: + child = child_ref() + if child is not None: + child._wakeup() + else: + to_remove.add(child_ref) + + for child_ref in to_remove: + # If the child has been garbage collected, remove it from the set + self._children.remove(child_ref) + + def peek(self): + """Get the current value of the signal without subscribing to changes.""" + return self._value + + def get(self) -> T: + """Get the current value of the signal.""" + value = self._value + if CURRENT_COMPUTED is not None: + # this is ued to detect infinite cycles + if BATCH_PENDING is not None: + PROCESSING_SIGNALS.add(self) + + # if accessing inside of a computed, add this to the computed's parents + CURRENT_COMPUTED._add_dependency(self, value) + + return value + + def set(self, value: T) -> None: + if ( + CURRENT_COMPUTED is not None + and BATCH_PENDING is not None + and self in PROCESSING_SIGNALS + ): + raise RuntimeError("Cycle detected") + + self._value = value + + # If the value is set outside of a batch, this ensures that all of the + # children will be fully marked as dirty before triggering any listeners + batch(self._wakeup) + + def subscribe( + self, fn: typing.Callable[[T], typing.Any] + ) -> typing.Callable[[], None]: + """Subscribe to changes in the signal. + + Parameters + ---------- + fn : Callable[[T], None] + The callback function to run when the signal changes. + + Returns + ------- + Callable[[], None] + A function for unsubscribing from the signal. + """ + return effect(lambda: fn(self())) + + +class Computed(Signal[T]): + """Represents a signal whose value is derived from other signals.""" + + __slots__ = ["_callback", "_dirty", "_first", "_has_error", "_parents", "_weak"] + + # Whether this is the first time processing the computed + _first: bool + + # Whether any of the computed's parents have changed or not + _dirty: bool + + # Whether the callback errored or not + _has_error: bool + + # Weakrefs has their own object identity, so we must reuse the same weakref + # over and over again + _weak: weakref.ref[Signal | Computed] + + # The parent dependencies of this computed. + _parents: dict[Signal, typing.Any] + + _callback: typing.Callable[[], T] + + def __init__(self, callback: typing.Callable[[], T]) -> None: + super().__init__(typing.cast(T, None)) + self._first = True + self._dirty = True + self._has_error = False + self._weak = weakref.ref(self) + self._parents = {} + self._callback = callback + + def __call__(self) -> T: + return self.get() + + def _wakeup(self): + """Mark this computed as dirty whenever any of its parents change.""" + self._dirty = True + super()._wakeup() + + def _add_dependency(self, parent: Signal, value: typing.Any) -> None: + """Add the Signal as a dependency of this computed. + + Called when another Signal's .value is accessed inside of this computed. + """ + self._parents[parent] = value + parent._children.add(self._weak) + + def _remove_dependencies(self): + """Remove all links between this computed and its dependencies.""" + for parent in self._parents: + parent._children.remove(self._weak) + + def peek(self) -> T: + global CURRENT_COMPUTED # noqa: PLW0603 + + if self._dirty: + try: + changed = False + if self._first: + self._first = False + changed = True + else: + for parent, old_value in self._parents.items(): + new_value = parent.peek() + if old_value != new_value: + changed = True + + if changed: + self._has_error = False + # Because the dependencies might have changed, we first + # remove all of the old links between this computed and + # its dependencies. + # + # The links will be recreated by the _addDependency method. + self._remove_dependencies() + + old = CURRENT_COMPUTED + CURRENT_COMPUTED = self + + try: + self._value = self._callback() + finally: + CURRENT_COMPUTED = old + except Exception as e: + self._has_error = True + # We reuse the _value slot for the error, instead of using + # a separate property + self._value = typing.cast(T, e) + + if self._has_error: + # We know that the value is an exception + raise self._value + + return self._value + + def get(self) -> T: + """Get the current value of the computed.""" + value = self.peek() + + if CURRENT_COMPUTED is not None: + # If accessing inside of a computed, add this to the computed's parents + CURRENT_COMPUTED._add_dependency(self, value) + + return value + + def set(self, value: T) -> None: # noqa: PLR6301 + raise AttributeError("Computed singals are read-only") + + # def __repr__(self) -> str: + # return f"Computed({self()})" + + +def computed(fn: typing.Callable[[], T]) -> Computed[T]: + """Create a new signal that is computed based on the values of other signals. + + The returned computed signal is read-only, and its value is automatically + updated when any signals accessed from within the callback function change. + + Parameters + ---------- + fn : Callable[[], T] + The function to compute the value of the signal. + + Returns + ------- + Computed[T] + A new read-only signal. + """ + return Computed(fn) + + +class Effect(Computed[T]): + """Represents a side-effect that runs in response to signal changes.""" + + __slots__ = ["_listener"] + + _listener: Listener | None + + def __init__(self, fn: typing.Callable[[], T]) -> None: + self._listener = None + super().__init__(fn) + + def __repr__(self) -> str: + return f"Effect({self()})" + + def _wakeup(self): + """Mark this effect as dirty whenever any of its parents change.""" + if BATCH_PENDING is None: + raise RuntimeError("invalid batch_pending") + + if self._listener is not None: + BATCH_PENDING.add(self._listener) + + super()._wakeup() + + def _listen(self, callback: typing.Callable[[T], None]) -> Disposer: + old_value = self() + + def listener(): + nonlocal old_value + new_value = self() + if old_value != new_value: + old_value = new_value + callback(old_value) + + self._listener = listener + callback(old_value) + + def dispose(): + self._listener = None + self._remove_dependencies() + + return dispose + + +def _effect(fn: typing.Callable[[], None]) -> Disposer: + """Create an effect to run arbitrary code in response to signal changes. + + An effect tracks which signals are accessed within the given callback + function `fn`, and re-runs the callback when those signals change. + + The callback may return a cleanup function. The cleanup function gets + run once, either when the callback is next called or when the effect + gets disposed, whichever happens first. + + Parameters + ---------- + fn : Callable[[], None] + The effect callback. + + Returns + ------- + Callable[[], None] + A function for disposing the effect. + """ + return Effect(lambda: batch(fn))._listen(lambda _: None) + + +@typing.overload +def effect( # noqa: D418 + deps: typing.Sequence[Signal], + *, + defer: bool = False, +) -> typing.Callable[[typing.Callable[..., None]], Disposer]: + """Create an effect with explicit dependencies. + + An effect is a side-effect that runs in response to signal changes. + + Parameters + ---------- + deps : Sequence[Signal] + The signals that the effect depends on. + + defer : bool, optional + Defer the effect until the next change, rather than running immediately. + By default, False. + + Returns + ------- + Callable[[Callable[..., None]], Disposer] + A decorator function for creating effects. + """ + + +@typing.overload +def effect(fn: typing.Callable[[], None], /) -> Disposer: # noqa: D418 + """Create an effect to run arbitrary code in response to signal changes. + + An effect tracks which signals are accessed within the given callback + function `fn`, and re-runs the callback when those signals change. + + The callback may return a cleanup function. The cleanup function gets + run once, either when the callback is next called or when the effect + gets disposed, whichever happens first. + + Parameters + ---------- + fn : Callable[[], None] + The effect callback. + + Returns + ------- + Callable[[], None] + A function for disposing the effect. + """ + + +def effect(*args, **kwargs) -> typing.Callable: + """Create an effect to run arbitrary code in response to signal changes.""" + if len(args) == 1 and callable(args[0]): + return _effect(args[0]) + + deps = args[0] if len(args) == 1 else kwargs.get("deps", []) + defer = kwargs.get("defer", False) + + def wrap(fn): + return _effect(on(deps=deps, defer=defer)(fn)) + + return wrap + + +def on(deps: typing.Sequence[Signal], *, defer: bool = False): + """Make dependencies for a function explicit. + + Parameters + ---------- + deps : Sequence[Signal] + The signals that the effect depends on. + + defer : bool, optional + Defer the effect until the next change, rather than running immediately. + By default, False. + + Returns + ------- + Callable[[Callable[..., None]], Callable[[], None]] + A callback function that can be registered as an effect. + """ + + def decorator(fn: typing.Callable[..., None]) -> typing.Callable[[], None]: + # The main effect function that will be run. + def main(): + return fn(*(dep() for dep in deps)) + + func = main + + if defer: + # Create a void function that accesses all of the + # dependencies so they will be tracked in an effect. + def void(): + nonlocal func + for dep in deps: + dep() + func = main + + func = void + + return lambda: func() # noqa: PLW0108 + + return decorator diff --git a/mesa/experimental/signals/code/_version.py b/mesa/experimental/signals/code/_version.py new file mode 100644 index 00000000000..e11bdb6a136 --- /dev/null +++ b/mesa/experimental/signals/code/_version.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +import importlib.metadata + +__version__ = importlib.metadata.version("signals") diff --git a/mesa/experimental/signals/code/test.py b/mesa/experimental/signals/code/test.py new file mode 100644 index 00000000000..82e3ececd2f --- /dev/null +++ b/mesa/experimental/signals/code/test.py @@ -0,0 +1,10 @@ + +from _core import Signal, Computed + +a = Signal(2) +b = Signal(3) + +c = Computed(lambda: a() + b()) + +c() +a.set(5) \ No newline at end of file diff --git a/mesa/experimental/signals/signal.py b/mesa/experimental/signals/signal.py index f6aaa5ce71c..d97c1b7493f 100644 --- a/mesa/experimental/signals/signal.py +++ b/mesa/experimental/signals/signal.py @@ -11,49 +11,7 @@ __all__ = ["Observable", "HasObservables"] - -class Computable: - """A Computable that is depended on one or more Observables. - - fixme how to make this work with Descriptors? - just to it as with ObservableList and SingallingList - so have a Computable and Computed class - declare the Computable at the top - assign the Computed in the instance - - """ - - def __init__(self, callable: Callable, *args, **kwargs): - """Draft Computable. - - Args: - callable: the callable that is computed - args: arguments to pass to the callable - **kwargs: keyword arguments to pass to the callable - - """ - # fixme: what if these are observable? - # basically the logic for subscribing should go into the observable class - # but we might have to split up a few things here - # easy fix is to just declare an attribute Observable at the class level - # and next assign a Computed to this attribute. - # not sure how this would work, because observable would return the computable - # not its internal value. So, you could either - # have a separate observable or let the observable check if the value - # is a computed and thus do an additional get operation on this - - self.callable = callable - self.args = args - self.kwargs = kwargs - self._is_dirty = True - - def __get__(self, instance, owner): - # fixme: not sure this will work correctly - - if self._is_dirty: - self.value = self.callable(*self.args, **self.kwargs) - self._is_dirty = False - return self.value +CURRENT_COMPUTED: "Computed" | None = None # the current Computed that is evaluating class BaseObservable(ABC): @@ -74,7 +32,14 @@ def __init__(self): self.fallback_value = None # fixme, should this be user specifiable? def __get__(self, instance, owner): - return getattr(instance, self.private_name) + value = getattr(instance, self.private_name) + + if CURRENT_COMPUTED is not None: + # there is a computed dependent on this Observable, so let's add + # this Observable as a parent + CURRENT_COMPUTED._add_parent(self, instance, self.public_name, value) + + return value def __set_name__(self, owner: "HasObservables", name: str): self.public_name = name @@ -118,6 +83,138 @@ def __set__(self, instance: "HasObservables", value): # noqa D103 setattr(instance, self.private_name, value) +class Computable(BaseObservable): + """A Computable that is depended on one or more Observables. + + fixme how to make this work with Descriptors? + just to it as with ObservableList and SingalingList + so have a Computable and Computed class + declare the Computable at the top + assign the Computed in the instance + + """ + + def __init__(self): + """Initialize a Computable.""" + super().__init__() + self.public_name: str + self.private_name: str + + self.signal_types: set = {"on_change"} + + def __get__(self, instance, owner): + computed = getattr(instance, self.private_name) + + old_value = computed.value + new_value = computed() + + if new_value != old_value: + instance.notify( + self.public_name, + old_value, + new_value, + "on_change", + ) + else: + return new_value + + def __set_name__(self, owner: "HasObservables", name: str): + self.public_name = name + self.private_name = f"_{name}" + owner.register_observable(self) + + def __set__(self, instance: "HasObservables", value): + setattr(instance, self.private_name, Computed(*value)) + + +class Computed: + def __init__(self, callable: Callable, *args, **kwargs): + self.callable = callable + self.args = args + self.kwargs = kwargs + self._is_dirty = False + self.value = None + + # fixme this is not correct, our HasObservable might have disappeared.... + # so we need to use weakrefs here. + self.parents: weakref.WeakKeyDictionary[HasObservables, dict[str], Any] = weakref.WeakKeyDictionary() + + def _set_dirty(self, signal): + self._is_dirty = True + # propagate this to all dependents + + def _add_parent(self, parent: "HasObservables", name: str, current_value: Any): + """Add a parent Observable. + + Args: + parent: the HasObservable instance to which the Observable belongs + name: the public name of the Observable + current_value: the current value of the Observable + + """ + parent.observe(name, All(), self._set_dirty) + + try: + self.parents[parent][name] = current_value + except KeyError: + self.parents[parent] = {name: current_value} + + def _remove_parents(self): + """Remove all parent Observables.""" + # we can ubsubscribe from everything on each parent + for parent in self.parents: + parent.unobserve(All(), All()) + + def __call__(self): + global CURRENT_COMPUTED # noqa: PLW0603 + CURRENT_COMPUTED = self + + if self._is_dirty: + changed = False + + # we might be dirty but values might have changed + # back and forth in our parents so let's check to make sure we + # really need to recalculate + for parent in self.parents.keyrefs(): + # does parent still exist? + if parent := parent(): + # if yes, compare old and new values for all + # tracked observables on this parent + for name, old_value in self.parents[parent].items(): + new_value = getattr(parent, name) + if new_value != old_value: + changed = True + break # we need to recalculate + else: + # trick for breaking cleanly out of nested for loops + # see https://stackoverflow.com/questions/653509/breaking-out-of-nested-loops + continue + break + else: + # one of our parents no longer exists + changed = True + break + + if changed: + # the dependencies of the computable function might have changed + # so we rebuilt + self._remove_parents() + + old = CURRENT_COMPUTED + CURRENT_COMPUTED = self + + try: + # fixme we need to handle error propagation somehow correctly + self._value = self.callable(*self.args, **self.kwargs) + except Exception as e: + raise e + finally: + CURRENT_COMPUTED = old + + self._is_dirty = False + return self.value + + class All: """Helper constant to subscribe to all Observables.""" @@ -166,10 +263,10 @@ def register_observable(cls, observable: BaseObservable): cls.observables[observable.public_name] = observable def observe( - self, - name: str | All, - signal_type: str | All, - handler: Callable, + self, + name: str | All, + signal_type: str | All, + handler: Callable, ): """Subscribe to the Observable for signal_type.