Skip to content

Commit

Permalink
refactor: put nullcontext and cancel_guard in util.py for reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
maartenbreddels committed Jun 29, 2023
1 parent 40d3930 commit 6934ee5
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 61 deletions.
15 changes: 1 addition & 14 deletions solara/hooks/misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextlib
import dataclasses
import io
import json
Expand Down Expand Up @@ -35,18 +34,6 @@
MaybeResult = Union[T, Result[T]]


# not available in python 3.6
class nullcontext(contextlib.AbstractContextManager):
def __init__(self, enter_result=None):
self.enter_result = enter_result

def __enter__(self):
return self.enter_result

def __exit__(self, *excinfo):
pass


def use_retry(*actions: Callable[[], Any]):
counter, set_counter = solara.use_state(0)

Expand Down Expand Up @@ -85,7 +72,7 @@ def download(cancel: threading.Event):

context: Any = None
if file_object:
context = nullcontext()
context = solara.util.nullcontext()
output_file = cast(IO, f.value)
else:
# f = cast(Result[Union[str, os.PathLike]], f)
Expand Down
50 changes: 4 additions & 46 deletions solara/hooks/use_thread.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,19 @@
import contextlib
import functools
import inspect
import logging
import os
import sys
import threading
from typing import Callable, Iterator, Optional, TypeVar, Union, cast

import reacton

import solara
from solara.datatypes import Result, ResultState
from solara.util import cancel_guard, nullcontext

SOLARA_ALLOW_OTHER_TRACER = os.environ.get("SOLARA_ALLOW_OTHER_TRACER", False) in (True, "True", "true", "1")
T = TypeVar("T")
logger = logging.getLogger("solara.hooks.use_thread")


# inherit from BaseException so less change of being caught
# in an except
class CancelledError(BaseException):
pass


def use_thread(
callback=Union[
Callable[[threading.Event], T],
Expand Down Expand Up @@ -50,39 +41,6 @@ def make_lock():
counter, retry = use_retry()
cancel: threading.Event = solara.use_memo(make_event, [*dependencies, counter])

@contextlib.contextmanager
def cancel_guard():
if not intrusive_cancel:
yield
return

def tracefunc(frame, event, arg):
# this gets called at least for every line executed
if cancel.is_set():
rc = reacton.core._get_render_context(required=False)
# we do not want to cancel the rendering cycle
if rc is None or not rc._is_rendering:
# this will bubble up
raise CancelledError()
if prev and SOLARA_ALLOW_OTHER_TRACER:
prev(frame, event, arg)
# keep tracing:
return tracefunc

# see https://docs.python.org/3/library/sys.html#sys.settrace
# it is for the calling thread only
# not every Python implementation has it
prev = None
if hasattr(sys, "gettrace"):
prev = sys.gettrace()
if hasattr(sys, "settrace"):
sys.settrace(tracefunc)
try:
yield
finally:
if hasattr(sys, "settrace"):
sys.settrace(prev)

def run():
set_result_state(ResultState.STARTING)

Expand Down Expand Up @@ -122,12 +80,12 @@ def runner():
# the function calls to f. We don't want to guard around
# a call to react, since that might slow down rendering
# during rendering
with cancel_guard():
with cancel_guard(cancel) if intrusive_cancel else nullcontext():
value = f()
if inspect.isgenerator(value):
while True:
try:
with cancel_guard():
with cancel_guard(cancel) if intrusive_cancel else nullcontext():
result.current = next(value)
error.current = None
except StopIteration:
Expand All @@ -147,7 +105,7 @@ def runner():
logger.exception(e)
set_result_state(ResultState.ERROR)
return
except CancelledError:
except solara.util.CancelledError:
pass
# this means this thread is cancelled not be request, but because
# a new thread is running, we can ignore this
Expand Down
2 changes: 1 addition & 1 deletion solara/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ async def app_loop(ws: websocket.WebsocketWrapper, session_id: str, connection_i
run_context = viztracer.VizTracer(output_file=output_file, max_stack_depth=10)
logger.warning(f"Running with tracer: {output_file}")
else:
run_context = contextlib.nullcontext()
run_context = solara.util.nullcontext()

kernel = context.kernel
with run_context, context:
Expand Down
52 changes: 52 additions & 0 deletions solara/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@
import contextlib
import os
import sys
import threading
from collections import abc
from pathlib import Path
from typing import Dict, List, Union

import numpy as np
import PIL.Image
import reacton

import solara

SOLARA_ALLOW_OTHER_TRACER = os.environ.get("SOLARA_ALLOW_OTHER_TRACER", False) in (True, "True", "true", "1")


def github_url(file):
rel_path = os.path.relpath(file, Path(solara.__file__).parent.parent)
Expand Down Expand Up @@ -159,3 +163,51 @@ def nested_get(object, dotted_name: str, default=None):
else:
object = getattr(object, name)
return object


# inherit from BaseException so less change of being caught
# in an except
class CancelledError(BaseException):
pass


# not available in python 3.6
class nullcontext(contextlib.AbstractContextManager):
def __init__(self, enter_result=None):
self.enter_result = enter_result

def __enter__(self):
return self.enter_result

def __exit__(self, *excinfo):
pass


@contextlib.contextmanager
def cancel_guard(cancelled: threading.Event):
def tracefunc(frame, event, arg):
# this gets called at least for every line executed
if cancelled.is_set():
rc = reacton.core.get_render_context(required=False)
# we do not want to cancel the rendering cycle
if rc is None or not rc._is_rendering:
# this will bubble up
raise CancelledError()
if prev and SOLARA_ALLOW_OTHER_TRACER:
prev(frame, event, arg)
# keep tracing:
return tracefunc

# see https://docs.python.org/3/library/sys.html#sys.settrace
# it is for the calling thread only
# not every Python implementation has it
prev = None
if hasattr(sys, "gettrace"):
prev = sys.gettrace()
if hasattr(sys, "settrace"):
sys.settrace(tracefunc)
try:
yield
finally:
if hasattr(sys, "settrace"):
sys.settrace(prev)
2 changes: 2 additions & 0 deletions tests/unit/hooks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def FetchFile(url=url, expected_size=None):
nonlocal result
result = use_fetch(url)
data = result.value
if result.error:
raise result.error
return w.Label(value=f"{len(data) if data else '-'}")

label, rc = render_fixed(FetchFile(), handle_error=False)
Expand Down

0 comments on commit 6934ee5

Please sign in to comment.