Skip to content

Commit

Permalink
feat: /resourcez endpoint which tracks resource usage
Browse files Browse the repository at this point in the history
include threads usage, kernel usage, and memory usage etc.
  • Loading branch information
maartenbreddels committed Mar 8, 2024
1 parent aef5db5 commit 5e578b1
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 2 deletions.
21 changes: 21 additions & 0 deletions solara/server/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,22 @@ def wrapper(abs_path):
return wrapper


class ThreadDebugInfo:
lock = threading.Lock()
created = 0
running = 0
stopped = 0


Thread__init__ = threading.Thread.__init__
Thread__bootstrap = threading.Thread._bootstrap # type: ignore


def WidgetContextAwareThread__init__(self, *args, **kwargs):
Thread__init__(self, *args, **kwargs)
with ThreadDebugInfo.lock:
ThreadDebugInfo.created += 1

self.current_context = None
try:
self.current_context = kernel_context.get_current_context()
Expand All @@ -255,6 +265,17 @@ def WidgetContextAwareThread__init__(self, *args, **kwargs):


def WidgetContextAwareThread__bootstrap(self):
with ThreadDebugInfo.lock:
ThreadDebugInfo.running += 1
try:
_WidgetContextAwareThread__bootstrap(self)
finally:
with ThreadDebugInfo.lock:
ThreadDebugInfo.running -= 1
ThreadDebugInfo.stopped += 1


def _WidgetContextAwareThread__bootstrap(self):
if not hasattr(self, "current_context"):
return Thread__bootstrap(self)
if self.current_context:
Expand Down
109 changes: 107 additions & 2 deletions solara/server/starlette.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
import json
import logging
import math
import os
import sys
import threading
import typing
from typing import Dict, List, Optional, Union, cast
from typing import Any, Dict, List, Optional, Union, cast
from uuid import uuid4

import anyio
Expand Down Expand Up @@ -39,7 +40,7 @@
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.requests import HTTPConnection, Request
from starlette.responses import HTMLResponse, JSONResponse
from starlette.responses import HTMLResponse, JSONResponse, Response
from starlette.routing import Mount, Route, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.types import Receive, Scope, Send
Expand Down Expand Up @@ -84,6 +85,14 @@ def _ensure_limiter():
websockets.legacy.http.MAX_LINE = 1024 * 32


class WebsocketDebugInfo:
lock = threading.Lock()
attempts = 0
connecting = 0
open = 0
closed = 0


class WebsocketWrapper(websocket.WebsocketWrapper):
ws: starlette.websockets.WebSocket

Expand Down Expand Up @@ -183,6 +192,18 @@ async def kernels(id):

async def kernel_connection(ws: starlette.websockets.WebSocket):
_ensure_limiter()
try:
with WebsocketDebugInfo.lock:
WebsocketDebugInfo.attempts += 1
WebsocketDebugInfo.connecting += 1
await _kernel_connection(ws)
finally:
with WebsocketDebugInfo.lock:
WebsocketDebugInfo.closed += 1
WebsocketDebugInfo.open -= 1


async def _kernel_connection(ws: starlette.websockets.WebSocket):
session_id = ws.cookies.get(server.COOKIE_KEY_SESSION_ID)

if settings.oauth.private and not has_auth_support:
Expand Down Expand Up @@ -212,6 +233,9 @@ async def kernel_connection(ws: starlette.websockets.WebSocket):
return
logger.info("Solara kernel requested for session_id=%s kernel_id=%s", session_id, kernel_id)
await ws.accept()
with WebsocketDebugInfo.lock:
WebsocketDebugInfo.connecting -= 1
WebsocketDebugInfo.open += 1

def websocket_thread_runner(ws: starlette.websockets.WebSocket, portal: anyio.from_thread.BlockingPortal):
async def run():
Expand Down Expand Up @@ -408,6 +432,86 @@ def readyz(request: Request):
return JSONResponse(json, status_code=status)


async def resourcez(request: Request):
_ensure_limiter()
assert limiter is not None
data: Dict[str, Any] = {}
verbose = request.query_params.get("verbose", None) is not None
data["websockets"] = {
"attempts": WebsocketDebugInfo.attempts,
"connecting": WebsocketDebugInfo.connecting,
"open": WebsocketDebugInfo.open,
"closed": WebsocketDebugInfo.closed,
}
from . import patch

data["threads"] = {
"created": patch.ThreadDebugInfo.created,
"running": patch.ThreadDebugInfo.running,
"stopped": patch.ThreadDebugInfo.stopped,
"active": threading.active_count(),
}
contexts = list(kernel_context.contexts.values())
data["kernels"] = {
"total": len(contexts),
"has_connected": len([k for k in contexts if kernel_context.PageStatus.CONNECTED in k.page_status.values()]),
"has_disconnected": len([k for k in contexts if kernel_context.PageStatus.DISCONNECTED in k.page_status.values()]),
"has_closed": len([k for k in contexts if kernel_context.PageStatus.CLOSED in k.page_status.values()]),
"limiter": {
"total_tokens": limiter.total_tokens,
"borrowed_tokens": limiter.borrowed_tokens,
"available_tokens": limiter.available_tokens,
},
}
default_limiter = anyio.to_thread.current_default_thread_limiter()
data["anyio.to_thread.limiter"] = {
"total_tokens": default_limiter.total_tokens,
"borrowed_tokens": default_limiter.borrowed_tokens,
"available_tokens": default_limiter.available_tokens,
}
if verbose:
try:
import psutil

def expand(named_tuple):
return {key: getattr(named_tuple, key) for key in named_tuple._fields}

data["cpu"] = {}
try:
data["cpu"]["percent"] = psutil.cpu_percent()
except Exception as e:
data["cpu"]["percent"] = str(e)
try:
data["cpu"]["count"] = psutil.cpu_count()
except Exception as e:
data["cpu"]["count"] = str(e)
try:
data["cpu"]["times"] = expand(psutil.cpu_times())
data["cpu"]["times"]["per_cpu"] = [expand(x) for x in psutil.cpu_times(percpu=True)]
except Exception as e:
data["cpu"]["times"] = str(e)
try:
data["cpu"]["times_percent"] = expand(psutil.cpu_times_percent())
data["cpu"]["times_percent"]["per_cpu"] = [expand(x) for x in psutil.cpu_times_percent(percpu=True)]
except Exception as e:
data["cpu"]["times_percent"] = str(e)
try:
memory = psutil.virtual_memory()
except Exception as e:
data["memory"] = str(e)
else:
data["memory"] = {
"bytes": expand(memory),
"GB": {key: getattr(memory, key) / 1024**3 for key in memory._fields},
}

except ModuleNotFoundError:
pass

json_string = json.dumps(data, indent=2)
return Response(content=json_string, media_type="application/json")


middleware = [
Middleware(GZipMiddleware, minimum_size=1000),
]
Expand All @@ -434,6 +538,7 @@ def readyz(request: Request):
]
routes = [
Route("/readyz", endpoint=readyz),
Route("/resourcez", endpoint=resourcez),
*routes_auth,
Route("/jupyter/api/kernels/{id}", endpoint=kernels),
WebSocketRoute("/jupyter/api/kernels/{kernel_id}/{name}", endpoint=kernel_connection),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,20 @@ HTTP/1.1 200 OK
...
```

## Live resource information


To check resource usage of the server (CPU, memory, etc.), the `/resourcez` endpoint is added, and should return a 200 HTTP status code and include
various resource information, like threads created and running, number of virtual kernels, etc. in JSON format. To get also memory and cpu usage, you can include
the `?verbose` query parameter, e.g.:

```
$ curl http://localhost:8765/resourcez\?verbose
```

The JSON format may be subject to change.



## Production mode

Expand Down

0 comments on commit 5e578b1

Please sign in to comment.