Skip to content

Commit

Permalink
Move WebSocket client to web worker
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jul 16, 2024
1 parent 2c7303e commit 50d8302
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 137 deletions.
2 changes: 1 addition & 1 deletion examples/experimental/gaussian_splats.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile:
}


def main(splat_paths: tuple[Path, ...], test_multisplat: bool = False) -> None:
def main(splat_paths: tuple[Path, ...]) -> None:
server = viser.ViserServer(share=True)
server.gui.configure_theme(dark_mode=True)
gui_reset_up = server.gui.add_button(
Expand Down
17 changes: 9 additions & 8 deletions src/viser/client/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import { Titlebar } from "./Titlebar";
import { ViserModal } from "./Modal";
import { useSceneTreeState } from "./SceneTreeState";
import { GetRenderRequestMessage, Message } from "./WebsocketMessages";
import { makeThrottledMessageSender } from "./WebsocketFunctions";
import { useThrottledMessageSender } from "./WebsocketFunctions";
import { useDisclosure } from "@mantine/hooks";
import { rayToViserCoords } from "./WorldTransformUtils";
import { ndcFromPointerXy, opencvXyFromPointerXy } from "./ClickUtils";
Expand All @@ -57,7 +57,7 @@ export type ViewerContextContents = {
// Useful references.
// TODO: there's really no reason these all need to be their own ref objects.
// We could have just one ref to a global mutable struct.
websocketRef: React.MutableRefObject<WebSocket | null>;
sendMessageRef: React.MutableRefObject<(message: Message) => void>;
canvasRef: React.MutableRefObject<HTMLCanvasElement | null>;
sceneRef: React.MutableRefObject<THREE.Scene | null>;
cameraRef: React.MutableRefObject<THREE.PerspectiveCamera | null>;
Expand Down Expand Up @@ -142,7 +142,11 @@ function ViewerRoot() {
messageSource: playbackPath === null ? "websocket" : "file_playback",
useSceneTree: useSceneTreeState(),
useGui: useGuiState(initialServer),
websocketRef: React.useRef(null),
sendMessageRef: React.useRef((message) =>
console.log(
`Tried to send ${message.type} but websocket is not connected!`,
),
),
canvasRef: React.useRef(null),
sceneRef: React.useRef(null),
cameraRef: React.useRef(null),
Expand Down Expand Up @@ -206,7 +210,7 @@ function ViewerContents({ children }: { children: React.ReactNode }) {
})}
forceColorScheme={darkMode ? "dark" : "light"}
>
{ children }
{children}
<Notifications
position="top-left"
containerWidth="20em"
Expand Down Expand Up @@ -271,10 +275,7 @@ function ViewerContents({ children }: { children: React.ReactNode }) {

function ViewerCanvas({ children }: { children: React.ReactNode }) {
const viewer = React.useContext(ViewerContext)!;
const sendClickThrottled = makeThrottledMessageSender(
viewer.websocketRef,
20,
);
const sendClickThrottled = useThrottledMessageSender(20);
const theme = useMantineTheme();

return (
Expand Down
7 changes: 2 additions & 5 deletions src/viser/client/src/CameraControls.tsx
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
import { ViewerContext } from "./App";
import { makeThrottledMessageSender } from "./WebsocketFunctions";
import { CameraControls } from "@react-three/drei";
import { useThree } from "@react-three/fiber";
import * as holdEvent from "hold-event";
import React, { useContext, useRef } from "react";
import { PerspectiveCamera } from "three";
import * as THREE from "three";
import { computeT_threeworld_world } from "./WorldTransformUtils";
import { useThrottledMessageSender } from "./WebsocketFunctions";

export function SynchronizedCameraControls() {
const viewer = useContext(ViewerContext)!;
const camera = useThree((state) => state.camera as PerspectiveCamera);

const sendCameraThrottled = makeThrottledMessageSender(
viewer.websocketRef,
20,
);
const sendCameraThrottled = useThrottledMessageSender(20);

// Helper for resetting camera poses.
const initialCameraRef = useRef<{
Expand Down
5 changes: 2 additions & 3 deletions src/viser/client/src/ControlPanel/ControlPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import BottomPanel from "./BottomPanel";
import FloatingPanel from "./FloatingPanel";
import { ThemeConfigurationMessage } from "../WebsocketMessages";
import SidebarPanel from "./SidebarPanel";
import { sendWebsocketMessage } from "../WebsocketFunctions";

// Must match constant in Python.
const ROOT_CONTAINER_ID = "root";
Expand Down Expand Up @@ -270,7 +269,7 @@ function ShareButton() {
<Button
fullWidth
onClick={() => {
sendWebsocketMessage(viewer.websocketRef, {
viewer.sendMessageRef.current({
type: "ShareUrlRequest",
});
setDoingSomething(true); // Loader state will help with debouncing.
Expand Down Expand Up @@ -316,7 +315,7 @@ function ShareButton() {
<Button
color="red"
onClick={() => {
sendWebsocketMessage(viewer.websocketRef, {
viewer.sendMessageRef.current({
type: "ShareUrlDisconnect",
});
setShareUrl(null);
Expand Down
4 changes: 2 additions & 2 deletions src/viser/client/src/ControlPanel/Generated.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { ViewerContext } from "../App";
import { makeThrottledMessageSender } from "../WebsocketFunctions";
import { useThrottledMessageSender } from "../WebsocketFunctions";
import { GuiComponentContext } from "./GuiComponentContext";

import { Box } from "@mantine/core";
Expand Down Expand Up @@ -31,7 +31,7 @@ export default function GeneratedGuiContainer({
}) {
const viewer = React.useContext(ViewerContext)!;
const updateGuiProps = viewer.useGui((state) => state.updateGuiProps);
const messageSender = makeThrottledMessageSender(viewer.websocketRef, 50);
const messageSender = useThrottledMessageSender(50);

function setValue(id: string, value: NonNullable<unknown>) {
updateGuiProps(id, { value: value });
Expand Down
57 changes: 27 additions & 30 deletions src/viser/client/src/MessageHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@ import {
Message,
} from "./WebsocketMessages";
import { PivotControls } from "@react-three/drei";
import {
isTexture,
makeThrottledMessageSender,
sendWebsocketMessage,
} from "./WebsocketFunctions";
import { isTexture, makeThrottledMessageSender } from "./WebsocketFunctions";
import { isGuiConfig } from "./ControlPanel/GuiState";
import { useFrame } from "@react-three/fiber";
import GeneratedGuiContainer from "./ControlPanel/Generated";
Expand Down Expand Up @@ -234,16 +230,20 @@ function useMessageHandler() {
message.plane == "xz"
? new THREE.Euler(0.0, 0.0, 0.0)
: message.plane == "xy"
? new THREE.Euler(Math.PI / 2.0, 0.0, 0.0)
: message.plane == "yx"
? new THREE.Euler(0.0, Math.PI / 2.0, Math.PI / 2.0)
: message.plane == "yz"
? new THREE.Euler(0.0, 0.0, Math.PI / 2.0)
: message.plane == "zx"
? new THREE.Euler(0.0, Math.PI / 2.0, 0.0)
: message.plane == "zy"
? new THREE.Euler(-Math.PI / 2.0, 0.0, -Math.PI / 2.0)
: undefined
? new THREE.Euler(Math.PI / 2.0, 0.0, 0.0)
: message.plane == "yx"
? new THREE.Euler(0.0, Math.PI / 2.0, Math.PI / 2.0)
: message.plane == "yz"
? new THREE.Euler(0.0, 0.0, Math.PI / 2.0)
: message.plane == "zx"
? new THREE.Euler(0.0, Math.PI / 2.0, 0.0)
: message.plane == "zy"
? new THREE.Euler(
-Math.PI / 2.0,
0.0,
-Math.PI / 2.0,
)
: undefined
}
/>
</group>
Expand Down Expand Up @@ -331,16 +331,16 @@ function useMessageHandler() {
message.material == "standard" || message.wireframe
? new THREE.MeshStandardMaterial(standardArgs)
: message.material == "toon3"
? new THREE.MeshToonMaterial({
gradientMap: generateGradientMap(3),
...standardArgs,
})
: message.material == "toon5"
? new THREE.MeshToonMaterial({
gradientMap: generateGradientMap(5),
...standardArgs,
})
: assertUnreachable(message.material);
? new THREE.MeshToonMaterial({
gradientMap: generateGradientMap(3),
...standardArgs,
})
: message.material == "toon5"
? new THREE.MeshToonMaterial({
gradientMap: generateGradientMap(5),
...standardArgs,
})
: assertUnreachable(message.material);
geometry.setAttribute(
"position",
new THREE.Float32BufferAttribute(
Expand Down Expand Up @@ -559,10 +559,7 @@ function useMessageHandler() {
}
case "TransformControlsMessage": {
const name = message.name;
const sendDragMessage = makeThrottledMessageSender(
viewer.websocketRef,
50,
);
const sendDragMessage = makeThrottledMessageSender(viewer, 50);
addSceneNodeMakeParents(
new SceneNode<THREE.Group>(
message.name,
Expand Down Expand Up @@ -1175,7 +1172,7 @@ export function FrameSynchronizedMessageHandler() {
return;
}
const payload = new Uint8Array(await blob.arrayBuffer());
sendWebsocketMessage(viewer.websocketRef, {
viewer.sendMessageRef.current({
type: "GetRenderResponseMessage",
payload: payload,
});
Expand Down
7 changes: 2 additions & 5 deletions src/viser/client/src/SceneTree.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import React from "react";
import * as THREE from "three";

import { ViewerContext } from "./App";
import { makeThrottledMessageSender } from "./WebsocketFunctions";
import { useThrottledMessageSender } from "./WebsocketFunctions";
import { Html } from "@react-three/drei";
import { immerable } from "immer";
import { useSceneTreeState } from "./SceneTreeState";
Expand Down Expand Up @@ -265,10 +265,7 @@ export function SceneNodeThreeObject(props: {
});

// Clicking logic.
const sendClicksThrottled = makeThrottledMessageSender(
viewer.websocketRef,
50,
);
const sendClicksThrottled = useThrottledMessageSender(50);
const [hovered, setHovered] = React.useState(false);
useCursor(hovered);
const hoveredRef = React.useRef(false);
Expand Down
21 changes: 9 additions & 12 deletions src/viser/client/src/WebsocketFunctions.tsx
Original file line number Diff line number Diff line change
@@ -1,31 +1,28 @@
import { MutableRefObject } from "react";
import React from "react";
import * as THREE from "three";
import { Message } from "./WebsocketMessages";
import { encode } from "@msgpack/msgpack";
import { ViewerContext, ViewerContextContents } from "./App";

/** Send message over websocket. */
export function sendWebsocketMessage(
websocketRef: MutableRefObject<WebSocket | null>,
message: Message,
) {
if (websocketRef.current === null) return;
websocketRef.current.send(encode(message));
/** Easier, hook version of makeThrottledMessageSender. */
export function useThrottledMessageSender(throttleMilliseconds: number) {
const viewer = React.useContext(ViewerContext)!;
return makeThrottledMessageSender(viewer, throttleMilliseconds);
}

/** Returns a function for sending messages, with automatic throttling. */
export function makeThrottledMessageSender(
websocketRef: MutableRefObject<WebSocket | null>,
viewer: ViewerContextContents,
throttleMilliseconds: number,
) {
let readyToSend = true;
let stale = false;
let latestMessage: Message | null = null;

function send(message: Message) {
if (websocketRef.current === null) return;
if (viewer.sendMessageRef.current === null) return;
latestMessage = message;
if (readyToSend) {
websocketRef.current.send(encode(message));
viewer.sendMessageRef.current(message);
stale = false;
readyToSend = false;

Expand Down
91 changes: 29 additions & 62 deletions src/viser/client/src/WebsocketInterface.tsx
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import AwaitLock from "await-lock";
import { decode } from "@msgpack/msgpack";

import WebsocketServerWorker from "./WebsocketServerWorker?worker";
import React, { useContext } from "react";

import { ViewerContext } from "./App";
import { syncSearchParamServer } from "./SearchParamsUtils";
import { Message } from "./WebsocketMessages";
import { WsWorkerIncoming, WsWorkerOutgoing } from "./WebsocketServerWorker";
/** Component for handling websocket connections. */
export function WebsocketMessageProducer() {
const messageQueueRef = useContext(ViewerContext)!.messageQueueRef;
Expand All @@ -16,70 +14,39 @@ export function WebsocketMessageProducer() {
syncSearchParamServer(server);

React.useEffect(() => {
// Lock for making sure messages are handled in order.
const orderLock = new AwaitLock();

let ws: null | WebSocket = null;
let done = false;

function tryConnect(): void {
if (done) return;

ws = new WebSocket(server);
const worker = new WebsocketServerWorker();

// Timeout is necessary when we're connecting to an SSH/tunneled port.
const retryTimeout = setTimeout(() => {
ws?.close();
}, 5000);

ws.onopen = () => {
clearTimeout(retryTimeout);
console.log(`Connected! ${server}`);
viewer.websocketRef.current = ws;
worker.onmessage = (event) => {
const data: WsWorkerOutgoing = event.data;
if (data.type === "connected") {
resetGui();
viewer.useGui.setState({ websocketConnected: true });
};

ws.onclose = (event) => {
console.log(`Disconnected! ${server} code=${event.code}`);
clearTimeout(retryTimeout);
viewer.websocketRef.current = null;
viewer.scenePointerInfo.current!.enabled = false;
viewer.useGui.setState({ websocketConnected: false });
viewer.sendMessageRef.current = (message) => {
postToWorker({ type: "send", message: message });
};
} else if (data.type === "closed") {
resetGui();

// Try to reconnect.
timeout = setTimeout(tryConnect, 1000);
};

ws.onmessage = async (event) => {
// Reduce websocket backpressure.
const messagePromise = new Promise<Message[]>((resolve) => {
(event.data.arrayBuffer() as Promise<ArrayBuffer>).then((buffer) => {
resolve(decode(new Uint8Array(buffer)) as Message[]);
});
});

// Try our best to handle messages in order. If this takes more than 1 second, we give up. :)
await orderLock.acquireAsync({ timeout: 1000 }).catch(() => {
console.log("Order lock timed out.");
orderLock.release();
});
try {
const messages = await messagePromise;
messageQueueRef.current.push(...messages);
} finally {
orderLock.acquired && orderLock.release();
}
};
viewer.useGui.setState({ websocketConnected: false });
viewer.sendMessageRef.current = (message) => {
console.log(
`Tried to send ${message.type} but websocket is not connected!`,
);
};
} else if (data.type === "message_batch") {
messageQueueRef.current.push(...data.messages);
}
};
function postToWorker(data: WsWorkerIncoming) {
worker.postMessage(data);
}

let timeout = setTimeout(tryConnect, 500);
postToWorker({ type: "set_server", server: server });
return () => {
done = true;
clearTimeout(timeout);
postToWorker({ type: "close" });
viewer.sendMessageRef.current = (message) =>
console.log(
`Tried to send ${message.type} but websocket is not connected!`,
);
viewer.useGui.setState({ websocketConnected: false });
ws?.close();
clearTimeout(timeout);
};
}, [server, resetGui]);

Expand Down
Loading

0 comments on commit 50d8302

Please sign in to comment.