diff --git a/solara/lab/__init__.py b/solara/lab/__init__.py index 65ed389b9..7ae6fd5d5 100644 --- a/solara/lab/__init__.py +++ b/solara/lab/__init__.py @@ -2,7 +2,7 @@ from .components import * # noqa: F401, F403 from .utils import cookies, headers # noqa: F401, F403 from ..server.kernel_context import on_kernel_start # noqa: F401 -from ..tasks import task, use_task, Task, TaskResult # noqa: F401, F403 +from ..tasks import reactive_task, task, use_task, Task, TaskResult # noqa: F401, F403 from ..toestand import computed # noqa: F401 diff --git a/solara/tasks.py b/solara/tasks.py index 45949d8da..920fc8cb9 100644 --- a/solara/tasks.py +++ b/solara/tasks.py @@ -7,6 +7,7 @@ from enum import Enum from typing import ( Any, + Awaitable, Callable, Coroutine, Generic, @@ -850,3 +851,94 @@ def run(): return wrapper else: return wrapper(f) + + +@overload +def reactive_task( + function: None = None, +) -> Callable[[Callable[[], R]], solara.Reactive[solara.Result[R]]]: + ... + + +@overload +def reactive_task( + function: Callable[[], R], +) -> solara.Reactive[solara.Result[R]]: + ... + + +def reactive_task( + function: Union[None, Callable[[], Union[Coroutine[Any, Any, R], R]]] = None, +) -> Union[Callable[[Callable[[], R]], solara.Reactive[solara.Result[R]]], solara.Reactive[solara.Result[R]]]: + """Decorator to turn a function into a task that auto-executes when one of its dependencies changes. + + + The decorator returns a [reactive variable](/api/reactive) with the Result object as its value. + + ## Example + + ```solara + import asyncio + import time + import solara + from solara.lab import reactive_task + + + x = solara.reactive(2) + + # now x_square is a Reactive[Result[int]] + @reactive_task + async def x_square(): + await asyncio.sleep(2) + a = b + return x.value**2 + + + @solara.component + def Page(): + solara.SliderInt("x", value=x, min=0, max=10) + if x_square.value.state == solara.ResultState.FINISHED: + solara.Text(repr(x_square.value.value)) + solara.ProgressLinear(x_square.value.state == solara.ResultState.RUNNING) + + ``` + + """ + + def wrapper(func: Callable[[], Union[Coroutine[Any, Any, R], R]]) -> solara.Reactive[solara.Result[R]]: + from solara.toestand import AutoSubscribeContextManager + + def create_task(): + auto_subscriber: AutoSubscribeContextManager + if inspect.iscoroutinefunction(function): + task: Task[[], R] + + async def run_function_with_auto_subscribe_async(): + with auto_subscriber: + return await cast(Awaitable[R], func()) + + task = TaskAsyncio(run_function_with_auto_subscribe_async) + else: + + def run_function_with_auto_subscribe(): + with auto_subscriber: + return func() + + task = TaskThreaded(cast(Callable[[], R], run_function_with_auto_subscribe)) + + def on_reactive_variable_change(old_value, new_value): + task() + + auto_subscriber = AutoSubscribeContextManager(on_reactive_variable_change) + # used in tests + task.result._task = task # type: ignore + task.result._auto_subscriber = auto_subscriber # type: ignore + task() + return task.result + + return cast(solara.Reactive[solara.Result[R]], Proxy(create_task)) + + if function is None: + return wrapper + else: + return wrapper(function) diff --git a/solara/toestand.py b/solara/toestand.py index 247bb68b3..95a6ac0af 100644 --- a/solara/toestand.py +++ b/solara/toestand.py @@ -749,9 +749,9 @@ def cleanup(): class AutoSubscribeContextManager(AutoSubscribeContextManagerBase): - on_change: Callable[[], None] + on_change: Callable[[Any, Any], None] - def __init__(self, on_change: Callable[[], None]): + def __init__(self, on_change: Callable[[Any, Any], None]): super().__init__() self.on_change = on_change diff --git a/tests/unit/task_test.py b/tests/unit/task_test.py index 5584bfbfd..ebf345d9c 100644 --- a/tests/unit/task_test.py +++ b/tests/unit/task_test.py @@ -559,3 +559,49 @@ async def work(): raise TimeoutError("took too long, state = " + str(task._state)) assert task._state == TaskState.FINISHED assert last_value == 99 + + +def test_reactive_task(no_kernel_context): + context_id = "1" + x = solara.reactive(1) + y = solara.reactive(2) + calls = 0 + + from solara.tasks import reactive_task + + def conditional_add(): + nonlocal calls + calls += 1 + if x.value == 0: + return 42 + else: + return x.value + y.value + + z = reactive_task(conditional_add) + + kernel1 = kernel.Kernel() + context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel1, session_id="session-1") + with context1: + # assert z._auto_subscriber.value.reactive_used is None + assert z.value.value is None + z._task._last_finished_event.wait() # type: ignore + assert z._auto_subscriber.reactive_used == {x, y} # type: ignore + assert z.value.value == 3 + # assert z._auto_subscriber.subscribed == 1 + assert len(x._storage.listeners[context_id]) == 0 + assert len(x._storage.listeners2[context_id]) == 1 + assert len(y._storage.listeners[context_id]) == 0 + assert len(y._storage.listeners2[context_id]) == 1 + assert calls == 1 + x.value = 2 + z._task._last_finished_event.wait() # type: ignore + assert z.value.value == 4 + assert z._auto_subscriber.reactive_used == {x, y} # type: ignore + assert calls == 2 + y.value = 3 + z._task._last_finished_event.wait() # type: ignore + assert z.value.value == 5 + assert z._auto_subscriber.reactive_used == {x, y} # type: ignore + assert calls == 3 + assert len(x._storage.listeners2[context_id]) == 1 + assert len(y._storage.listeners2[context_id]) == 1