From 3cdce2bf6c6ff456e6398f5273cbf425326bf7a0 Mon Sep 17 00:00:00 2001 From: "Maarten A. Breddels" Date: Fri, 12 Jan 2024 12:00:29 +0100 Subject: [PATCH] feat: task and use_task for better asyncio task a thread execution use_task can replace use_thread, but also support coroutine functions. task is a decorator that creates a top level task that can be executed from any location (e.g. a click handler of a button). --- .github/workflows/codequality.yaml | 2 +- .pre-commit-config.yaml | 2 +- solara/lab/__init__.py | 1 + solara/server/starlette.py | 6 +- solara/tasks.py | 548 ++++++++++++++++++ solara/toestand.py | 2 +- solara/website/pages/api/__init__.py | 2 + solara/website/pages/api/task.py | 39 ++ solara/website/pages/api/use_task.py | 19 + .../pages/docs/content/10-howto/30-tasks.md | 92 +++ tests/unit/task_test.py | 509 ++++++++++++++++ 11 files changed, 1216 insertions(+), 6 deletions(-) create mode 100644 solara/tasks.py create mode 100644 solara/website/pages/api/task.py create mode 100644 solara/website/pages/api/use_task.py create mode 100644 solara/website/pages/docs/content/10-howto/30-tasks.md create mode 100644 tests/unit/task_test.py diff --git a/.github/workflows/codequality.yaml b/.github/workflows/codequality.yaml index 07233d3bc..f1b72765d 100644 --- a/.github/workflows/codequality.yaml +++ b/.github/workflows/codequality.yaml @@ -25,7 +25,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - pip install ".[dev]" mypy==0.991 black==22.12.0 codespell==2.2.4 "click<8.1.4" "traitlets<5.10.0" "matplotlib<3.8.0" + pip install ".[dev]" mypy==1.1.1 black==22.12.0 codespell==2.2.4 "click<8.1.4" "traitlets<5.10.0" "matplotlib<3.8.0" mypy --install-types --non-interactive solara - name: Run codespell run: codespell diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fbc8ae961..6ebc4854a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,7 +28,7 @@ repos: files: \.py$ args: [--profile=black] - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v0.991" # Use the sha / tag you want to point at + rev: "v1.1.1" # Use the sha / tag you want to point at hooks: - id: mypy args: [--no-strict-optional, --ignore-missing-imports] diff --git a/solara/lab/__init__.py b/solara/lab/__init__.py index f864a8ae6..f2e892c58 100644 --- a/solara/lab/__init__.py +++ b/solara/lab/__init__.py @@ -1,6 +1,7 @@ # isort: skip_file from .components import * # noqa: F401, F403 from ..server.kernel_context import on_kernel_start # noqa: F401 +from ..tasks import task, use_task, Task # noqa: F401, F403 from ..toestand import computed # noqa: F401 diff --git a/solara/server/starlette.py b/solara/server/starlette.py index c39eedfec..ba3a67217 100644 --- a/solara/server/starlette.py +++ b/solara/server/starlette.py @@ -90,19 +90,19 @@ async def process_messages_task(self): await self.ws.send_text(first) def close(self): - self.portal.call(self.ws.close) + self.portal.call(self.ws.close) # type: ignore def send_text(self, data: str) -> None: if settings.main.experimental_performance: self.to_send.append(data) else: - self.portal.call(self.ws.send_bytes, data) + self.portal.call(self.ws.send_bytes, data) # type: ignore def send_bytes(self, data: bytes) -> None: if settings.main.experimental_performance: self.to_send.append(data) else: - self.portal.call(self.ws.send_bytes, data) + self.portal.call(self.ws.send_bytes, data) # type: ignore async def receive(self): if hasattr(self.portal, "start_task_soon"): diff --git a/solara/tasks.py b/solara/tasks.py new file mode 100644 index 000000000..dfb793689 --- /dev/null +++ b/solara/tasks.py @@ -0,0 +1,548 @@ +import abc +import asyncio +import inspect +import logging +import threading +from typing import ( + Any, + Callable, + Coroutine, + Generic, + Optional, + TypeVar, + Union, + cast, + overload, +) + +import typing_extensions + +import solara +import solara.util +from solara.toestand import Singleton + +R = TypeVar("R") + +P = typing_extensions.ParamSpec("P") +logger = logging.getLogger("solara.task") + +try: + threading.Thread(target=lambda: None).start() + has_threads = True +except RuntimeError: + has_threads = False +has_threads + + +class Task(Generic[P, R], abc.ABC): + def __init__(self): + self.result = solara.reactive( + solara.Result[R]( + value=None, + state=solara.ResultState.INITIAL, + ) + ) + self.last_value: Optional[R] = None + + @property + def value(self) -> Optional[R]: + return self.result.value.value if self.result.value is not None else None + + @property + def state(self) -> solara.ResultState: + return self.result.value.state if self.result.value is not None else solara.ResultState.INITIAL + + @property + def error(self) -> Optional[Exception]: + return self.result.value.error if self.result.value is not None else None + + @abc.abstractmethod + def retry(self) -> None: + ... + + @abc.abstractmethod + def cancel(self) -> None: + ... + + @abc.abstractmethod + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None: + ... + + +class _CancelledErrorInOurTask(BaseException): + pass + + +class TaskAsyncio(Task[P, R]): + current_task: Optional[asyncio.Task] = None + current_future: Optional[asyncio.Future] = None + _cancel: Optional[Callable[[], None]] = None + _retry: Optional[Callable[[], None]] = None + + def __init__(self, run_in_thread: bool, function: Callable[P, Coroutine[Any, Any, R]]): + self.run_in_thread = run_in_thread + self.function = function + super().__init__() + + def cancel(self) -> None: + if self._cancel: + self._cancel() + else: + raise RuntimeError("Cannot cancel task, never started") + + def retry(self): + if self._retry: + self._retry() + else: + raise RuntimeError("Cannot retry task, never started") + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None: + self.current_future = future = asyncio.Future[R]() + current_task: asyncio.Task[None] + if self.current_task: + self.current_task.cancel() + + def retry(): + self(*args, **kwargs) + + def cancel(): + event_loop = current_task.get_loop() + # cancel after cancel is a no-op + self._cancel = lambda: None + if asyncio.current_task() == current_task: + if event_loop == asyncio.get_event_loop(): + # we got called in our own task and event loop + raise _CancelledErrorInOurTask() + else: + current_task.cancel() + self.result.value = solara.Result[R](value=self.last_value, state=solara.ResultState.CANCELLED, cancel=self.cancel, _retry=retry) + else: + current_task.cancel() + self.result.value = solara.Result[R](value=self.last_value, state=solara.ResultState.CANCELLED, cancel=self.cancel, _retry=retry) + + common: _CommonArgs = { + "cancel": cancel, + "_retry": retry, + } + self._cancel = cancel + self._retry = retry + call_event_loop = asyncio.get_event_loop() + if self.run_in_thread: + thread_event_loop = asyncio.new_event_loop() + self.current_task = current_task = thread_event_loop.create_task(self._async_run(call_event_loop, future, common, args, kwargs)) + + def runs_in_thread(): + try: + thread_event_loop.run_until_complete(current_task) + except asyncio.CancelledError as e: + # if : + call_event_loop.call_soon_threadsafe(future.set_exception, e) + # self.result.value = self.result.value = solara.Result[R](value=self.last_value, state=solara.ResultState.CANCELLED, **common) + except Exception as e: + logger.exception("error running in thread") + call_event_loop.call_soon_threadsafe(future.set_exception, e) + raise + + self.result.value = solara.Result[R](value=self.last_value, state=solara.ResultState.STARTING, **common) + thread = threading.Thread(target=runs_in_thread) + thread.start() + else: + self.current_task = current_task = asyncio.create_task(self._async_run(call_event_loop, future, common, args, kwargs)) + self.result.value = solara.Result[R](value=self.last_value, state=solara.ResultState.STARTING, **common) + + async def _async_run(self, call_event_loop: asyncio.AbstractEventLoop, future: asyncio.Future, common: "_CommonArgs", args, kwargs) -> None: + task_for_this_call = asyncio.current_task() + assert task_for_this_call is not None + + def still_active(): + assert task_for_this_call is not None + return (self.current_task == task_for_this_call) and not task_for_this_call.cancelled() + + assert self.current_task is task_for_this_call + + if still_active(): + self.result.value = solara.Result[R](value=self.last_value, state=solara.ResultState.STARTING, **common) + + async def runner(): + try: + if still_active(): + self.result.value = solara.Result[R](value=self.last_value, state=solara.ResultState.RUNNING, **common) + self.last_value = value = await self.function(*args, **kwargs) + if still_active() and not task_for_this_call.cancelled(): + self.result.value = self.result.value = solara.Result[R](value=value, state=solara.ResultState.FINISHED, **common) + logger.info("setting result to %r", value) + call_event_loop.call_soon_threadsafe(future.set_result, value) + except Exception as e: + if still_active(): + self.result.value = self.result.value = solara.Result[R](value=self.last_value, error=e, state=solara.ResultState.ERROR, **common) + call_event_loop.call_soon_threadsafe(future.set_exception, e) + # Although this seems like an easy way to handle cancellation, an early cancelled task will never execute + # so this code will never execute, so we need to handle this in the cancel function in __call__ + # except asyncio.CancelledError as e: + # if still_active(): + # self.result.value = self.result.value = solara.Result[R](value=self.last_value, state=solara.ResultState.CANCELLED, **common) + # call_event_loop.call_soon_threadsafe(future.set_exception, e) + # But... if we call cancel in our own task, we still need to do it from this place + except _CancelledErrorInOurTask as e: + try: + # maybe there is a different way to get a full stack trace? + raise asyncio.CancelledError() from e + except asyncio.CancelledError as e: + if still_active(): + self.result.value = self.result.value = solara.Result[R](value=self.last_value, state=solara.ResultState.CANCELLED, **common) + call_event_loop.call_soon_threadsafe(future.set_exception, e) + + await runner() + + +class TaskThreaded(Task[P, R]): + _current_cancel_event: Optional[threading.Event] = None + current_thread: Optional[threading.Thread] = None + running_thread: Optional[threading.Thread] = None + _last_finished_event: Optional[threading.Event] = None + _cancel: Optional[Callable[[], None]] = None + _retry: Optional[Callable[[], None]] = None + + def __init__(self, function: Callable[P, R]): + super().__init__() + self.__qualname__ = function.__qualname__ + self.function = function + self.lock = threading.Lock() + + def cancel(self) -> None: + if self._cancel: + self._cancel() + else: + raise RuntimeError("Cannot cancel task, never started") + + def retry(self): + if self._retry: + self._retry() + else: + raise RuntimeError("Cannot retry task, never started") + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None: + self._last_finished_event = _last_finished_event = threading.Event() + self._current_cancel_event = cancel_event = threading.Event() + + def retry(): + self(*args, **kwargs) + + def cancel(): + cancel_event.set() + if threading.current_thread() == current_thread: + raise solara.util.CancelledError() + self._current_cancel_event = None + + self._retry = retry + self._cancel = cancel + + common: _CommonArgs = { + "cancel": cancel, + "_retry": retry, + } + self.current_thread = current_thread = threading.Thread(target=lambda: self._run(_last_finished_event, cancel_event, common, args, kwargs), daemon=True) + self.result.value = solara.Result[R](value=self.last_value, state=solara.ResultState.STARTING, **common) + current_thread.start() + + def _run(self, _last_finished_event, cancel_event, common: "_CommonArgs", args, kwargs) -> None: + def am_i_the_last_called_thread(): + return self.running_thread == threading.current_thread() + + def runner(): + intrusive_cancel = True + wait_for_thread = None + with self.lock: + # if there is a current thread already, we'll need + # to wait for it. copy the ref, and set ourselves + # as the current one + if self.running_thread: + wait_for_thread = self.running_thread + self.running_thread = threading.current_thread() + if wait_for_thread is not None: + if am_i_the_last_called_thread(): + self.result.value = solara.Result[R](value=self.last_value, state=solara.ResultState.WAITING, **common) + # don't start before the previous is stopped + try: + wait_for_thread.join() + except: # noqa + pass + if threading.current_thread() != self.running_thread: + # in case a new thread was started that also was waiting for the previous + # thread to st stop, we can finish this + return + # we previously set current to None, but if we do not do that, we can still render the old value + # while we can still show a loading indicator using the .state + # result.current = None + if am_i_the_last_called_thread(): + self.result.value = solara.Result[R](value=self.last_value, state=solara.ResultState.RUNNING, **common) + + callback = self.function + try: + guard = solara.util.cancel_guard(cancel_event) if intrusive_cancel else solara.util.nullcontext() + try: + # we only use the cancel_guard context manager around + # the function calls to f. We don't want to guard around + # a call to react, since that might slow down rendering + # during rendering + with guard: + if am_i_the_last_called_thread(): + value = callback(*args, **kwargs) + if inspect.isgenerator(value): + generator = value + self.last_value = None + while True: + try: + with guard: + self.last_value = next(generator) + if am_i_the_last_called_thread(): + self.result.value = self.result.value = solara.Result[R]( + value=self.last_value, state=solara.ResultState.RUNNING, **common + ) + except StopIteration: + break + if am_i_the_last_called_thread(): + self.result.value = self.result.value = solara.Result[R](value=self.last_value, state=solara.ResultState.FINISHED, **common) + else: + self.last_value = None + self.last_value = value + if am_i_the_last_called_thread(): + self.result.value = self.result.value = solara.Result[R](value=self.last_value, state=solara.ResultState.FINISHED, **common) + except Exception as e: + if am_i_the_last_called_thread(): + logger.exception(e) + self.result.value = self.result.value = solara.Result[R](value=self.last_value, error=e, state=solara.ResultState.ERROR, **common) + return + except solara.util.CancelledError: + pass + # this means this thread is cancelled not be request, but because + # a new thread is running, we can ignore this + finally: + if am_i_the_last_called_thread(): + self.running_thread = None + logger.info("thread done!") + if cancel_event.is_set(): + self.result.value = solara.Result[R](value=self.last_value, state=solara.ResultState.CANCELLED, **common) + _last_finished_event.set() + + runner() + + +class _CommonArgs(typing_extensions.TypedDict): + cancel: Callable[[], None] + _retry: Callable[[], None] + + +# TODO: Not sure if we want to use this, or have all local variables in Task subclasses be reactive vars +class Proxy: + def __init__(self, factory): + self._instance = Singleton(factory) + + def __getattr__(self, name): + return getattr(self._instance.value, name) + + def __setattr__(self, name, value): + if name == "_instance": + super().__setattr__(name, value) + else: + setattr(self._instance.value, name, value) + + def __call__(self, *args, **kwargs): + return self._instance.value(*args, **kwargs) + + +@overload +def task( + f: None = None, + *, + prefer_threaded: bool = ..., +) -> Callable[[Callable[P, R]], Task[P, R]]: + ... + + +@overload +def task( + f: Callable[P, Union[Coroutine[Any, Any, R], R]], + *, + prefer_threaded: bool = ..., +) -> Task[P, R]: + ... + + +def task( + f: Union[None, Callable[P, Union[Coroutine[Any, Any, R], R]]] = None, + *, + prefer_threaded: bool = True, +) -> Union[Callable[[Callable[P, R]], Task[P, R]], Task[P, R]]: + """Decorator to turn a function or coroutine function into a task. + + A task is a callable that will run the function in a separate thread for normal functions + or a asyncio task for a coroutine function. + + The task callable does only execute the function once when called multiple times, + and will cancel previous executions if the function is called again before the previous finished. + + The wrapped function return value is available as the `.value` attribute of the task object. + + ## Example + + ```solara + import asyncio + import solara + from solara.lab import task + + @task + async def fetch_data(): + await asyncio.sleep(2) + return "The answer is 42" + + + @solara.component + def Page(): + solara.Button("Fetch data", on_click=fetch_data) + solara.ProgressLinear(fetch_data.state == solara.ResultState.RUNNING) + if fetch_data.state == solara.ResultState.FINISHED: + solara.Text(fetch_data.value) + + + ## Arguments + + - `f`: Function to turn into task or None + - `prefer_threaded` - bool: Will run coroutine functions as a task in a thread when threads are available. + This ensures that even when a coroutine functions calls a blocking function the UI is still responsive. + On platform where threads are not supported (like Pyodide / WASM / Emscripten / PyScript), a coroutine + function will always run in the current event loop. + + ``` + + """ + + def wrapper(f: Union[None, Callable[P, Union[Coroutine[Any, Any, R], R]]]) -> Task[P, R]: + def create_task(): + if inspect.iscoroutinefunction(f): + return TaskAsyncio[P, R](prefer_threaded and has_threads, f) + else: + return TaskThreaded[P, R](cast(Callable[P, R], f)) + + return cast(Task[P, R], Proxy(create_task)) + + if f is None: + return wrapper + else: + return wrapper(f) + + +@overload +def use_task( + f: None = None, + dependencies=[], + *, + raise_error=..., + prefer_threaded=..., +) -> Callable[[Callable[[], R]], solara.Result[R]]: + ... + + +@overload +def use_task( + f: Callable[P, R], + dependencies=[], + *, + raise_error=..., + prefer_threaded=..., +) -> solara.Result[R]: + ... + + +def use_task( + f: Union[None, Callable[[], R]] = None, + dependencies=[], + *, + raise_error=True, + prefer_threaded=True, +) -> Union[Callable[[Callable[[], R]], solara.Result[R]], solara.Result[R]]: + """Run a function or coroutine as a task and return the result. + + ## Example + + ### Running in a thread + + ```solara + import time + import solara + from solara.lab import use_task + + + @solara.component + def Page(): + number = solara.use_reactive(4) + + def square(): + time.sleep(1) + return number.value**2 + + result: solara.Result[bool] = use_task(square, dependencies=[number.value]) + + solara.InputInt("Square", value=number, continuous_update=True) + if result.state == solara.ResultState.FINISHED: + solara.Success(f"Square of {number} == {result.value}") + solara.ProgressLinear(result.state == solara.ResultState.RUNNING) + ``` + + ### Running in a asyncio task + + Note that the only difference is our function is now a coroutine function, + and we use `asyncio.sleep` instead of `time.sleep`. + ```solara + import asyncio + import solara + from solara.lab import use_task + + + @solara.component + def Page(): + number = solara.use_reactive(4) + + async def square(): + await asyncio.sleep(1) + return number.value**2 + + result: solara.Result[bool] = use_task(square, dependencies=[number.value]) + + solara.InputInt("Square", value=number, continuous_update=True) + if result.state == solara.ResultState.FINISHED: + solara.Success(f"Square of {number} == {result.value}") + solara.ProgressLinear(result.state == solara.ResultState.RUNNING) + ``` + + ## Arguments + + - `f`: The function or coroutine to run as a task. + - `dependencies`: A list of dependencies that will trigger a rerun of the task when changed. + - `raise_error`: If true, an error in the task will be raised. If false, the error should be handled by the + user and is available in the `.error` attribute of the task object. + - `prefer_threaded` - bool: Will run coroutine functions as a task in a thread when threads are available. + This ensures that even when a coroutine functions calls a blocking function the UI is still responsive. + On platform where threads are not supported (like Pyodide / WASM / Emscripten / PyScript), a coroutine + function will always run in the current event loop. + + + """ + + def wrapper(f): + task_instance = solara.use_memo(lambda: task(f, prefer_threaded=prefer_threaded), dependencies=dependencies) + + def run(): + task_instance() + return task_instance.cancel + + solara.use_effect(run, dependencies=dependencies) + if raise_error: + if task_instance.state == solara.ResultState.ERROR and task_instance.error is not None: + raise task_instance.error + return task_instance.result.value + + if f is None: + return wrapper + else: + return wrapper(f) diff --git a/solara/toestand.py b/solara/toestand.py index c1f8ee25e..30d70fba6 100644 --- a/solara/toestand.py +++ b/solara/toestand.py @@ -74,7 +74,7 @@ def use_sync_external_store_with_selector(subscribe, get_snapshot: Callable[[], def merge_state(d1: S, **kwargs) -> S: if dataclasses.is_dataclass(d1): - return dataclasses.replace(d1, **kwargs) + return dataclasses.replace(d1, **kwargs) # type: ignore if "pydantic" in sys.modules and isinstance(d1, sys.modules["pydantic"].BaseModel): return type(d1)(**{**d1.dict(), **kwargs}) # type: ignore return cast(S, {**cast(dict, d1), **kwargs}) diff --git a/solara/website/pages/api/__init__.py b/solara/website/pages/api/__init__.py index 9d1aa41ec..0666c9d90 100644 --- a/solara/website/pages/api/__init__.py +++ b/solara/website/pages/api/__init__.py @@ -132,6 +132,8 @@ "on_kernel_start", "tab", "tabs", + "task", + "use_task", ], }, ] diff --git a/solara/website/pages/api/task.py b/solara/website/pages/api/task.py new file mode 100644 index 000000000..7dfdc686e --- /dev/null +++ b/solara/website/pages/api/task.py @@ -0,0 +1,39 @@ +"""# Task + +A global way to run code in the background, with the UI available to the user. This is useful for long running tasks, like downloading data. + +The task decorator turns a function or coroutine function into a task object. +A task is a callable that will run the function in a separate thread for normal functions +or an asyncio task for a coroutine function. Note that on platforms where threads are supported, +asyncio tasks will still be executed in threads (unless the +`prefer_thread=False` argument is passed) + +The task object will execute the function only once per virtual kernel. When called multiple times, +the previously started run will be cancelled. + +The Result object is wrapped as a reactive variable in the `.result` attribute of the task object, so to access the underlying value, use `.result.value`. + +The return value of the function is available as the `.value` attribute of the result object, meaning it's accessible as `.result.value.value`. While +a demonstation of composability, this is not very readable, so you can also use `.value` property to access the return value of the function. + +## Task object + + * `.result`: A reactive variable wrapping the result object. + * `.value`: Alias for `.result.value.value` + * `.error`: Alias for `.result.value.error` + * `.state`: Alias for `.result.value.state` + * `.cancel()`: Cancels the task. + + + +""" +import solara +import solara.autorouting +import solara.lab +from solara.website.utils import apidoc + +from . import NoPage + +title = "Task" +Page = NoPage +__doc__ += apidoc(solara.lab.task) # type: ignore diff --git a/solara/website/pages/api/use_task.py b/solara/website/pages/api/use_task.py new file mode 100644 index 000000000..8101d6d88 --- /dev/null +++ b/solara/website/pages/api/use_task.py @@ -0,0 +1,19 @@ +"""# use_task + +A hook that allows you to run code in the background, with the UI available to the user. This is useful for long running tasks, like downloading data. + +Unlike with the [`@task`](/api/task) decorator, the result is not globally shared, but only available to the component that called `use_task`. + +Note that unlike the [`@task`](/api/task) decorator, the task is invoked immediately, and the hook will return the Result object, instead of the task object. + +""" +import solara +import solara.autorouting +import solara.lab +from solara.website.utils import apidoc + +from . import NoPage + +title = "use_task" +Page = NoPage +__doc__ += apidoc(solara.lab.use_task) # type: ignore diff --git a/solara/website/pages/docs/content/10-howto/30-tasks.md b/solara/website/pages/docs/content/10-howto/30-tasks.md new file mode 100644 index 000000000..5bb1eb678 --- /dev/null +++ b/solara/website/pages/docs/content/10-howto/30-tasks.md @@ -0,0 +1,92 @@ +# Long running code + +Solara can run long running code in tasks to have a responsive UI while code runs in the background. +For IO bounds code, we often use async code, which runs in the same thread as your UI. +For CPU intensive, or blocking code, we want to use threads. + +## Async task running on an event + +```python +import asyncio +import solara +from solara.lab import task + +@task +async def fetch_data(): + await asyncio.sleep(2) + return "The answer is 42" + + +@solara.component +def Page(): + solara.Button("Fetch data", on_click=fetch_data) + solara.ProgressLinear(fetch_data.state == solara.ResultState.RUNNING) + if fetch_data.state == solara.ResultState.FINISHED: + solara.Text(fetch_data.value) +``` + + +## Threaded task running on an event + +```python +import time +import solara +from solara.lab import task + +@task +def fetch_data(): + time.sleep(2) + return "The answer is 42" + + +@solara.component +def Page(): + print(fetch_data.state) + solara.Button("Fetch data", on_click=fetch_data) + solara.ProgressLinear(fetch_data.state == solara.ResultState.RUNNING) + if fetch_data.state == solara.ResultState.FINISHED: + solara.Text(fetch_data.value) +``` + + + +## Threaded task running when data changes + +A common situation in UI's is the need to run and re-run a long running +function when data changes. An example of that is a UI elements like a +[Select](/api/select) which triggers fetching of data from a server. +In this case, the [`use_task`](/api/use_task) hook can be used (with call=True). + + +```solara +import time +import solara +from solara.lab import task, use_task +import requests + + +countries = ['Aruba', 'the Netherlands', 'USA', 'China'] +country = solara.reactive('Aruba') + +@solara.component +def Page(): + def get_country_data(): + return requests.get(f"https://restcountries.com/v3.1/name/{country.value}").json()[0] + + result = use_task(get_country_data, dependencies=[country.value]) + + solara.Select("Choose country", value=country, values=countries) + solara.ProgressLinear(result.state == solara.ResultState.RUNNING) + + if result.state == solara.ResultState.FINISHED: + languages = result.value["languages"] + solara.Markdown(f""" + # Languages in {country.value} + +
+            {repr(languages)}
+            
+ """) + elif result.state == solara.ResultState.ERROR: + solara.Error(f"Error occurred: {result.error}") +``` diff --git a/tests/unit/task_test.py b/tests/unit/task_test.py new file mode 100644 index 000000000..d449a66d9 --- /dev/null +++ b/tests/unit/task_test.py @@ -0,0 +1,509 @@ +import asyncio +import time + +import ipyvuetify as v +import pytest +from reacton import ipywidgets as w + +import solara.tasks +from solara.server import kernel, kernel_context +from solara.tasks import use_task +from solara.toestand import Computed + + +@solara.tasks.task +def something(count: int, delay: float = 0.1): + time.sleep(delay) + return "42" * count + + +@solara.component +def ComputeButton(count, delay: float = 0.1, on_render=lambda: None): + solara.Button("Run", on_click=lambda: something(count, delay)) + on_render() + # print(something.result.value) + if something.result.value: + if something.state == solara.ResultState.RUNNING: + solara.Info("running") + elif something.state == solara.ResultState.FINISHED: + solara.Info("Done: " + str(something.value)) + elif something.state == solara.ResultState.ERROR: + solara.Info("Error: " + str(something.error)) + else: + solara.Info("Cancelled") + + +@solara.component +def Page(): + ComputeButton(2) + ComputeButton(3) + + +cancel_square = False + + +@solara.tasks.task +def square(value: float): + if cancel_square: + square.cancel() + return value**2 + + +@solara.component +def SquareButton(value, on_render=lambda: None): + solara.Button("Run", on_click=lambda: square(value)) + on_render() + if square.result.value: + if square.state == solara.ResultState.RUNNING: + solara.Info("running") + elif square.state == solara.ResultState.FINISHED: + solara.Info("Done: " + str(square.value)) + elif square.state == solara.ResultState.ERROR: + solara.Info("Error: " + str(square.error)) + else: + solara.Info("Cancelled") + + +def test_task_basic(): + results = [] + + def collect(): + results.append((square.state, square.value)) + + box, rc = solara.render(SquareButton(3, on_render=collect), handle_error=False) + button = rc.find(v.Btn, children=["Run"]).widget + button.click() + assert square._last_finished_event # type: ignore + square._last_finished_event.wait() # type: ignore + assert results == [ + (solara.ResultState.INITIAL, None), + (solara.ResultState.STARTING, None), + (solara.ResultState.RUNNING, None), + (solara.ResultState.FINISHED, 9), + ] + results.clear() + rc.render(SquareButton(2, on_render=collect)) + button = rc.find(v.Btn, children=["Run"]).widget + button.click() + square._last_finished_event.wait() # type: ignore + assert results == [ + # extra finished due to the rc.render call + (solara.ResultState.FINISHED, 9), + (solara.ResultState.STARTING, 9), + (solara.ResultState.RUNNING, 9), + (solara.ResultState.FINISHED, 4), + ] + + +# async version + +cancel_square_async = False + + +@solara.tasks.task +async def square_async(value: float): + if cancel_square_async: + square_async.cancel() + return value**2 + + +@solara.component +def SquareButtonAsync(value, on_render=lambda: None): + solara.Button("Run", on_click=lambda: square_async(value)) + on_render() + if square_async.result.value: + if square_async.state == solara.ResultState.RUNNING: + solara.Info("running") + elif square_async.state == solara.ResultState.FINISHED: + solara.Info("Done: " + str(square_async.value)) + elif square_async.state == solara.ResultState.ERROR: + solara.Info("Error: " + str(square_async.error)) + else: + solara.Info("Cancelled") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("run_in_thread", [True, False]) +async def test_task_basic_async(run_in_thread): + results = [] + assert square_async._instance.value.run_in_thread # type: ignore + square_async._instance.value.run_in_thread = run_in_thread # type: ignore + + def collect(): + results.append((square_async.state, square_async.value)) + + box, rc = solara.render(SquareButtonAsync(3, on_render=collect), handle_error=False) + button = rc.find(v.Btn, children=["Run"]).widget + button.click() + assert square_async.current_future # type: ignore + await square_async.current_future # type: ignore + assert results == [ + (solara.ResultState.INITIAL, None), + (solara.ResultState.STARTING, None), + (solara.ResultState.RUNNING, None), + (solara.ResultState.FINISHED, 9), + ] + results.clear() + rc.render(SquareButtonAsync(2, on_render=collect)) + button = rc.find(v.Btn, children=["Run"]).widget + button.click() + await square_async.current_future # type: ignore + assert results == [ + # extra finished due to the rc.render call + (solara.ResultState.FINISHED, 9), + (solara.ResultState.STARTING, 9), + (solara.ResultState.RUNNING, 9), + (solara.ResultState.FINISHED, 4), + ] + square_async._instance.value.run_in_thread = True # type: ignore + + +def test_task_two(): + results2 = [] + results3 = [] + # ugly reset + square.last_value = None + + def collect2(): + results2.append((square.state, square.value)) + + def collect3(): + results3.append((square.state, square.value)) + + @solara.component + def Test(): + SquareButton(2, on_render=collect2) + SquareButton(3, on_render=collect3) + + box, rc = solara.render(Test(), handle_error=False) + button = rc.find(v.Btn, children=["Run"])[0].widget + button.click() + assert square._last_finished_event # type: ignore + square._last_finished_event.wait() # type: ignore + assert ( + results2 + == results3 + == [ + (solara.ResultState.INITIAL, None), + (solara.ResultState.STARTING, None), + (solara.ResultState.RUNNING, None), + (solara.ResultState.FINISHED, 4), + ] + ) + assert len(rc.find(children=["Done: 4"])) == 2 + + # now we press the second button + results2.clear() + results3.clear() + button = rc.find(v.Btn, children=["Run"])[1].widget + button.click() + assert square._last_finished_event # type: ignore + square._last_finished_event.wait() # type: ignore + assert ( + results2 + == results3 + == [ + # not a finished event, because we don't render from the start + (solara.ResultState.STARTING, 4), + (solara.ResultState.RUNNING, 4), + (solara.ResultState.FINISHED, 9), + ] + ) + assert len(rc.find(children=["Done: 9"])) == 2 + + +def test_task_cancel_retry(): + global cancel_square + results = [] + + # ugly reset + square.last_value = None + + def collect(): + results.append((square.state, square.value)) + + box, rc = solara.render(SquareButton(5, on_render=collect), handle_error=False) + button = rc.find(v.Btn, children=["Run"]).widget + cancel_square = True + try: + button.click() + assert square._last_finished_event # type: ignore + square._last_finished_event.wait() # type: ignore + assert results == [ + (solara.ResultState.INITIAL, None), + (solara.ResultState.STARTING, None), + (solara.ResultState.RUNNING, None), + (solara.ResultState.CANCELLED, None), + ] + finally: + cancel_square = False + results.clear() + square.retry() + square._last_finished_event.wait() # type: ignore + assert results == [ + (solara.ResultState.STARTING, None), + (solara.ResultState.RUNNING, None), + (solara.ResultState.FINISHED, 5**2), + ] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("run_in_thread", [True, False]) +async def test_task_async_cancel_retry(run_in_thread): + global cancel_square_async + results = [] + + assert square_async._instance.value.run_in_thread # type: ignore + square_async._instance.value.run_in_thread = run_in_thread # type: ignore + + # ugly reset + square_async.last_value = None + + def collect(): + results.append((square_async.state, square_async.value)) + + box, rc = solara.render(SquareButtonAsync(5, on_render=collect), handle_error=False) + button = rc.find(v.Btn, children=["Run"]).widget + cancel_square_async = True + try: + button.click() + assert square_async.current_future # type: ignore + try: + await square_async.current_future # type: ignore + except asyncio.CancelledError: + pass + + assert results == [ + (solara.ResultState.INITIAL, None), + (solara.ResultState.STARTING, None), + (solara.ResultState.RUNNING, None), + (solara.ResultState.CANCELLED, None), + ] + finally: + cancel_square_async = False + results.clear() + square_async.retry() + await square_async.current_future # type: ignore + assert results == [ + (solara.ResultState.STARTING, None), + (solara.ResultState.RUNNING, None), + (solara.ResultState.FINISHED, 5**2), + ] + + square_async._instance.value.run_in_thread = True # type: ignore + + +def test_task_scopes(no_kernel_context): + results1 = [] + results2 = [] + + def collect1(): + results1.append((something.state, something.value)) + + def collect2(): + results2.append((something.state, something.value)) + + kernel1 = kernel.Kernel() + kernel2 = kernel.Kernel() + assert kernel_context.current_context[kernel_context.get_current_thread_key()] is None + + context1 = kernel_context.VirtualKernelContext(id="toestand-1", kernel=kernel1, session_id="session-1") + context2 = kernel_context.VirtualKernelContext(id="toestand-2", kernel=kernel2, session_id="session-2") + + with context1: + box1, rc1 = solara.render(ComputeButton(5, on_render=collect1), handle_error=False) + button1 = rc1.find(v.Btn, children=["Run"]).widget + + with context2: + box2, rc2 = solara.render(ComputeButton(5, on_render=collect2), handle_error=False) + button2 = rc2.find(v.Btn, children=["Run"]).widget + + with context1: + button1.click() + finished_event1 = something._last_finished_event # type: ignore + assert finished_event1 + + with context2: + assert something._last_finished_event is not finished_event1 # type: ignore + assert something._last_finished_event is None # type: ignore + + finished_event1.wait() + assert results1 == [ + (solara.ResultState.INITIAL, None), + (solara.ResultState.STARTING, None), + (solara.ResultState.RUNNING, None), + (solara.ResultState.FINISHED, "4242424242"), + ] + # results1.clear() + assert results2 == [(solara.ResultState.INITIAL, None)] + + with context2: + button2.click() + finished_event2 = something._last_finished_event # type: ignore + assert finished_event2 + finished_event2.wait() + assert results2 == [ + (solara.ResultState.INITIAL, None), + (solara.ResultState.STARTING, None), + (solara.ResultState.RUNNING, None), + (solara.ResultState.FINISHED, "4242424242"), + ] + + +def test_task_and_computed(no_kernel_context): + @Computed + def square_minus_one(): + # breakpoint() + return square.value - 1 + + kernel1 = kernel.Kernel() + kernel2 = kernel.Kernel() + assert kernel_context.current_context[kernel_context.get_current_thread_key()] is None + + context1 = kernel_context.VirtualKernelContext(id="t1", kernel=kernel1, session_id="session-1") + context2 = kernel_context.VirtualKernelContext(id="t2", kernel=kernel2, session_id="session-2") + + with context1: + r1 = square.result + assert len(square.result._storage.listeners2["t1"]) == 0 + square(5) + assert square._last_finished_event # type: ignore + square._last_finished_event.wait() # type: ignore + # accessing will add it to the listeners + assert len(square.result._storage.listeners2["t1"]) == 0 + assert square_minus_one.value == 24 + assert len(square.result._storage.listeners2["t1"]) == 1 + square_minus_one._auto_subscriber.value.reactive_used == {square.value} + + with context2: + r2 = square.result + assert len(square.result._storage.listeners2["t2"]) == 0 + square(6) + assert square._last_finished_event # type: ignore + square._last_finished_event.wait() # type: ignore + assert len(square.result._storage.listeners2["t2"]) == 0 + assert square_minus_one.value == 35 + assert len(square.result._storage.listeners2["t2"]) == 1 + square_minus_one._auto_subscriber.value.reactive_used == {square.value} + + with context1: + assert r1 is square.result + assert len(square.result._storage.listeners2["t1"]) == 1 + square._last_finished_event = None # type: ignore + square_minus_one._auto_subscriber.value.reactive_used == {square.value} + assert square_minus_one.value == 24 + square(7) + square_minus_one._auto_subscriber.value.reactive_used == {square.value} + assert square._last_finished_event # type: ignore + square._last_finished_event.wait() # type: ignore + assert square_minus_one.value == 48 + + with context2: + assert r2 is square.result + assert square_minus_one.value == 35 + square(8) + assert square._last_finished_event # type: ignore + square._last_finished_event.wait() # type: ignore + assert square_minus_one.value == 63 + + +# copied from hooks_test.py + + +def test_use_task_intrusive_cancel(): + result = None + last_value = 0 + seconds = 4.0 + + @solara.component + def Test(): + nonlocal result + nonlocal last_value + + def work(): + nonlocal last_value + for i in range(100): + last_value = i + # if not cancelled, might take 4 seconds + time.sleep(seconds / 100) + return 2**42 + + result = use_task(work, dependencies=[]) + return w.Label(value="test") + + solara.render_fixed(Test(), handle_error=False) + assert result is not None + assert isinstance(result, solara.Result) + result.cancel() + while result.state in [solara.ResultState.STARTING, solara.ResultState.RUNNING]: + time.sleep(0.1) + assert result.state == solara.ResultState.CANCELLED + assert last_value != 99 + + # also test retry + seconds = 0.1 + result.retry() + while result.state == solara.ResultState.CANCELLED: + time.sleep(0.1) + while result.state in [solara.ResultState.STARTING, solara.ResultState.RUNNING]: + time.sleep(0.1) + assert last_value == 99 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("prefer_threaded", [True, False]) +async def test_use_task_async(prefer_threaded): + result = None + last_value = 0 + seconds = 4.0 + + @solara.component + def Test(): + nonlocal result + nonlocal last_value + + async def work(): + # print("work", id) + nonlocal last_value + for i in range(100): + last_value = i + # print("work", id, i) + # if not cancelled, might take 4 seconds + await asyncio.sleep(seconds / 100) + return 2**42 + + result = use_task(work, dependencies=[], prefer_threaded=prefer_threaded) + # print("render with", result) + return w.Label(value="test") + + solara.render_fixed(Test(), handle_error=False) + assert result is not None + assert isinstance(result, solara.Result) + result.cancel() + # the current implementation if cancel is direct, we so we not need the code below + # n = 0 + # while result.state in [solara.ResultState.INITIAL, solara.ResultState.STARTING, solara.ResultState.RUNNING]: + # await asyncio.sleep(0.1) + # n += 1 + # if n == 100: + # raise TimeoutError("took too long, state = " + str(result.state)) + assert result.state == solara.ResultState.CANCELLED + assert last_value != 99 + + # also test retry + seconds = 0.1 + result.retry() + n = 0 + while result.state == solara.ResultState.CANCELLED: + await asyncio.sleep(0.1) + n += 1 + if n == 100: + raise TimeoutError("took too long, state = " + str(result.state)) + n = 0 + while result.state in [solara.ResultState.STARTING, solara.ResultState.RUNNING]: + await asyncio.sleep(0.1) + n += 1 + if n == 100: + raise TimeoutError("took too long, state = " + str(result.state)) + assert result.state == solara.ResultState.FINISHED + assert last_value == 99