diff --git a/solara/server/app.py b/solara/server/app.py index 94a0c9fb1..302ac7173 100644 --- a/solara/server/app.py +++ b/solara/server/app.py @@ -31,6 +31,13 @@ reload.reloader.start() +class Local(threading.local): + app_context_stack: Optional[List[Optional["AppContext"]]] = None + + +local = Local() + + class AppType(str, Enum): SCRIPT = "script" NOTEBOOK = "notebook" @@ -270,12 +277,16 @@ def display(self, *args): print(args) # noqa def __enter__(self): + if local.app_context_stack is None: + local.app_context_stack = [] key = get_current_thread_key() + local.app_context_stack.append(current_context.get(key, None)) current_context[key] = self def __exit__(self, *args): key = get_current_thread_key() - current_context[key] = None + assert local.app_context_stack is not None + current_context[key] = local.app_context_stack.pop() def close(self): with self: