Skip to content

Commit

Permalink
make use_task with overload
Browse files Browse the repository at this point in the history
  • Loading branch information
maartenbreddels authored and iisakkirotko committed Feb 5, 2024
1 parent a94b0ef commit 06c4657
Showing 1 changed file with 33 additions and 10 deletions.
43 changes: 33 additions & 10 deletions solara/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,23 @@ def create_task():
return wrapper(function)


def use_task(f: Callable[P, R], dependencies=[], *, raise_error=True) -> Union[Task[P, R], solara.Result[R]]:
@overload
def use_task(
f: None = None,
) -> Callable[[Callable[[], R]], solara.Result[R]]:
...


@overload
def use_task(
f: Callable[P, R],
) -> solara.Result[R]:
...


def use_task(
f: Union[None, Callable[[], R]] = None, dependencies=[], *, raise_error=True
) -> Union[Callable[[Callable[[], R]], solara.Result[R]], solara.Result[R]]:
"""Run a function or coroutine as a task and return the result.
## Example
Expand Down Expand Up @@ -398,14 +414,21 @@ async def square():
"""
task_instance = solara.use_memo(lambda: task(f), dependencies=dependencies)

def run():
task_instance()
return task_instance.cancel
def wrapper(f):
task_instance = solara.use_memo(lambda: task(f), dependencies=dependencies)

def run():
task_instance()
return task_instance.cancel

solara.use_effect(run, dependencies=dependencies)
if raise_error:
if task_instance.state == solara.ResultState.ERROR and task_instance.error is not None:
raise task_instance.error
return task_instance.result.value
solara.use_effect(run, dependencies=dependencies)
if raise_error:
if task_instance.state == solara.ResultState.ERROR and task_instance.error is not None:
raise task_instance.error
return task_instance.result.value

if f is None:
return wrapper
else:
return wrapper(f)

0 comments on commit 06c4657

Please sign in to comment.