From 14d4b1e543d805b260dcd8bc4727b7ee97e72c50 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Wed, 20 Sep 2023 16:30:54 -0700 Subject: [PATCH] Add get_render() method for client handles (#102) * Add get_render() method for client handles * Suppress type error --- docs/source/examples/01_image.rst | 59 ++--- .../examples/07_record3d_visualizer.rst | 3 +- docs/source/examples/08_smplx_visualizer.rst | 3 +- examples/19_get_renders.py | 47 ++++ src/viser/_messages.py | 17 ++ src/viser/_viser.py | 57 ++++- src/viser/client/src/App.tsx | 9 +- src/viser/client/src/SceneTree.tsx | 49 +--- src/viser/client/src/WebsocketInterface.tsx | 220 +++++++++++++----- src/viser/client/src/WebsocketMessages.tsx | 36 ++- src/viser/infra/_infra.py | 14 ++ src/viser/infra/_typescript_interface_gen.py | 23 +- 12 files changed, 381 insertions(+), 156 deletions(-) create mode 100644 examples/19_get_renders.py diff --git a/docs/source/examples/01_image.rst b/docs/source/examples/01_image.rst index d0bf4d912..4d56eac34 100644 --- a/docs/source/examples/01_image.rst +++ b/docs/source/examples/01_image.rst @@ -24,37 +24,38 @@ NeRFs), or images to render as 3D textures. import viser - server = viser.ViserServer() - - # Add a background image. - server.set_background_image( - iio.imread(Path(__file__).parent / "assets/Cal_logo.png"), - format="png", - ) - - # Add main image. - server.add_image( - "/img", - iio.imread(Path(__file__).parent / "assets/Cal_logo.png"), - 4.0, - 4.0, - format="png", - wxyz=(1.0, 0.0, 0.0, 0.0), - position=(2.0, 2.0, 0.0), - ) - while True: + if __name__ == "__main__": + server = viser.ViserServer() + + # Add a background image. + server.set_background_image( + iio.imread(Path(__file__).parent / "assets/Cal_logo.png"), + format="png", + ) + + # Add main image. server.add_image( - "/noise", - onp.random.randint( - 0, - 256, - size=(400, 400, 3), - dtype=onp.uint8, - ), + "/img", + iio.imread(Path(__file__).parent / "assets/Cal_logo.png"), 4.0, 4.0, - format="jpeg", + format="png", wxyz=(1.0, 0.0, 0.0, 0.0), - position=(2.0, 2.0, -1e-2), + position=(2.0, 2.0, 0.0), ) - time.sleep(0.2) + while True: + server.add_image( + "/noise", + onp.random.randint( + 0, + 256, + size=(400, 400, 3), + dtype=onp.uint8, + ), + 4.0, + 4.0, + format="jpeg", + wxyz=(1.0, 0.0, 0.0, 0.0), + position=(2.0, 2.0, -1e-2), + ) + time.sleep(0.2) diff --git a/docs/source/examples/07_record3d_visualizer.rst b/docs/source/examples/07_record3d_visualizer.rst index 904c7863d..0952d0b2c 100644 --- a/docs/source/examples/07_record3d_visualizer.rst +++ b/docs/source/examples/07_record3d_visualizer.rst @@ -30,8 +30,9 @@ Parse and stream record3d captures. To get the demo data, see ``./assets/downloa data_path: Path = Path(__file__).parent / "assets/record3d_dance", downsample_factor: int = 4, max_frames: int = 100, + share: bool = False, ) -> None: - server = viser.ViserServer() + server = viser.ViserServer(share=share) print("Loading frames!") loader = viser.extras.Record3dLoader(data_path) diff --git a/docs/source/examples/08_smplx_visualizer.rst b/docs/source/examples/08_smplx_visualizer.rst index 7d40a8ccd..65ff74f58 100644 --- a/docs/source/examples/08_smplx_visualizer.rst +++ b/docs/source/examples/08_smplx_visualizer.rst @@ -41,8 +41,9 @@ parameters to run this script: num_betas: int = 10, num_expression_coeffs: int = 10, ext: Literal["npz", "pkl"] = "npz", + share: bool = False, ) -> None: - server = viser.ViserServer() + server = viser.ViserServer(share=share) server.configure_theme(control_layout="collapsible", dark_mode=True) model = smplx.create( model_path=str(model_path), diff --git a/examples/19_get_renders.py b/examples/19_get_renders.py new file mode 100644 index 000000000..61275f62e --- /dev/null +++ b/examples/19_get_renders.py @@ -0,0 +1,47 @@ +"""Get Renders + +Example for getting renders from a client's viewport to the Python API.""" + +import time + +import imageio.v3 as iio +import numpy as onp + +import viser + + +def main(): + server = viser.ViserServer() + + button = server.add_gui_button("Render a GIF") + + @button.on_click + def _(event: viser.GuiEvent) -> None: + client = event.client + assert client is not None + + client.reset_scene() + + images = [] + + for i in range(20): + positions = onp.random.normal(size=(30, 3)) * 3.0 + client.add_spline_catmull_rom( + f"/catmull_{i}", + positions, + tension=0.5, + line_width=3.0, + color=onp.random.uniform(size=3), + ) + images.append(client.get_render(height=720, width=1280)) + + print("Writing GIF...") + iio.imwrite("saved.gif", images) + print("Wrote GIF!") + + while True: + time.sleep(10.0) + + +if __name__ == "__main__": + main() diff --git a/src/viser/_messages.py b/src/viser/_messages.py index b5005e5f4..9a3dc66cf 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -494,3 +494,20 @@ class CubicBezierSplineMessage(Message): control_points: Tuple[Tuple[float, float, float], ...] line_width: float color: int + + +@dataclasses.dataclass +class GetRenderRequestMessage(Message): + """Message from server->client requesting a render of the current viewport.""" + + format: Literal["image/jpeg", "image/png"] + height: int + width: int + quality: int + + +@dataclasses.dataclass +class GetRenderResponseMessage(Message): + """Message from client->server carrying a render.""" + + payload: bytes diff --git a/src/viser/_viser.py b/src/viser/_viser.py index 5dd05feee..692661f29 100644 --- a/src/viser/_viser.py +++ b/src/viser/_viser.py @@ -2,18 +2,20 @@ import contextlib import dataclasses +import io import threading import time from pathlib import Path -from typing import Callable, Dict, Generator, List, Tuple +from typing import Callable, Dict, Generator, List, Optional, Tuple +import imageio.v3 as iio import numpy as onp import numpy.typing as npt import rich from rich import box, style from rich.panel import Panel from rich.table import Table -from typing_extensions import override +from typing_extensions import Literal, override from . import _client_autobuild, _messages, infra from . import transforms as tf @@ -196,6 +198,57 @@ def _queue_unsafe(self, message: _messages.Message) -> None: """Define how the message API should send messages.""" self._state.connection.send(message) + def get_render( + self, height: int, width: int, transport_format: Literal["png", "jpeg"] = "jpeg" + ) -> onp.ndarray: + """Request a render from a client, block until it's done and received, then + return it as a numpy array. + + Args: + height: Height of rendered image. Should be <= the browser height. + width: Width of rendered image. Should be <= the browser width. + transport_format: Image transport format. JPEG will return a lossy (H, W, 3) RGB array. PNG will + return a lossless (H, W, 4) RGBA array, but can cause memory issues on the frontend if called + too quickly for higher-resolution images. + """ + + # Listen for a render reseponse message, which should contain the rendered + # image. + render_ready_event = threading.Event() + out: Optional[onp.ndarray] = None + + def got_render_cb( + client_id: int, message: _messages.GetRenderResponseMessage + ) -> None: + del client_id + self._state.connection.unregister_handler( + _messages.GetRenderResponseMessage, got_render_cb + ) + nonlocal out + out = iio.imread( + io.BytesIO(message.payload), + extension=f".{transport_format}", + ) + render_ready_event.set() + + self._state.connection.register_handler( + _messages.GetRenderResponseMessage, got_render_cb + ) + self._queue( + _messages.GetRenderRequestMessage( + "image/jpeg" if transport_format == "jpeg" else "image/png", + height=height, + width=width, + # Only used for JPEG. The main reason to use a lower quality version + # value is (unfortunately) to make life easier for the Javascript + # garbage collector. + quality=80, + ) + ) + render_ready_event.wait() + assert out is not None + return out + @contextlib.contextmanager def atomic(self) -> Generator[None, None, None]: """Returns a context where: diff --git a/src/viser/client/src/App.tsx b/src/viser/client/src/App.tsx index 03b4cac70..f3f6ae810 100644 --- a/src/viser/client/src/App.tsx +++ b/src/viser/client/src/App.tsx @@ -36,7 +36,7 @@ import { import { Titlebar } from "./Titlebar"; import { ViserModal } from "./Modal"; import { useSceneTreeState } from "./SceneTreeState"; -import { Message } from "./WebsocketMessages"; +import { GetRenderRequestMessage, Message } from "./WebsocketMessages"; export type ViewerContextContents = { // Zustand hooks. @@ -61,6 +61,11 @@ export type ViewerContextContents = { }; }>; messageQueueRef: React.MutableRefObject; + // Requested a render. + getRenderRequestState: React.MutableRefObject< + "ready" | "triggered" | "pause" | "in_progress" + >; + getRenderRequest: React.MutableRefObject; }; export const ViewerContext = React.createContext( null @@ -99,6 +104,8 @@ function ViewerRoot() { // Scene node attributes that aren't placed in the zustand state for performance reasons. nodeAttributesFromName: React.useRef({}), messageQueueRef: React.useRef([]), + getRenderRequestState: React.useRef("ready"), + getRenderRequest: React.useRef(null), }; return ( diff --git a/src/viser/client/src/SceneTree.tsx b/src/viser/client/src/SceneTree.tsx index 7928a50f7..548e74183 100644 --- a/src/viser/client/src/SceneTree.tsx +++ b/src/viser/client/src/SceneTree.tsx @@ -40,47 +40,16 @@ function SceneNodeThreeChildren(props: { parent: THREE.Object3D; }) { const viewer = React.useContext(ViewerContext)!; - const [children, setChildren] = React.useState([]); - - // De-bounce updates to children. - React.useEffect(() => { - let readyToUpdate = true; - - let updateChildrenTimeout: NodeJS.Timeout | undefined = undefined; - - function updateChildren() { - const newChildren = - viewer.useSceneTree.getState().nodeFromName[props.name]?.children; - if (newChildren === undefined || children == newChildren) { - return; - } - if (readyToUpdate) { - setChildren(newChildren!); - readyToUpdate = false; - updateChildrenTimeout = setTimeout(() => { - readyToUpdate = true; - updateChildren(); - }, 50); - } - } - const unsubscribe = viewer.useSceneTree.subscribe( - (state) => state.nodeFromName[props.name], - updateChildren - ); - updateChildren(); - - return () => { - clearTimeout(updateChildrenTimeout); - unsubscribe(); - }; - }, [children]); + const children = + viewer.useSceneTree((state) => state.nodeFromName[props.name]?.children); // Create a group of children inside of the parent object. return createPortal( - {children.map((child_id) => { - return ; - })} + {children && + children.map((child_id) => { + return ; + })} , props.parent @@ -129,8 +98,10 @@ export function SceneNodeThreeObject(props: { name: string }) { // For not-fully-understood reasons, wrapping makeObject with useMemo() fixes // stability issues (eg breaking runtime errors) associated with // PivotControls. - const objNode = - makeObject && React.useMemo(() => makeObject(setRef), [makeObject]); + const objNode = React.useMemo( + () => makeObject && makeObject(setRef), + [makeObject] + ); const children = obj === null ? null : ( diff --git a/src/viser/client/src/WebsocketInterface.tsx b/src/viser/client/src/WebsocketInterface.tsx index 116f47ef1..f43bce9b1 100644 --- a/src/viser/client/src/WebsocketInterface.tsx +++ b/src/viser/client/src/WebsocketInterface.tsx @@ -13,7 +13,11 @@ import { CameraFrustum, CoordinateFrame } from "./ThreeAssets"; import { Message } from "./WebsocketMessages"; import styled from "@emotion/styled"; import { Html, PivotControls } from "@react-three/drei"; -import { isTexture, makeThrottledMessageSender } from "./WebsocketFunctions"; +import { + isTexture, + makeThrottledMessageSender, + sendWebsocketMessage, +} from "./WebsocketFunctions"; import { isGuiConfig, useViserMantineTheme } from "./ControlPanel/GuiState"; import { useFrame } from "@react-three/fiber"; import GeneratedGuiContainer from "./ControlPanel/Generated"; @@ -31,7 +35,7 @@ function threeColorBufferFromUint8Buffer(colors: ArrayBuffer) { return Math.pow((value + 0.055) / 1.055, 2.4); } }), - 3, + 3 ); } @@ -63,7 +67,7 @@ function useMessageHandler() { addSceneNodeMakeParents( new SceneNode(parent_name, (ref) => ( - )), + )) ); } addSceneNode(node); @@ -79,6 +83,13 @@ function useMessageHandler() { } switch (message.type) { + // Request a render. + case "GetRenderRequestMessage": { + viewer.getRenderRequest.current = message; + viewer.getRenderRequestState.current = "triggered"; + return; + } + // Configure the theme. case "ThemeConfigurationMessage": { setTheme(message); return; @@ -93,7 +104,7 @@ function useMessageHandler() { axes_length={message.axes_length} axes_radius={message.axes_radius} /> - )), + )) ); return; } @@ -114,18 +125,18 @@ function useMessageHandler() { new Float32Array( message.points.buffer.slice( message.points.byteOffset, - message.points.byteOffset + message.points.byteLength, - ), + message.points.byteOffset + message.points.byteLength + ) ), - 3, - ), + 3 + ) ); geometry.computeBoundingSphere(); // Wrap uint8 buffer for colors. Note that we need to set normalized=true. geometry.setAttribute( "color", - threeColorBufferFromUint8Buffer(message.colors), + threeColorBufferFromUint8Buffer(message.colors) ); addSceneNodeMakeParents( @@ -144,8 +155,8 @@ function useMessageHandler() { // disposal. geometry.dispose(); pointCloudMaterial.dispose(); - }, - ), + } + ) ); return; } @@ -181,16 +192,16 @@ function useMessageHandler() { new Float32Array( message.vertices.buffer.slice( message.vertices.byteOffset, - message.vertices.byteOffset + message.vertices.byteLength, - ), + message.vertices.byteOffset + message.vertices.byteLength + ) ), - 3, - ), + 3 + ) ); if (message.vertex_colors !== null) { geometry.setAttribute( "color", - threeColorBufferFromUint8Buffer(message.vertex_colors), + threeColorBufferFromUint8Buffer(message.vertex_colors) ); } @@ -199,11 +210,11 @@ function useMessageHandler() { new Uint32Array( message.faces.buffer.slice( message.faces.byteOffset, - message.faces.byteOffset + message.faces.byteLength, - ), + message.faces.byteOffset + message.faces.byteLength + ) ), - 1, - ), + 1 + ) ); geometry.computeVertexNormals(); geometry.computeBoundingSphere(); @@ -219,8 +230,8 @@ function useMessageHandler() { // disposal. geometry.dispose(); material.dispose(); - }, - ), + } + ) ); return; } @@ -230,7 +241,7 @@ function useMessageHandler() { message.image_media_type !== null && message.image_base64_data !== null ? new TextureLoader().load( - `data:${message.image_media_type};base64,${message.image_base64_data}`, + `data:${message.image_media_type};base64,${message.image_base64_data}` ) : undefined; @@ -247,8 +258,8 @@ function useMessageHandler() { image={texture} /> ), - () => texture?.dispose(), - ), + () => texture?.dispose() + ) ); return; } @@ -256,7 +267,7 @@ function useMessageHandler() { const name = message.name; const sendDragMessage = makeThrottledMessageSender( viewer.websocketRef, - 50, + 50 ); addSceneNodeMakeParents( new SceneNode(message.name, (ref) => ( @@ -297,7 +308,7 @@ function useMessageHandler() { }} /> - )), + )) ); return; } @@ -306,12 +317,12 @@ function useMessageHandler() { const R_threeworld_world = new THREE.Quaternion(); R_threeworld_world.setFromEuler( - new THREE.Euler(-Math.PI / 2.0, 0.0, 0.0), + new THREE.Euler(-Math.PI / 2.0, 0.0, 0.0) ); const target = new THREE.Vector3( message.look_at[0], message.look_at[1], - message.look_at[2], + message.look_at[2] ); target.applyQuaternion(R_threeworld_world); cameraControls.setTarget(target.x, target.y, target.z); @@ -322,12 +333,12 @@ function useMessageHandler() { const cameraControls = viewer.cameraControlRef.current!; const R_threeworld_world = new THREE.Quaternion(); R_threeworld_world.setFromEuler( - new THREE.Euler(-Math.PI / 2.0, 0.0, 0.0), + new THREE.Euler(-Math.PI / 2.0, 0.0, 0.0) ); const updir = new THREE.Vector3( message.position[0], message.position[1], - message.position[2], + message.position[2] ).applyQuaternion(R_threeworld_world); camera.up.set(updir.x, updir.y, updir.z); @@ -341,7 +352,7 @@ function useMessageHandler() { cameraControls.setPosition( prevPosition.x, prevPosition.y, - prevPosition.z, + prevPosition.z ); return; } @@ -352,18 +363,18 @@ function useMessageHandler() { const position_cmd = new THREE.Vector3( message.position[0], message.position[1], - message.position[2], + message.position[2] ); const R_worldthree_world = new THREE.Quaternion(); R_worldthree_world.setFromEuler( - new THREE.Euler(-Math.PI / 2.0, 0.0, 0.0), + new THREE.Euler(-Math.PI / 2.0, 0.0, 0.0) ); position_cmd.applyQuaternion(R_worldthree_world); cameraControls.setPosition( position_cmd.x, position_cmd.y, - position_cmd.z, + position_cmd.z ); return; } @@ -372,7 +383,7 @@ function useMessageHandler() { // tan(fov / 2.0) = 0.5 * film height / focal length // focal length = 0.5 * film height / tan(fov / 2.0) camera.setFocalLength( - (0.5 * camera.getFilmHeight()) / Math.tan(message.fov / 2.0), + (0.5 * camera.getFilmHeight()) / Math.tan(message.fov / 2.0) ); return; } @@ -408,14 +419,11 @@ function useMessageHandler() { if (isTexture(oldBackgroundTexture)) oldBackgroundTexture.dispose(); viewer.useGui.setState({ backgroundAvailable: true }); - }, + } ); viewer.backgroundMaterialRef.current!.uniforms.enabled.value = true; viewer.backgroundMaterialRef.current!.uniforms.hasDepth.value = message.base64_depth !== null; - console.log( - viewer.backgroundMaterialRef.current!.uniforms.hasDepth.value, - ); if (message.base64_depth !== null) { // If depth is available set the texture @@ -427,7 +435,7 @@ function useMessageHandler() { viewer.backgroundMaterialRef.current!.uniforms.depthMap.value = texture; if (isTexture(oldDepthTexture)) oldDepthTexture.dispose(); - }, + } ); } return; @@ -472,7 +480,7 @@ function useMessageHandler() { ); - }), + }) ); return; } @@ -509,7 +517,7 @@ function useMessageHandler() { ); - }), + }) ); return; } @@ -544,10 +552,10 @@ function useMessageHandler() { ); }, - () => texture.dispose(), - ), + () => texture.dispose() + ) ); - }, + } ); return; } @@ -613,7 +621,7 @@ function useMessageHandler() { > ); - }), + }) ); return; } @@ -636,12 +644,12 @@ function useMessageHandler() { ))} ); - }), + }) ); return; } default: { - console.log("Receivd message did not match any known types:", message); + console.log("Received message did not match any known types:", message); return; } } @@ -650,14 +658,112 @@ function useMessageHandler() { export function FrameSynchronizedMessageHandler() { const handleMessage = useMessageHandler(); - const messageQueueRef = useContext(ViewerContext)!.messageQueueRef; + const viewer = useContext(ViewerContext)!; + const messageQueueRef = viewer.messageQueueRef; + + // We'll reuse the same canvas. + const renderBufferCanvas = React.useMemo(() => new OffscreenCanvas(1, 1), []); useFrame(() => { - // Handle messages before every frame. - // Place this directly in ws.onmessage can cause race conditions! - const numMessages = messageQueueRef.current.length; - const processBatch = messageQueueRef.current.splice(0, numMessages); - processBatch.forEach(handleMessage); + // Send a render along if it was requested! + if (viewer.getRenderRequestState.current === "triggered") { + viewer.getRenderRequestState.current = "pause"; + } else if (viewer.getRenderRequestState.current === "pause") { + const sourceCanvas = viewer.canvasRef.current!; + + const targetWidth = viewer.getRenderRequest.current!.width; + const targetHeight = viewer.getRenderRequest.current!.height; + + // We'll save a render to an intermediate canvas with the requested dimensions. + if (renderBufferCanvas.width !== targetWidth) + renderBufferCanvas.width = targetWidth; + if (renderBufferCanvas.height !== targetHeight) + renderBufferCanvas.height = targetHeight; + + const ctx = renderBufferCanvas.getContext("2d")!; + ctx.reset(); + // Use a white background for JPEGs, which don't have an alpha channel. + if (viewer.getRenderRequest.current?.format === "image/jpeg") { + ctx.fillStyle = "white"; + ctx.fillRect(0, 0, renderBufferCanvas.width, renderBufferCanvas.height); + } + + // Determine offsets for the source canvas. We'll always center our renders. + // https://developer.mozilla.org/en-US/docs/Web/API/CanvasRenderingContext2D/drawImage + let sourceWidth = sourceCanvas.width; + let sourceHeight = sourceCanvas.height; + + const sourceAspect = sourceWidth / sourceHeight; + const targetAspect = targetWidth / targetHeight; + + if (sourceAspect > targetAspect) { + // The source is wider than the target. + // We need to shrink the width. + sourceWidth = Math.round(targetAspect * sourceHeight); + } else if (sourceAspect < targetAspect) { + // The source is narrower than the target. + // We need to shrink the height. + sourceHeight = Math.round(sourceWidth / targetAspect); + } + + console.log( + `Sending render; requested aspect ratio was ${targetAspect} (dimensinos: ${targetWidth}/${targetHeight}), copying from aspect ratio ${ + sourceWidth / sourceHeight + } (dimensions: ${sourceWidth}/${sourceHeight}).` + ); + + ctx.drawImage( + sourceCanvas, + (sourceCanvas.width - sourceWidth) / 2.0, + (sourceCanvas.height - sourceHeight) / 2.0, + sourceWidth, + sourceHeight, + 0, + 0, + targetWidth, + targetHeight + ); + + viewer.getRenderRequestState.current = "in_progress"; + + // Encode the image, the send it. + renderBufferCanvas + .convertToBlob({ + type: viewer.getRenderRequest.current!.format, + quality: viewer.getRenderRequest.current!.quality / 100.0, + }) + .then(async (blob) => { + if (blob === null) { + console.error("Render failed"); + viewer.getRenderRequestState.current = "ready"; + return; + } + const payload = new Uint8Array(await blob.arrayBuffer()); + sendWebsocketMessage(viewer.websocketRef, { + type: "GetRenderResponseMessage", + payload: payload, + }); + viewer.getRenderRequestState.current = "ready"; + }); + } + + // Handle messages, but only if we're not trying to render something. + if (viewer.getRenderRequestState.current === "ready") { + // Handle messages before every frame. + // Place this directly in ws.onmessage can cause race conditions! + // + // If a render is requested, note that we don't handle any more messages + // until the render is done. + const requestRenderIndex = messageQueueRef.current.findIndex( + (message) => message.type === "GetRenderRequestMessage" + ); + const numMessages = + requestRenderIndex !== -1 + ? requestRenderIndex + 1 + : messageQueueRef.current.length; + const processBatch = messageQueueRef.current.splice(0, numMessages); + processBatch.forEach(handleMessage); + } }); return null; @@ -696,8 +802,8 @@ export function WebsocketMessageProducer() { viewer.useGui.setState({ websocketConnected: true }); }; - ws.onclose = () => { - console.log(`Disconnected! ${server}`); + ws.onclose = (event) => { + console.log(`Disconnected! ${server} code=${event.code}`); clearTimeout(retryTimeout); viewer.websocketRef.current = null; viewer.useGui.setState({ websocketConnected: false }); diff --git a/src/viser/client/src/WebsocketMessages.tsx b/src/viser/client/src/WebsocketMessages.tsx index 8c0dadaba..2ebe7f8a2 100644 --- a/src/viser/client/src/WebsocketMessages.tsx +++ b/src/viser/client/src/WebsocketMessages.tsx @@ -1,8 +1,5 @@ // AUTOMATICALLY GENERATED message interfaces, from Python dataclass definitions. // This file should not be manually modified. - -// For numpy arrays, we directly serialize the underlying data buffer. -type ArrayBuffer = Uint8Array; /** Message for a posed viewer camera. * Pose is in the form T_world_camera, OpenCV convention, +Z forward. * @@ -78,8 +75,8 @@ export interface Gui3DMessage { export interface PointCloudMessage { type: "PointCloudMessage"; name: string; - points: ArrayBuffer; - colors: ArrayBuffer; + points: Uint8Array; + colors: Uint8Array; point_size: number; } /** Mesh message. @@ -91,10 +88,10 @@ export interface PointCloudMessage { export interface MeshMessage { type: "MeshMessage"; name: string; - vertices: ArrayBuffer; - faces: ArrayBuffer; + vertices: Uint8Array; + faces: Uint8Array; color: number | null; - vertex_colors: ArrayBuffer | null; + vertex_colors: Uint8Array | null; wireframe: boolean; opacity: number | null; side: "front" | "back" | "double"; @@ -600,6 +597,25 @@ export interface CubicBezierSplineMessage { line_width: number; color: number; } +/** Message from server->client requesting a render of the current viewport. + * + * (automatically generated) + */ +export interface GetRenderRequestMessage { + type: "GetRenderRequestMessage"; + format: "image/jpeg" | "image/png"; + height: number; + width: number; + quality: number; +} +/** Message from client->server carrying a render. + * + * (automatically generated) + */ +export interface GetRenderResponseMessage { + type: "GetRenderResponseMessage"; + payload: Uint8Array; +} export type Message = | ViewerCameraMessage @@ -648,4 +664,6 @@ export type Message = | GuiSetValueMessage | ThemeConfigurationMessage | CatmullRomSplineMessage - | CubicBezierSplineMessage; + | CubicBezierSplineMessage + | GetRenderRequestMessage + | GetRenderResponseMessage; diff --git a/src/viser/infra/_infra.py b/src/viser/infra/_infra.py index 737bb98fe..3b3f6997c 100644 --- a/src/viser/infra/_infra.py +++ b/src/viser/infra/_infra.py @@ -64,6 +64,20 @@ def register_handler( self._incoming_handlers[message_cls] = [] self._incoming_handlers[message_cls].append(callback) # type: ignore + def unregister_handler( + self, + message_cls: Type[TMessage], + callback: Optional[Callable[[ClientId, TMessage], None]] = None, + ): + """Unregister a handler for a particular message type.""" + assert ( + message_cls in self._incoming_handlers + ), "Tried to unregister a handler that hasn't been registered." + if callback is None: + self._incoming_handlers.pop(message_cls) + else: + self._incoming_handlers[message_cls].remove(callback) # type: ignore + def _handle_incoming_message(self, client_id: ClientId, message: Message) -> None: """Handle incoming messages.""" if type(message) in self._incoming_handlers: diff --git a/src/viser/infra/_typescript_interface_gen.py b/src/viser/infra/_typescript_interface_gen.py index 52bc3ce5a..9c82a327b 100644 --- a/src/viser/infra/_typescript_interface_gen.py +++ b/src/viser/infra/_typescript_interface_gen.py @@ -16,7 +16,9 @@ float: "number", int: "number", str: "string", - onp.ndarray: "ArrayBuffer", + # For numpy arrays, we directly serialize the underlying data buffer. + onp.ndarray: "Uint8Array", + bytes: "Uint8Array", Any: "any", None: "null", type(None): "null", @@ -80,9 +82,9 @@ def generate_typescript_interfaces(message_cls: Type[Message]) -> str: map(lambda line: line.strip(), cls.__doc__.split("\n")) ) out_lines.append(f"/** {docstring}") - out_lines.append(f" *") - out_lines.append(f" * (automatically generated)") - out_lines.append(f" */") + out_lines.append(" *") + out_lines.append(" * (automatically generated)") + out_lines.append(" */") out_lines.append(f"export interface {cls.__name__} " + "{") out_lines.append(f' type: "{cls.__name__}";') @@ -117,19 +119,6 @@ def generate_typescript_interfaces(message_cls: Type[Message]) -> str: "// This file should not be manually modified.", "", ] - + ( - # Add numpy type alias if needed. - [ - ( - "// For numpy arrays, we directly serialize the underlying data" - " buffer." - ), - "type ArrayBuffer = Uint8Array;", - "", - ] - if interfaces.count("ArrayBuffer") > 0 - else [] - ) ) + interfaces )