Skip to content

Commit

Permalink
refactor: use kernel_id instead of session_id
Browse files Browse the repository at this point in the history
We have been using the jupyter session_id as a unique key for
the ill defined concept of a "AppContext".
For consistency, we now use the kernel_id as the unique key
and renamed the AppContext to VirtualKernelContext.

A VirtualKernelContext is a context manager that
needs to wrap any code that uses a kernel, either explicitly
or implicitly, such as widgets.
  • Loading branch information
maartenbreddels committed Oct 5, 2023
1 parent 800f751 commit ac85fa2
Show file tree
Hide file tree
Showing 20 changed files with 312 additions and 289 deletions.
2 changes: 1 addition & 1 deletion packages/solara-widget-manager/src/kernel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export async function connectKernel(
// if (!model) {
// return;
// }
const model = { 'id': 'solara-id', 'name': 'solara-name' }
const model = { 'id': kernelId, 'name': 'solara-name' }
const kernel = new KernelConnection({ model, serverSettings });
return kernel;
}
Expand Down
4 changes: 2 additions & 2 deletions solara/scope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def __init__(self, name="connection"):

def _get_dict(self) -> MutableMapping:
if _in_solara_server():
import solara.server.app
import solara.server.kernel_context

context = solara.server.app.get_current_context()
context = solara.server.kernel_context.get_current_context()
if self.name not in context.user_dicts:
with self.lock:
if self.name not in context.user_dicts:
Expand Down
186 changes: 17 additions & 169 deletions solara/server/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import dataclasses
import importlib.util
import logging
import os
Expand All @@ -8,17 +7,16 @@
import traceback
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, cast
from typing import Any, Dict, List, Optional, cast

import ipywidgets as widgets
import reacton
from ipywidgets import DOMWidget, Widget
from reacton.core import Element, render

import solara

from . import kernel, reload, settings, websocket
from .kernel import Kernel, WebsocketStreamWrapper
from . import kernel_context, reload, settings
from .kernel import Kernel
from .utils import pdb_guard

WebSocket = Any
Expand All @@ -30,13 +28,6 @@
reload.reloader.start()


class Local(threading.local):
app_context_stack: Optional[List[Optional["AppContext"]]] = None


local = Local()


class AppType(str, Enum):
SCRIPT = "script"
NOTEBOOK = "notebook"
Expand Down Expand Up @@ -68,13 +59,13 @@ def __init__(self, name, default_app_name="Page"):
self.name = name
self.path: Path = Path(self.name).resolve()
try:
context = get_current_context()
context = kernel_context.get_current_context()
except RuntimeError:
context = None
if context is not None:
raise RuntimeError(f"We should not have an existing Solara app context when running an app for the first time: {context}")
app_context = create_dummy_context()
with app_context:
dummy_kernel_context = kernel_context.create_dummy_context()
with dummy_kernel_context:
app = self._execute()

self._first_execute_app = app
Expand All @@ -85,7 +76,7 @@ def __init__(self, name, default_app_name="Page"):
if mod.__file__ is not None:
package_root_path = Path(mod.__file__).parent
reload.reloader.root_path = package_root_path
app_context.close()
dummy_kernel_context.close()

def _execute(self):
logger.info("Executing %s", self.name)
Expand Down Expand Up @@ -187,8 +178,8 @@ def add_path():

def close(self):
reload.reloader.on_change = None
context_values = list(contexts.values())
contexts.clear()
context_values = list(kernel_context.contexts.values())
kernel_context.contexts.clear()
for context in context_values:
context.close()

Expand All @@ -204,7 +195,7 @@ def on_file_change(self, name):
if path.suffix == ".vue":
logger.info("Vue file changed: %s", name)
template_content = path.read_text()
for context in list(contexts.values()):
for context in list(kernel_context.contexts.values()):
with context:
for filepath, widget in context.templates.items():
if filepath == str(path):
Expand All @@ -224,7 +215,7 @@ def reload(self):

solara.lab.toestand.ConnectionStore._type_counter.clear()

context_values = list(contexts.values())
context_values = list(kernel_context.contexts.values())
# save states into the context so the hot reload will
# keep the same state
for context in context_values:
Expand Down Expand Up @@ -256,136 +247,6 @@ def reload(self):
context.reload()


@dataclasses.dataclass
class AppContext:
id: str
kernel: kernel.Kernel
control_sockets: List[WebSocket] = dataclasses.field(default_factory=list)
# this is the 'private' version of the normally global ipywidgets.Widgets.widget dict
# see patch.py
widgets: Dict[str, Widget] = dataclasses.field(default_factory=dict)
# same, for ipyvue templates
# see patch.py
templates: Dict[str, Widget] = dataclasses.field(default_factory=dict)
user_dicts: Dict[str, Dict] = dataclasses.field(default_factory=dict)
# anything we need to attach to the context
# e.g. for a react app the render context, so that we can store/restore the state
app_object: Optional[Any] = None
reload: Callable = lambda: None
state: Any = None
container: Optional[DOMWidget] = None

def display(self, *args):
print(args) # noqa

def __enter__(self):
if local.app_context_stack is None:
local.app_context_stack = []
key = get_current_thread_key()
local.app_context_stack.append(current_context.get(key, None))
current_context[key] = self

def __exit__(self, *args):
key = get_current_thread_key()
assert local.app_context_stack is not None
current_context[key] = local.app_context_stack.pop()

def close(self):
with self:
if self.app_object is not None:
if isinstance(self.app_object, reacton.core._RenderContext):
try:
self.app_object.close()
except Exception as e:
logger.exception("Could not close render context: %s", e)
# we want to continue, so we at least close all widgets
widgets.Widget.close_all()
# what if we reference each other
# import gc
# gc.collect()
if self.id in contexts:
del contexts[self.id]

def _state_reset(self):
state_directory = Path(".") / "states"
state_directory.mkdir(exist_ok=True)
path = state_directory / f"{self.id}.pickle"
path = path.absolute()
try:
path.unlink()
except: # noqa
pass
del contexts[self.id]
key = get_current_thread_key()
del current_context[key]

def state_save(self, state_directory: os.PathLike):
path = Path(state_directory) / f"{self.id}.pickle"
render_context = self.app_object
if render_context is not None:
render_context = cast(reacton.core._RenderContext, render_context)
state = render_context.state_get()
with path.open("wb") as f:
logger.debug("State: %r", state)
pickle.dump(state, f)


contexts: Dict[str, AppContext] = {}
# maps from thread key to AppContext, if AppContext is None, it exists, but is not set as current
current_context: Dict[str, Optional[AppContext]] = {}


def create_dummy_context():
from . import kernel

app_context = AppContext(
id="dummy",
kernel=kernel.Kernel(),
)
return app_context


def get_current_thread_key() -> str:
thread = threading.current_thread()
return get_thread_key(thread)


def get_thread_key(thread: threading.Thread) -> str:
thread_key = thread._name + str(thread._ident) # type: ignore
return thread_key


def set_context_for_thread(context: AppContext, thread: threading.Thread):
key = get_thread_key(thread)
current_context[key] = context


def has_current_context() -> bool:
thread_key = get_current_thread_key()
return (thread_key in current_context) and (current_context[thread_key] is not None)


def get_current_context() -> AppContext:
thread_key = get_current_thread_key()
if thread_key not in current_context:
raise RuntimeError(
f"Tried to get the current context for thread {thread_key}, but no known context found. This might be a bug in Solara. "
f"(known contexts: {list(current_context.keys())}"
)
context = current_context[thread_key]
if context is None:
raise RuntimeError(
f"Tried to get the current context for thread {thread_key!r}, although the context is know, it was not set for this thread. "
+ "This might be a bug in Solara."
)
return context


def set_current_context(context: Optional[AppContext]):
thread_key = get_current_thread_key()
current_context[thread_key] = context


def _run_app(
app_state,
app_script: AppScript,
Expand All @@ -398,7 +259,7 @@ def _run_app(
if app_state:
logger.info("Restoring state: %r", app_state)

context = get_current_context()
context = kernel_context.get_current_context()
container = context.container
if isinstance(main_object, widgets.Widget):
return main_object, render_context
Expand Down Expand Up @@ -438,7 +299,7 @@ def _run_app(
def load_app_widget(app_state, app_script: AppScript, pathname: str):
# load the app, and set it at the child of the context's container
app_state_initial = app_state
context = get_current_context()
context = kernel_context.get_current_context()
container = context.container
assert container is not None
try:
Expand Down Expand Up @@ -482,18 +343,18 @@ def on_msg(msg):
path = data.get("path", "")
app_name = data.get("appName") or "__default__"
app = apps[app_name]
context = get_current_context()
context = kernel_context.get_current_context()
import ipyvuetify

container = ipyvuetify.Html(tag="div")
context.container = container
load_app_widget(None, app, path)
comm.send({"method": "finished", "widget_id": context.container._model_id})
elif method == "check":
context = get_current_context()
context = kernel_context.get_current_context()
elif method == "reload":
assert app is not None
context = get_current_context()
context = kernel_context.get_current_context()
path = data.get("path", "")
with context:
load_app_widget(context.state, app, path)
Expand All @@ -509,27 +370,14 @@ def reload():
logger.debug(f"Send reload to client: {context.id}")
comm.send({"method": "reload"})

context = get_current_context()
context = kernel_context.get_current_context()
context.reload = reload


def register_solara_comm_target(kernel: Kernel):
kernel.comm_manager.register_target("solara.control", solara_comm_target)


def initialize_virtual_kernel(context_id: str, websocket: websocket.WebsocketWrapper):
kernel = Kernel()
logger.info("new virtual kernel: %s", context_id)
context = contexts[context_id] = AppContext(id=context_id, kernel=kernel, control_sockets=[], widgets={}, templates={})
with context:
widgets.register_comm_target(kernel)
register_solara_comm_target(kernel)
assert kernel is Kernel.instance()
kernel.shell_stream = WebsocketStreamWrapper(websocket, "shell")
kernel.control_stream = WebsocketStreamWrapper(websocket, "control")
kernel.session.websockets.add(websocket)


from . import patch # noqa

patch.patch()
Expand Down
20 changes: 10 additions & 10 deletions solara/server/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def allowed():
from solara.server.threaded import ServerBase

from . import app as appmod
from . import cdn_helper, server, settings, websocket
from . import cdn_helper, kernel_context, server, settings, websocket

os.environ["SERVER_SOFTWARE"] = "solara/" + str(solara.__version__)

Expand Down Expand Up @@ -111,8 +111,8 @@ def kernels(id):
return {"name": "lala", "id": "dsa"}


@websocket_extension.route("/jupyter/api/kernels/<id>/<name>")
def kernels_connection(ws: simple_websocket.Server, id: str, name: str):
@websocket_extension.route("/jupyter/api/kernels/<kernel_id>/<name>")
def kernels_connection(ws: simple_websocket.Server, kernel_id: str, name: str):
if not settings.main.base_url:
settings.main.base_url = url_for("blueprint-solara.read_root", _external=True)
if settings.oauth.private and not has_solara_enterprise:
Expand All @@ -127,24 +127,24 @@ def kernels_connection(ws: simple_websocket.Server, id: str, name: str):
user = None

try:
connection_id = request.args["session_id"]
page_id = request.args["session_id"]
session_id = request.cookies.get(server.COOKIE_KEY_SESSION_ID)
logger.info("Solara kernel requested for session_id=%s connection_id=%s", session_id, connection_id)
logger.info("Solara kernel requested for session_id=%s kernel_id=%s", session_id, kernel_id)
if session_id is None:
logger.error("no session cookie")
ws.close()
return
ws_wrapper = WebsocketWrapper(ws)
asyncio.run(server.app_loop(ws_wrapper, session_id=session_id, connection_id=connection_id, user=user))
asyncio.run(server.app_loop(ws_wrapper, session_id=session_id, kernel_id=kernel_id, page_id=page_id, user=user))
except: # noqa
logger.exception("Error in kernel handler")
raise


@blueprint.route("/_solara/api/close/<connection_id>", methods=["GET", "POST"])
def close(connection_id: str):
if connection_id in appmod.contexts:
context = appmod.contexts[connection_id]
@blueprint.route("/_solara/api/close/<kernel_id>", methods=["GET", "POST"])
def close(kernel_id: str):
if kernel_id in kernel_context.contexts:
context = kernel_context.contexts[kernel_id]
context.close()
return ""

Expand Down
2 changes: 1 addition & 1 deletion solara/server/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def publish_msg(self, msg_type, data=None, metadata=None, buffers=None, **keys):
comm.create_comm = Comm

def get_comm_manager():
from .app import get_current_context, has_current_context
from .kernel_context import get_current_context, has_current_context

if has_current_context():
return get_current_context().kernel.comm_manager
Expand Down
Loading

0 comments on commit ac85fa2

Please sign in to comment.