From 1cfa556f520efa979a411524a67d2bcbab31b490 Mon Sep 17 00:00:00 2001 From: "Maarten A. Breddels" Date: Fri, 19 Jan 2024 12:48:11 +0100 Subject: [PATCH] feat: reactive_task decorator (re)runs a task when a dependency changes Dependencies are other reactive variables. Examples: ``` import asyncio import time import solara from solara.lab import reactive_task x = solara.reactive(2) @reactive_task async def x_square(): await asyncio.sleep(2) a = b return x.value**2 @solara.component def Page(): solara.SliderInt("x", value=x, min=0, max=10) if x_square.value.state == solara.ResultState.FINISHED: solara.Text(repr(x_square.value.value)) solara.ProgressLinear(x_square.value.state == solara.ResultState.RUNNING) ``` --- solara/lab/__init__.py | 2 +- solara/tasks.py | 92 +++++++++++++++++++++++++++++++++++++++++ solara/toestand.py | 4 +- tests/unit/task_test.py | 46 +++++++++++++++++++++ 4 files changed, 141 insertions(+), 3 deletions(-) diff --git a/solara/lab/__init__.py b/solara/lab/__init__.py index ccbf3071e..857cbb56a 100644 --- a/solara/lab/__init__.py +++ b/solara/lab/__init__.py @@ -1,6 +1,6 @@ # isort: skip_file from .components import * # noqa: F401, F403 -from ..tasks import task, use_task, Task # noqa: F401, F403 +from ..tasks import reactive_task, task, use_task, Task # noqa: F401, F403 def __getattr__(name): diff --git a/solara/tasks.py b/solara/tasks.py index 12a0bf5cc..e743185e1 100644 --- a/solara/tasks.py +++ b/solara/tasks.py @@ -5,6 +5,7 @@ import threading from typing import ( Any, + Awaitable, Callable, Coroutine, Generic, @@ -432,3 +433,94 @@ def run(): return wrapper else: return wrapper(f) + + +@overload +def reactive_task( + function: None = None, +) -> Callable[[Callable[[], R]], solara.Reactive[solara.Result[R]]]: + ... + + +@overload +def reactive_task( + function: Callable[[], R], +) -> solara.Reactive[solara.Result[R]]: + ... + + +def reactive_task( + function: Union[None, Callable[[], Union[Coroutine[Any, Any, R], R]]] = None, +) -> Union[Callable[[Callable[[], R]], solara.Reactive[solara.Result[R]]], solara.Reactive[solara.Result[R]]]: + """Decorator to turn a function into a task that auto-executes when one of its dependencies changes. + + + The decorator returns a [reactive variable](/api/reactive) with the Result object as its value. + + ## Example + + ```solara + import asyncio + import time + import solara + from solara.lab import reactive_task + + + x = solara.reactive(2) + + # now x_square is a Reactive[Result[int]] + @reactive_task + async def x_square(): + await asyncio.sleep(2) + a = b + return x.value**2 + + + @solara.component + def Page(): + solara.SliderInt("x", value=x, min=0, max=10) + if x_square.value.state == solara.ResultState.FINISHED: + solara.Text(repr(x_square.value.value)) + solara.ProgressLinear(x_square.value.state == solara.ResultState.RUNNING) + + ``` + + """ + + def wrapper(func: Callable[[], Union[Coroutine[Any, Any, R], R]]) -> solara.Reactive[solara.Result[R]]: + from solara.toestand import AutoSubscribeContextManager + + def create_task(): + auto_subscriber: AutoSubscribeContextManager + if inspect.iscoroutinefunction(function): + task: Task[[], R] + + async def run_function_with_auto_subscribe_async(): + with auto_subscriber: + return await cast(Awaitable[R], func()) + + task = TaskAsyncio(run_function_with_auto_subscribe_async) + else: + + def run_function_with_auto_subscribe(): + with auto_subscriber: + return func() + + task = TaskThreaded(cast(Callable[[], R], run_function_with_auto_subscribe)) + + def on_reactive_variable_change(old_value, new_value): + task() + + auto_subscriber = AutoSubscribeContextManager(on_reactive_variable_change) + # used in tests + task.result._task = task # type: ignore + task.result._auto_subscriber = auto_subscriber # type: ignore + task() + return task.result + + return cast(solara.Reactive[solara.Result[R]], Proxy(create_task)) + + if function is None: + return wrapper + else: + return wrapper(function) diff --git a/solara/toestand.py b/solara/toestand.py index da91da691..d0997b101 100644 --- a/solara/toestand.py +++ b/solara/toestand.py @@ -627,9 +627,9 @@ def cleanup(): class AutoSubscribeContextManager(AutoSubscribeContextManagerBase): - on_change: Callable[[], None] + on_change: Callable[[Any, Any], None] - def __init__(self, on_change: Callable[[], None]): + def __init__(self, on_change: Callable[[Any, Any], None]): super().__init__() self.on_change = on_change diff --git a/tests/unit/task_test.py b/tests/unit/task_test.py index af6bb403e..2327e3c1b 100644 --- a/tests/unit/task_test.py +++ b/tests/unit/task_test.py @@ -472,3 +472,49 @@ async def work(): while result.state in [solara.ResultState.STARTING, solara.ResultState.RUNNING]: await asyncio.sleep(0.1) assert last_value == 99 + + +def test_reactive_task(no_kernel_context): + context_id = "1" + x = solara.reactive(1) + y = solara.reactive(2) + calls = 0 + + from solara.tasks import reactive_task + + def conditional_add(): + nonlocal calls + calls += 1 + if x.value == 0: + return 42 + else: + return x.value + y.value + + z = reactive_task(conditional_add) + + kernel1 = kernel.Kernel() + context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel1, session_id="session-1") + with context1: + # assert z._auto_subscriber.value.reactive_used is None + assert z.value.value is None + z._task._last_finished_event.wait() # type: ignore + assert z._auto_subscriber.reactive_used == {x, y} # type: ignore + assert z.value.value == 3 + # assert z._auto_subscriber.subscribed == 1 + assert len(x._storage.listeners[context_id]) == 0 + assert len(x._storage.listeners2[context_id]) == 1 + assert len(y._storage.listeners[context_id]) == 0 + assert len(y._storage.listeners2[context_id]) == 1 + assert calls == 1 + x.value = 2 + z._task._last_finished_event.wait() # type: ignore + assert z.value.value == 4 + assert z._auto_subscriber.reactive_used == {x, y} # type: ignore + assert calls == 2 + y.value = 3 + z._task._last_finished_event.wait() # type: ignore + assert z.value.value == 5 + assert z._auto_subscriber.reactive_used == {x, y} # type: ignore + assert calls == 3 + assert len(x._storage.listeners2[context_id]) == 1 + assert len(y._storage.listeners2[context_id]) == 1