From 6934ee5108ad2a05d10a50ae8b3ad16dcca22fb4 Mon Sep 17 00:00:00 2001 From: "Maarten A. Breddels" Date: Thu, 29 Jun 2023 13:20:17 +0200 Subject: [PATCH] refactor: put nullcontext and cancel_guard in util.py for reuse --- solara/hooks/misc.py | 15 +---------- solara/hooks/use_thread.py | 50 +++--------------------------------- solara/server/server.py | 2 +- solara/util.py | 52 ++++++++++++++++++++++++++++++++++++++ tests/unit/hooks_test.py | 2 ++ 5 files changed, 60 insertions(+), 61 deletions(-) diff --git a/solara/hooks/misc.py b/solara/hooks/misc.py index 817d15fdb..f4eff52df 100644 --- a/solara/hooks/misc.py +++ b/solara/hooks/misc.py @@ -1,4 +1,3 @@ -import contextlib import dataclasses import io import json @@ -35,18 +34,6 @@ MaybeResult = Union[T, Result[T]] -# not available in python 3.6 -class nullcontext(contextlib.AbstractContextManager): - def __init__(self, enter_result=None): - self.enter_result = enter_result - - def __enter__(self): - return self.enter_result - - def __exit__(self, *excinfo): - pass - - def use_retry(*actions: Callable[[], Any]): counter, set_counter = solara.use_state(0) @@ -85,7 +72,7 @@ def download(cancel: threading.Event): context: Any = None if file_object: - context = nullcontext() + context = solara.util.nullcontext() output_file = cast(IO, f.value) else: # f = cast(Result[Union[str, os.PathLike]], f) diff --git a/solara/hooks/use_thread.py b/solara/hooks/use_thread.py index 034f197de..90a20e585 100644 --- a/solara/hooks/use_thread.py +++ b/solara/hooks/use_thread.py @@ -1,28 +1,19 @@ -import contextlib import functools import inspect import logging import os -import sys import threading from typing import Callable, Iterator, Optional, TypeVar, Union, cast -import reacton - import solara from solara.datatypes import Result, ResultState +from solara.util import cancel_guard, nullcontext SOLARA_ALLOW_OTHER_TRACER = os.environ.get("SOLARA_ALLOW_OTHER_TRACER", False) in (True, "True", "true", "1") T = TypeVar("T") logger = logging.getLogger("solara.hooks.use_thread") -# inherit from BaseException so less change of being caught -# in an except -class CancelledError(BaseException): - pass - - def use_thread( callback=Union[ Callable[[threading.Event], T], @@ -50,39 +41,6 @@ def make_lock(): counter, retry = use_retry() cancel: threading.Event = solara.use_memo(make_event, [*dependencies, counter]) - @contextlib.contextmanager - def cancel_guard(): - if not intrusive_cancel: - yield - return - - def tracefunc(frame, event, arg): - # this gets called at least for every line executed - if cancel.is_set(): - rc = reacton.core._get_render_context(required=False) - # we do not want to cancel the rendering cycle - if rc is None or not rc._is_rendering: - # this will bubble up - raise CancelledError() - if prev and SOLARA_ALLOW_OTHER_TRACER: - prev(frame, event, arg) - # keep tracing: - return tracefunc - - # see https://docs.python.org/3/library/sys.html#sys.settrace - # it is for the calling thread only - # not every Python implementation has it - prev = None - if hasattr(sys, "gettrace"): - prev = sys.gettrace() - if hasattr(sys, "settrace"): - sys.settrace(tracefunc) - try: - yield - finally: - if hasattr(sys, "settrace"): - sys.settrace(prev) - def run(): set_result_state(ResultState.STARTING) @@ -122,12 +80,12 @@ def runner(): # the function calls to f. We don't want to guard around # a call to react, since that might slow down rendering # during rendering - with cancel_guard(): + with cancel_guard(cancel) if intrusive_cancel else nullcontext(): value = f() if inspect.isgenerator(value): while True: try: - with cancel_guard(): + with cancel_guard(cancel) if intrusive_cancel else nullcontext(): result.current = next(value) error.current = None except StopIteration: @@ -147,7 +105,7 @@ def runner(): logger.exception(e) set_result_state(ResultState.ERROR) return - except CancelledError: + except solara.util.CancelledError: pass # this means this thread is cancelled not be request, but because # a new thread is running, we can ignore this diff --git a/solara/server/server.py b/solara/server/server.py index c6fe2e9be..c07914068 100644 --- a/solara/server/server.py +++ b/solara/server/server.py @@ -122,7 +122,7 @@ async def app_loop(ws: websocket.WebsocketWrapper, session_id: str, connection_i run_context = viztracer.VizTracer(output_file=output_file, max_stack_depth=10) logger.warning(f"Running with tracer: {output_file}") else: - run_context = contextlib.nullcontext() + run_context = solara.util.nullcontext() kernel = context.kernel with run_context, context: diff --git a/solara/util.py b/solara/util.py index e198402ac..26c2e3836 100644 --- a/solara/util.py +++ b/solara/util.py @@ -2,15 +2,19 @@ import contextlib import os import sys +import threading from collections import abc from pathlib import Path from typing import Dict, List, Union import numpy as np import PIL.Image +import reacton import solara +SOLARA_ALLOW_OTHER_TRACER = os.environ.get("SOLARA_ALLOW_OTHER_TRACER", False) in (True, "True", "true", "1") + def github_url(file): rel_path = os.path.relpath(file, Path(solara.__file__).parent.parent) @@ -159,3 +163,51 @@ def nested_get(object, dotted_name: str, default=None): else: object = getattr(object, name) return object + + +# inherit from BaseException so less change of being caught +# in an except +class CancelledError(BaseException): + pass + + +# not available in python 3.6 +class nullcontext(contextlib.AbstractContextManager): + def __init__(self, enter_result=None): + self.enter_result = enter_result + + def __enter__(self): + return self.enter_result + + def __exit__(self, *excinfo): + pass + + +@contextlib.contextmanager +def cancel_guard(cancelled: threading.Event): + def tracefunc(frame, event, arg): + # this gets called at least for every line executed + if cancelled.is_set(): + rc = reacton.core.get_render_context(required=False) + # we do not want to cancel the rendering cycle + if rc is None or not rc._is_rendering: + # this will bubble up + raise CancelledError() + if prev and SOLARA_ALLOW_OTHER_TRACER: + prev(frame, event, arg) + # keep tracing: + return tracefunc + + # see https://docs.python.org/3/library/sys.html#sys.settrace + # it is for the calling thread only + # not every Python implementation has it + prev = None + if hasattr(sys, "gettrace"): + prev = sys.gettrace() + if hasattr(sys, "settrace"): + sys.settrace(tracefunc) + try: + yield + finally: + if hasattr(sys, "settrace"): + sys.settrace(prev) diff --git a/tests/unit/hooks_test.py b/tests/unit/hooks_test.py index 8648efa74..f09cbba91 100644 --- a/tests/unit/hooks_test.py +++ b/tests/unit/hooks_test.py @@ -176,6 +176,8 @@ def FetchFile(url=url, expected_size=None): nonlocal result result = use_fetch(url) data = result.value + if result.error: + raise result.error return w.Label(value=f"{len(data) if data else '-'}") label, rc = render_fixed(FetchFile(), handle_error=False)