Skip to content

Commit

Permalink
feat: reactive_task decorator (re)runs a task when a dependency changes
Browse files Browse the repository at this point in the history
Dependencies are other reactive variables.

Examples:

```
import asyncio
import time
import solara
from solara.lab import reactive_task

x = solara.reactive(2)

@reactive_task
async def x_square():
    await asyncio.sleep(2)
    a = b
    return x.value**2

@solara.component
def Page():
    solara.SliderInt("x", value=x, min=0, max=10)
    if x_square.value.state == solara.ResultState.FINISHED:
        solara.Text(repr(x_square.value.value))
    solara.ProgressLinear(x_square.value.state == solara.ResultState.RUNNING)
```
  • Loading branch information
maartenbreddels committed Feb 14, 2024
1 parent d39063b commit f0b4019
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 3 deletions.
2 changes: 1 addition & 1 deletion solara/lab/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# isort: skip_file
from .components import * # noqa: F401, F403
from ..server.kernel_context import on_kernel_start # noqa: F401
from ..tasks import task, use_task, Task, TaskResult # noqa: F401, F403
from ..tasks import reactive_task, task, use_task, Task, TaskResult # noqa: F401, F403
from ..toestand import computed # noqa: F401


Expand Down
92 changes: 92 additions & 0 deletions solara/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from enum import Enum
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Generic,
Expand Down Expand Up @@ -762,3 +763,94 @@ def run():
return wrapper
else:
return wrapper(f)


@overload
def reactive_task(
function: None = None,
) -> Callable[[Callable[[], R]], solara.Reactive[solara.Result[R]]]:
...


@overload
def reactive_task(
function: Callable[[], R],
) -> solara.Reactive[solara.Result[R]]:
...


def reactive_task(
function: Union[None, Callable[[], Union[Coroutine[Any, Any, R], R]]] = None,
) -> Union[Callable[[Callable[[], R]], solara.Reactive[solara.Result[R]]], solara.Reactive[solara.Result[R]]]:
"""Decorator to turn a function into a task that auto-executes when one of its dependencies changes.
The decorator returns a [reactive variable](/api/reactive) with the Result object as its value.
## Example
```solara
import asyncio
import time
import solara
from solara.lab import reactive_task
x = solara.reactive(2)
# now x_square is a Reactive[Result[int]]
@reactive_task
async def x_square():
await asyncio.sleep(2)
a = b
return x.value**2
@solara.component
def Page():
solara.SliderInt("x", value=x, min=0, max=10)
if x_square.value.state == solara.ResultState.FINISHED:
solara.Text(repr(x_square.value.value))
solara.ProgressLinear(x_square.value.state == solara.ResultState.RUNNING)
```
"""

def wrapper(func: Callable[[], Union[Coroutine[Any, Any, R], R]]) -> solara.Reactive[solara.Result[R]]:
from solara.toestand import AutoSubscribeContextManager

def create_task():
auto_subscriber: AutoSubscribeContextManager
if inspect.iscoroutinefunction(function):
task: Task[[], R]

async def run_function_with_auto_subscribe_async():
with auto_subscriber:
return await cast(Awaitable[R], func())

task = TaskAsyncio(run_function_with_auto_subscribe_async)
else:

def run_function_with_auto_subscribe():
with auto_subscriber:
return func()

task = TaskThreaded(cast(Callable[[], R], run_function_with_auto_subscribe))

def on_reactive_variable_change(old_value, new_value):
task()

auto_subscriber = AutoSubscribeContextManager(on_reactive_variable_change)
# used in tests
task.result._task = task # type: ignore
task.result._auto_subscriber = auto_subscriber # type: ignore
task()
return task.result

return cast(solara.Reactive[solara.Result[R]], Proxy(create_task))

if function is None:
return wrapper
else:
return wrapper(function)
4 changes: 2 additions & 2 deletions solara/toestand.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,9 +708,9 @@ def cleanup():


class AutoSubscribeContextManager(AutoSubscribeContextManagerBase):
on_change: Callable[[], None]
on_change: Callable[[Any, Any], None]

def __init__(self, on_change: Callable[[], None]):
def __init__(self, on_change: Callable[[Any, Any], None]):
super().__init__()
self.on_change = on_change

Expand Down
46 changes: 46 additions & 0 deletions tests/unit/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,3 +554,49 @@ async def work():
raise TimeoutError("took too long, state = " + str(result._state))
assert result._state == TaskState.FINISHED
assert last_value == 99


def test_reactive_task(no_kernel_context):
context_id = "1"
x = solara.reactive(1)
y = solara.reactive(2)
calls = 0

from solara.tasks import reactive_task

def conditional_add():
nonlocal calls
calls += 1
if x.value == 0:
return 42
else:
return x.value + y.value

z = reactive_task(conditional_add)

kernel1 = kernel.Kernel()
context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel1, session_id="session-1")
with context1:
# assert z._auto_subscriber.value.reactive_used is None
assert z.value.value is None
z._task._last_finished_event.wait() # type: ignore
assert z._auto_subscriber.reactive_used == {x, y} # type: ignore
assert z.value.value == 3
# assert z._auto_subscriber.subscribed == 1
assert len(x._storage.listeners[context_id]) == 0
assert len(x._storage.listeners2[context_id]) == 1
assert len(y._storage.listeners[context_id]) == 0
assert len(y._storage.listeners2[context_id]) == 1
assert calls == 1
x.value = 2
z._task._last_finished_event.wait() # type: ignore
assert z.value.value == 4
assert z._auto_subscriber.reactive_used == {x, y} # type: ignore
assert calls == 2
y.value = 3
z._task._last_finished_event.wait() # type: ignore
assert z.value.value == 5
assert z._auto_subscriber.reactive_used == {x, y} # type: ignore
assert calls == 3
assert len(x._storage.listeners2[context_id]) == 1
assert len(y._storage.listeners2[context_id]) == 1

0 comments on commit f0b4019

Please sign in to comment.