Skip to content

Commit

Permalink
feat: allow reconnecting to existing kernel and display widget by id
Browse files Browse the repository at this point in the history
This allows an ipypopout or a similar library to open a new browser
window and show a widget that is already running in the main window.

Note that this is limited to the same browser, because the session_id
is required to be the same. This is a security feature.
  • Loading branch information
maartenbreddels committed Oct 4, 2023
1 parent 707b481 commit d602541
Show file tree
Hide file tree
Showing 17 changed files with 501 additions and 55 deletions.
5 changes: 5 additions & 0 deletions packages/solara-widget-manager/src/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ export class WidgetManager extends JupyterLabManager {
}
}

async fetchAll() {
// fetch all widgets
await this._loadFromKernel();
}

async run(appName: string, path: string) {
// used for routing
// should be similar to what we do in navigator.vue
Expand Down
3 changes: 2 additions & 1 deletion solara/server/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,10 @@ def kernels_connection(ws: simple_websocket.Server, kernel_id: str, name: str):

@blueprint.route("/_solara/api/close/<kernel_id>", methods=["GET", "POST"])
def close(kernel_id: str):
page_id = request.args["session_id"]
if kernel_id in kernel_context.contexts:
context = kernel_context.contexts[kernel_id]
context.close()
context.page_close(page_id)
return ""


Expand Down
7 changes: 7 additions & 0 deletions solara/server/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,13 @@ def __init__(self, *args, **kwargs):
super(SessionWebsocket, self).__init__(*args, **kwargs)
self.websockets: Set[websocket.WebsocketWrapper] = set() # map from .. msg id to websocket?

def close(self):
for ws in list(self.websockets):
try:
ws.close()
except: # noqa
pass

def send(self, stream, msg_or_type, content=None, parent=None, ident=None, buffers=None, track=False, header=None, metadata=None):
try:
if isinstance(msg_or_type, dict):
Expand Down
141 changes: 133 additions & 8 deletions solara/server/kernel_context.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import asyncio
import dataclasses
import enum
import logging
import os
import pickle
import threading
import time
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, cast

import ipywidgets as widgets
import reacton
from ipywidgets import DOMWidget, Widget

import solara.server.settings
import solara.util

from . import kernel, kernel_context, websocket
from .kernel import Kernel, WebsocketStreamWrapper

Expand All @@ -24,10 +30,20 @@ class Local(threading.local):
local = Local()


class PageStatus(enum.Enum):
CONNECTED = "connected"
DISCONNECTED = "disconnected"
CLOSED = "closed"


@dataclasses.dataclass
class VirtualKernelContext:
id: str
kernel: kernel.Kernel
# we keep track of the session id to prevent kernel hijacking
# to 'steal' a kernel, one would need to know the session id
# *and* the kernel id
session_id: str
control_sockets: List[WebSocket] = dataclasses.field(default_factory=list)
# this is the 'private' version of the normally global ipywidgets.Widgets.widget dict
# see patch.py
Expand All @@ -42,6 +58,11 @@ class VirtualKernelContext:
reload: Callable = lambda: None
state: Any = None
container: Optional[DOMWidget] = None
# we track which pages are connected to implement kernel culling
page_status: Dict[str, PageStatus] = dataclasses.field(default_factory=dict)
# only used for testing
_last_kernel_cull_task: "Optional[asyncio.Future[None]]" = None
closed: bool = False

def display(self, *args):
print(args) # noqa
Expand All @@ -59,6 +80,7 @@ def __exit__(self, *args):
current_context[key] = local.kernel_context_stack.pop()

def close(self):
logger.info("Shut down virtual kernel: %s", self.id)
with self:
if self.app_object is not None:
if isinstance(self.app_object, reacton.core._RenderContext):
Expand All @@ -71,8 +93,10 @@ def close(self):
# what if we reference each other
# import gc
# gc.collect()
self.kernel.session.close()
if self.id in contexts:
del contexts[self.id]
self.closed = True

def _state_reset(self):
state_directory = Path(".") / "states"
Expand All @@ -97,6 +121,94 @@ def state_save(self, state_directory: os.PathLike):
logger.debug("State: %r", state)
pickle.dump(state, f)

def page_connect(self, page_id: str):
logger.info("Connect page %s for kernel %s", page_id, self.id)
assert self.page_status.get(page_id) != PageStatus.CLOSED, "cannot connect with the same page_id after a close"
self.page_status[page_id] = PageStatus.CONNECTED
if self._last_kernel_cull_task:
self._last_kernel_cull_task.cancel()

def page_disconnect(self, page_id: str) -> "asyncio.Future[None]":
"""Signal that a page has disconnected, and schedule a kernel cull if needed.
During the kernel reconnect window, we will keep the kernel alive, even if all pages have disconnected.
Returns a future that is set when the kernel cull is done.
The scheduled kernel cull can be cancelled when a new page connects, a new disconnect is scheduled,
or a page if explicitly closed.
"""
logger.info("Disconnect page %s for kernel %s", page_id, self.id)
future: "asyncio.Future[None]" = asyncio.Future()
self.page_status[page_id] = PageStatus.DISCONNECTED
current_event_loop = asyncio.get_event_loop()

async def kernel_cull():
try:
cull_timeout_sleep_seconds = solara.util.parse_timedelta(solara.server.settings.kernel.cull_timeout)
logger.info("Scheduling kernel cull, will wait for max %s before shutting down the virtual kernel %s", cull_timeout_sleep_seconds, self.id)
await asyncio.sleep(cull_timeout_sleep_seconds)
has_connected_pages = PageStatus.CONNECTED in self.page_status.values()
if has_connected_pages:
logger.info("We have (re)connected pages, keeping the virtual kernel %s alive", self.id)
else:
logger.info("No connected pages, and timeout reached, shutting down virtual kernel %s", self.id)
self.close()
current_event_loop.call_soon_threadsafe(future.set_result, None)
except asyncio.CancelledError:
current_event_loop.call_soon_threadsafe(future.cancel, "cancelled because a new cull task was scheduled")
raise

has_connected_pages = PageStatus.CONNECTED in self.page_status.values()
if not has_connected_pages:
# when we have no connected pages, we will schedule a kernel cull
if self._last_kernel_cull_task:
self._last_kernel_cull_task.cancel()

async def create_task():
task = asyncio.create_task(kernel_cull())
# create a reference to the task so we can cancel it later
self._last_kernel_cull_task = task
await task

asyncio.run_coroutine_threadsafe(create_task(), keep_alive_event_loop)
else:
future.set_result(None)
return future

def page_close(self, page_id: str):
"""Signal that a page has closed, and close the context if needed.
Closing the browser tab or a page navigation means an explicit close, which is
different from a websocket/page disconnect, which we might want to recover from.
"""
self.page_status[page_id] = PageStatus.CLOSED
logger.info("Close page %s for kernel %s", page_id, self.id)
has_connected_pages = PageStatus.CONNECTED in self.page_status.values()
has_disconnected_pages = PageStatus.DISCONNECTED in self.page_status.values()
if not (has_connected_pages or has_disconnected_pages):
logger.info("No connected or disconnected pages, shutting down virtual kernel %s", self.id)
if self._last_kernel_cull_task:
self._last_kernel_cull_task.cancel()
self.close()


try:
# Normal Python
keep_alive_event_loop = asyncio.new_event_loop()

def _run():
asyncio.set_event_loop(keep_alive_event_loop)
try:
keep_alive_event_loop.run_forever()
except Exception:
logger.exception("Error in keep alive event loop")
raise

threading.Thread(target=_run, daemon=True).start()
except RuntimeError:
# Emscripten/pyodide/lite
keep_alive_event_loop = asyncio.get_event_loop()

contexts: Dict[str, VirtualKernelContext] = {}
# maps from thread key to VirtualKernelContext, if VirtualKernelContext is None, it exists, but is not set as current
Expand All @@ -108,6 +220,7 @@ def create_dummy_context():

kernel_context = VirtualKernelContext(
id="dummy",
session_id="dummy",
kernel=kernel.Kernel(),
)
return kernel_context
Expand Down Expand Up @@ -154,15 +267,27 @@ def set_current_context(context: Optional[VirtualKernelContext]):
current_context[thread_key] = context


def initialize_virtual_kernel(kernel_id: str, websocket: websocket.WebsocketWrapper):
import solara.server.app

kernel = Kernel()
logger.info("new virtual kernel: %s", kernel_id)
context = contexts[kernel_id] = VirtualKernelContext(id=kernel_id, kernel=kernel, control_sockets=[], widgets={}, templates={})
def initialize_virtual_kernel(session_id: str, kernel_id: str, websocket: websocket.WebsocketWrapper):
from solara.server import app as appmodule

if kernel_id in contexts:
logger.info("reusing virtual kernel: %s", kernel_id)
context = contexts[kernel_id]
if context.session_id != session_id:
logger.critical("Session id mismatch when reusing kernel (hack attempt?): %s != %s", context.session_id, session_id)
websocket.send_text("Session id mismatch when reusing kernel (hack attempt?)")
# to avoid very fast reconnects (we are in a thread anyway)
time.sleep(0.5)
raise ValueError("Session id mismatch")
kernel = context.kernel
else:
kernel = Kernel()
logger.info("new virtual kernel: %s", kernel_id)
context = contexts[kernel_id] = VirtualKernelContext(id=kernel_id, session_id=session_id, kernel=kernel, control_sockets=[], widgets={}, templates={})
with context:
widgets.register_comm_target(kernel)
appmodule.register_solara_comm_target(kernel)
with context:
widgets.register_comm_target(kernel)
solara.server.app.register_solara_comm_target(kernel)
assert kernel is Kernel.instance()
kernel.shell_stream = WebsocketStreamWrapper(websocket, "shell")
kernel.control_stream = WebsocketStreamWrapper(websocket, "control")
Expand Down
58 changes: 31 additions & 27 deletions solara/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def is_ready(url) -> bool:


async def app_loop(ws: websocket.WebsocketWrapper, session_id: str, kernel_id: str, page_id: str, user: dict = None):
context = initialize_virtual_kernel(kernel_id, ws)
context = initialize_virtual_kernel(session_id, kernel_id, ws)
if context is None:
logging.warning("invalid kernel id: %r", kernel_id)
# to avoid very fast reconnects (we are in a thread anyway)
Expand All @@ -124,35 +124,39 @@ async def app_loop(ws: websocket.WebsocketWrapper, session_id: str, kernel_id: s
run_context = solara.util.nullcontext()

kernel = context.kernel
with run_context, context:
if user:
from solara_enterprise.auth import user as solara_user
try:
context.page_connect(page_id)
with run_context, context:
if user:
from solara_enterprise.auth import user as solara_user

solara_user.set(user)
solara_user.set(user)

while True:
try:
message = await ws.receive()
except websocket.WebSocketDisconnect:
while True:
try:
context.kernel.session.websockets.remove(ws)
except KeyError:
pass
logger.debug("Disconnected")
return
t0 = time.time()
if isinstance(message, str):
msg = json.loads(message)
else:
msg = deserialize_binary_message(message)
t1 = time.time()
if not process_kernel_messages(kernel, msg):
# if we shut down the kernel, we do not keep the page session alive
context.close()
return
t2 = time.time()
if settings.main.timing:
print(f"timing: total={t2-t0:.3f}s, deserialize={t1-t0:.3f}s, kernel={t2-t1:.3f}s") # noqa: T201
message = await ws.receive()
except websocket.WebSocketDisconnect:
try:
context.kernel.session.websockets.remove(ws)
except KeyError:
pass
logger.debug("Disconnected")
break
t0 = time.time()
if isinstance(message, str):
msg = json.loads(message)
else:
msg = deserialize_binary_message(message)
t1 = time.time()
if not process_kernel_messages(kernel, msg):
# if we shut down the kernel, we do not keep the page session alive
context.close()
return
t2 = time.time()
if settings.main.timing:
print(f"timing: total={t2-t0:.3f}s, deserialize={t1-t0:.3f}s, kernel={t2-t1:.3f}s") # noqa: T201
finally:
context.page_disconnect(page_id)


def process_kernel_messages(kernel: Kernel, msg: Dict) -> bool:
Expand Down
13 changes: 13 additions & 0 deletions solara/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from filelock import FileLock

import solara.util
from solara.minisettings import BaseSettings

from .. import ( # noqa # sidefx is that this module creates the ~/.solara directory
Expand Down Expand Up @@ -85,6 +86,15 @@ class Config:
env_file = ".env"


class Kernel(BaseSettings):
cull_timeout: str = "24h"

class Config:
env_prefix = "solara_kernel_"
case_sensitive = False
env_file = ".env"


AUTH0_TEST_CLIENT_ID = "cW7owP5Q52YHMZAnJwT8FPlH2ZKvvL3U"
AUTH0_TEST_CLIENT_SECRET = "zxITXxoz54OjuSmdn-PluQgAwbeYyoB7ALlnLoodftvAn81usDXW0quchvoNvUYD"
AUTH0_TEST_API_BASE_URL = "dev-y02f2bpr8skxu785.us.auth0.com"
Expand Down Expand Up @@ -156,6 +166,9 @@ class Config:
assets = Assets()
oauth = OAuth()
session = Session()
kernel = Kernel()
# fail early
solara.util.parse_timedelta(kernel.cull_timeout)

if assets.proxy:
try:
Expand Down
3 changes: 2 additions & 1 deletion solara/server/starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,10 @@ async def run():

def close(request: Request):
kernel_id = request.path_params["kernel_id"]
page_id = request.query_params["session_id"]
if kernel_id in kernel_context.contexts:
context = kernel_context.contexts[kernel_id]
context.close()
context.page_close(page_id)
response = HTMLResponse(content="", status_code=200)
return response

Expand Down
Loading

0 comments on commit d602541

Please sign in to comment.