diff --git a/solara/server/patch.py b/solara/server/patch.py index 15f01072c..d81211c3e 100644 --- a/solara/server/patch.py +++ b/solara/server/patch.py @@ -254,12 +254,22 @@ def wrapper(abs_path): return wrapper +class ThreadDebugInfo: + lock = threading.Lock() + created = 0 + running = 0 + stopped = 0 + + Thread__init__ = threading.Thread.__init__ Thread__bootstrap = threading.Thread._bootstrap # type: ignore def WidgetContextAwareThread__init__(self, *args, **kwargs): Thread__init__(self, *args, **kwargs) + with ThreadDebugInfo.lock: + ThreadDebugInfo.created += 1 + self.current_context = None try: self.current_context = kernel_context.get_current_context() @@ -268,6 +278,17 @@ def WidgetContextAwareThread__init__(self, *args, **kwargs): def WidgetContextAwareThread__bootstrap(self): + with ThreadDebugInfo.lock: + ThreadDebugInfo.running += 1 + try: + _WidgetContextAwareThread__bootstrap(self) + finally: + with ThreadDebugInfo.lock: + ThreadDebugInfo.running -= 1 + ThreadDebugInfo.stopped += 1 + + +def _WidgetContextAwareThread__bootstrap(self): if not hasattr(self, "current_context"): return Thread__bootstrap(self) if self.current_context: diff --git a/solara/server/starlette.py b/solara/server/starlette.py index 512748faf..fdd27cd55 100644 --- a/solara/server/starlette.py +++ b/solara/server/starlette.py @@ -1,11 +1,12 @@ import asyncio +import json import logging import math import os import sys import threading import typing -from typing import Dict, List, Optional, Union, cast +from typing import Any, Dict, List, Optional, Union, cast from uuid import uuid4 import anyio @@ -39,7 +40,7 @@ from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.gzip import GZipMiddleware from starlette.requests import HTTPConnection, Request -from starlette.responses import HTMLResponse, JSONResponse +from starlette.responses import HTMLResponse, JSONResponse, Response from starlette.routing import Mount, Route, WebSocketRoute from starlette.staticfiles import StaticFiles from starlette.types import Receive, Scope, Send @@ -84,6 +85,14 @@ def _ensure_limiter(): websockets.legacy.http.MAX_LINE = 1024 * 32 +class WebsocketDebugInfo: + lock = threading.Lock() + attempts = 0 + connecting = 0 + open = 0 + closed = 0 + + class WebsocketWrapper(websocket.WebsocketWrapper): ws: starlette.websockets.WebSocket @@ -183,6 +192,18 @@ async def kernels(id): async def kernel_connection(ws: starlette.websockets.WebSocket): _ensure_limiter() + try: + with WebsocketDebugInfo.lock: + WebsocketDebugInfo.attempts += 1 + WebsocketDebugInfo.connecting += 1 + await _kernel_connection(ws) + finally: + with WebsocketDebugInfo.lock: + WebsocketDebugInfo.closed += 1 + WebsocketDebugInfo.open -= 1 + + +async def _kernel_connection(ws: starlette.websockets.WebSocket): session_id = ws.cookies.get(server.COOKIE_KEY_SESSION_ID) if settings.oauth.private and not has_auth_support: @@ -212,6 +233,9 @@ async def kernel_connection(ws: starlette.websockets.WebSocket): return logger.info("Solara kernel requested for session_id=%s kernel_id=%s", session_id, kernel_id) await ws.accept() + with WebsocketDebugInfo.lock: + WebsocketDebugInfo.connecting -= 1 + WebsocketDebugInfo.open += 1 def websocket_thread_runner(ws: starlette.websockets.WebSocket, portal: anyio.from_thread.BlockingPortal): async def run(): @@ -408,6 +432,86 @@ def readyz(request: Request): return JSONResponse(json, status_code=status) +async def resourcez(request: Request): + _ensure_limiter() + assert limiter is not None + data: Dict[str, Any] = {} + verbose = request.query_params.get("verbose", None) is not None + data["websockets"] = { + "attempts": WebsocketDebugInfo.attempts, + "connecting": WebsocketDebugInfo.connecting, + "open": WebsocketDebugInfo.open, + "closed": WebsocketDebugInfo.closed, + } + from . import patch + + data["threads"] = { + "created": patch.ThreadDebugInfo.created, + "running": patch.ThreadDebugInfo.running, + "stopped": patch.ThreadDebugInfo.stopped, + "active": threading.active_count(), + } + contexts = list(kernel_context.contexts.values()) + data["kernels"] = { + "total": len(contexts), + "has_connected": len([k for k in contexts if kernel_context.PageStatus.CONNECTED in k.page_status.values()]), + "has_disconnected": len([k for k in contexts if kernel_context.PageStatus.DISCONNECTED in k.page_status.values()]), + "has_closed": len([k for k in contexts if kernel_context.PageStatus.CLOSED in k.page_status.values()]), + "limiter": { + "total_tokens": limiter.total_tokens, + "borrowed_tokens": limiter.borrowed_tokens, + "available_tokens": limiter.available_tokens, + }, + } + default_limiter = anyio.to_thread.current_default_thread_limiter() + data["anyio.to_thread.limiter"] = { + "total_tokens": default_limiter.total_tokens, + "borrowed_tokens": default_limiter.borrowed_tokens, + "available_tokens": default_limiter.available_tokens, + } + if verbose: + try: + import psutil + + def expand(named_tuple): + return {key: getattr(named_tuple, key) for key in named_tuple._fields} + + data["cpu"] = {} + try: + data["cpu"]["percent"] = psutil.cpu_percent() + except Exception as e: + data["cpu"]["percent"] = str(e) + try: + data["cpu"]["count"] = psutil.cpu_count() + except Exception as e: + data["cpu"]["count"] = str(e) + try: + data["cpu"]["times"] = expand(psutil.cpu_times()) + data["cpu"]["times"]["per_cpu"] = [expand(x) for x in psutil.cpu_times(percpu=True)] + except Exception as e: + data["cpu"]["times"] = str(e) + try: + data["cpu"]["times_percent"] = expand(psutil.cpu_times_percent()) + data["cpu"]["times_percent"]["per_cpu"] = [expand(x) for x in psutil.cpu_times_percent(percpu=True)] + except Exception as e: + data["cpu"]["times_percent"] = str(e) + try: + memory = psutil.virtual_memory() + except Exception as e: + data["memory"] = str(e) + else: + data["memory"] = { + "bytes": expand(memory), + "GB": {key: getattr(memory, key) / 1024**3 for key in memory._fields}, + } + + except ModuleNotFoundError: + pass + + json_string = json.dumps(data, indent=2) + return Response(content=json_string, media_type="application/json") + + middleware = [ Middleware(GZipMiddleware, minimum_size=1000), ] @@ -434,6 +538,7 @@ def readyz(request: Request): ] routes = [ Route("/readyz", endpoint=readyz), + Route("/resourcez", endpoint=resourcez), *routes_auth, Route("/jupyter/api/kernels/{id}", endpoint=kernels), WebSocketRoute("/jupyter/api/kernels/{kernel_id}/{name}", endpoint=kernel_connection), diff --git a/solara/website/pages/docs/content/20-understanding/50-solara-server.md b/solara/website/pages/docs/content/20-understanding/50-solara-server.md index f19116b08..e7835c6a3 100644 --- a/solara/website/pages/docs/content/20-understanding/50-solara-server.md +++ b/solara/website/pages/docs/content/20-understanding/50-solara-server.md @@ -48,6 +48,20 @@ HTTP/1.1 200 OK ... ``` +## Live resource information + + +To check resource usage of the server (CPU, memory, etc.), the `/resourcez` endpoint is added, and should return a 200 HTTP status code and include +various resource information, like threads created and running, number of virtual kernels, etc. in JSON format. To get also memory and cpu usage, you can include +the `?verbose` query parameter, e.g.: + +``` +$ curl http://localhost:8765/resourcez\?verbose +``` + +The JSON format may be subject to change. + + ## Production mode