diff --git a/.gitignore b/.gitignore index 204279b6d..35074b16e 100644 --- a/.gitignore +++ b/.gitignore @@ -14,5 +14,6 @@ htmlcov .DS_Store .envrc .vite +build src/viser/client/build src/viser/client/.nodeenv diff --git a/docs/source/conventions.md b/docs/source/conventions.md index f7c5b4f0c..4ab30ff39 100644 --- a/docs/source/conventions.md +++ b/docs/source/conventions.md @@ -41,8 +41,7 @@ where `wxyz` is the quaternion form of the :math:`\mathrm{SO}(3)` matrix ## World coordinates In the world coordinate space, +Z points upward by default. This can be -overridden with :func:`viser.ViserServer.set_up_direction()` or -:func:`viser.ClientHandle.set_up_direction()`. +overridden with :func:`viser.SceneApi.set_up_direction()`. ## Cameras diff --git a/docs/source/examples/02_gui.rst b/docs/source/examples/02_gui.rst index 2d3178eb6..3f7e185e2 100644 --- a/docs/source/examples/02_gui.rst +++ b/docs/source/examples/02_gui.rst @@ -29,7 +29,6 @@ Examples of basic GUI elements that we can create, read from, and write to. initial_value=0, disabled=True, ) - gui_slider = server.gui.add_slider( "Slider", min=0, @@ -38,6 +37,7 @@ Examples of basic GUI elements that we can create, read from, and write to. initial_value=0, disabled=True, ) + gui_progress = server.gui.add_progress_bar(25, animated=True) with server.gui.add_folder("Editable"): gui_vector2 = server.gui.add_vector2( @@ -119,6 +119,8 @@ Examples of basic GUI elements that we can create, read from, and write to. point_shape="circle", ) + gui_progress.value = float((counter % 100)) + # We can use `.visible` and `.disabled` to toggle GUI elements. gui_text.visible = not gui_checkbox_hide.value gui_button.visible = not gui_checkbox_hide.value diff --git a/docs/source/examples/23_smpl_visualizer_skinned.rst b/docs/source/examples/23_smpl_visualizer_skinned.rst index bed4961b6..9650659c5 100644 --- a/docs/source/examples/23_smpl_visualizer_skinned.rst +++ b/docs/source/examples/23_smpl_visualizer_skinned.rst @@ -139,8 +139,6 @@ See here for download instructions: # Match transform control gizmos to joint positions. for i, control in enumerate(gui_elements.transform_controls): control.position = smpl_outputs.T_parent_joint[i, :3, 3] - print(control.position) - skinned_handle.bones[i].wxyz = tf.SO3.from_matrix( smpl_outputs.T_world_joint[i, :3, :3] ).wxyz diff --git a/docs/source/index.md b/docs/source/index.md index ddd76a199..c55da51a3 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -29,8 +29,6 @@ pip install viser[examples] After an example script is running, you can connect by navigating to the printed URL (default: `http://localhost:8080`). -See also: our [development docs](https://viser.studio/development/). - .. toctree:: diff --git a/examples/02_gui.py b/examples/02_gui.py index ebbc62bbc..e13228359 100644 --- a/examples/02_gui.py +++ b/examples/02_gui.py @@ -18,7 +18,6 @@ def main() -> None: initial_value=0, disabled=True, ) - gui_slider = server.gui.add_slider( "Slider", min=0, @@ -27,6 +26,7 @@ def main() -> None: initial_value=0, disabled=True, ) + gui_progress = server.gui.add_progress_bar(25, animated=True) with server.gui.add_folder("Editable"): gui_vector2 = server.gui.add_vector2( @@ -108,6 +108,8 @@ def _(_) -> None: point_shape="circle", ) + gui_progress.value = float((counter % 100)) + # We can use `.visible` and `.disabled` to toggle GUI elements. gui_text.visible = not gui_checkbox_hide.value gui_button.visible = not gui_checkbox_hide.value diff --git a/examples/20_scene_pointer.py b/examples/20_scene_pointer.py index 1069c3024..6d5a2e903 100644 --- a/examples/20_scene_pointer.py +++ b/examples/20_scene_pointer.py @@ -71,7 +71,7 @@ def _(event: viser.ScenePointerEvent) -> None: client.scene.remove_pointer_callback() # Get the first hit position (based on distance from the ray origin). - hit_pos = min(hit_pos, key=lambda x: onp.linalg.norm(x - origin)) + hit_pos = hit_pos[onp.argmin(onp.sum((hit_pos - origin) ** 2, axis=-1))] # Create a sphere at the hit location. hit_pos_mesh = trimesh.creation.icosphere(radius=0.1) diff --git a/examples/23_smpl_visualizer_skinned.py b/examples/23_smpl_visualizer_skinned.py index 7fc261009..2fb364c48 100644 --- a/examples/23_smpl_visualizer_skinned.py +++ b/examples/23_smpl_visualizer_skinned.py @@ -134,8 +134,6 @@ def main(model_path: Path) -> None: # Match transform control gizmos to joint positions. for i, control in enumerate(gui_elements.transform_controls): control.position = smpl_outputs.T_parent_joint[i, :3, 3] - print(control.position) - skinned_handle.bones[i].wxyz = tf.SO3.from_matrix( smpl_outputs.T_world_joint[i, :3, :3] ).wxyz diff --git a/examples/assets/mdx_example.mdx b/examples/assets/mdx_example.mdx index f72a4f6b5..fe73a976d 100644 --- a/examples/assets/mdx_example.mdx +++ b/examples/assets/mdx_example.mdx @@ -16,7 +16,7 @@ In inline code blocks, you can show off colors with color chips: `#FED363` Adding images from a remote origin is simple. -![Viser Logo](http://nerfstudio-project.github.io/viser/_static/viser.svg) +![Viser Logo](https://viser.studio/latest/_static/logo.svg) For local images with relative paths, you can either directly use a [data URL](https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/Data_URLs) @@ -30,7 +30,7 @@ Tables follow the standard markdown spec: | Application | Description | | ---------------------------------------------------- | -------------------------------------------------- | -| [Nerfstudio](https://nerf.studio) | A collaboration friendly studio for NeRFs | +| [NS](https://nerf.studio) | A collaboration friendly studio for NeRFs | | [Viser](https://nerfstudio-project.github.io/viser/) | An interactive 3D visualization toolbox for Python | Code blocks, while being not nearly as exciting as some of the things presented, @@ -90,5 +90,3 @@ So that's MDX in Viser. It has support for: blocks, inline code - [x] Color chips - [x] JSX enhanced components -- [ ] Prism highlighted code blocks and code block tabs -- [ ] Exposed Mantine in markdown diff --git a/examples/experimental/gaussian_splats.py b/examples/experimental/gaussian_splats.py index 0960391ee..24d7b10f5 100644 --- a/examples/experimental/gaussian_splats.py +++ b/examples/experimental/gaussian_splats.py @@ -29,6 +29,7 @@ class SplatFile(TypedDict): def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile: """Load an antimatter15-style splat file.""" + start_time = time.time() splat_buffer = splat_path.read_bytes() bytes_per_gaussian = ( # Each Gaussian is serialized as: @@ -43,7 +44,6 @@ def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile: ) assert len(splat_buffer) % bytes_per_gaussian == 0 num_gaussians = len(splat_buffer) // bytes_per_gaussian - print("Number of gaussians to render: ", f"{num_gaussians=}") # Reinterpret cast to dtypes that we want to extract. splat_uint8 = onp.frombuffer(splat_buffer, dtype=onp.uint8).reshape( @@ -51,14 +51,16 @@ def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile: ) scales = splat_uint8[:, 12:24].copy().view(onp.float32) wxyzs = splat_uint8[:, 28:32] / 255.0 * 2.0 - 1.0 - Rs = onp.array([tf.SO3(wxyz).as_matrix() for wxyz in wxyzs]) + Rs = tf.SO3(wxyzs).as_matrix() covariances = onp.einsum( "nij,njk,nlk->nil", Rs, onp.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs ) centers = splat_uint8[:, 0:12].copy().view(onp.float32) if center: centers -= onp.mean(centers, axis=0, keepdims=True) - print("Splat file loaded") + print( + f"Splat file with {num_gaussians=} loaded in {time.time() - start_time} seconds" + ) return { "centers": centers, # Colors should have shape (N, 3). @@ -70,64 +72,40 @@ def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile: def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: - plydata = PlyData.read(ply_file_path) - vert = plydata["vertex"] - sorted_indices = onp.argsort( - -onp.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"]) - / (1 + onp.exp(-vert["opacity"])) - ) - numgaussians = len(vert) - print("Number of gaussians to render: ", numgaussians) - colors = onp.zeros((numgaussians, 3)) - opacities = onp.zeros((numgaussians, 1)) - positions = onp.zeros((numgaussians, 3)) - wxyzs = onp.zeros((numgaussians, 4)) - scales = onp.zeros((numgaussians, 3)) - for idx in sorted_indices: - v = plydata["vertex"][idx] - position = onp.array([v["x"], v["y"], v["z"]], dtype=onp.float32) - scale = onp.exp( - onp.array([v["scale_0"], v["scale_1"], v["scale_2"]], dtype=onp.float32) - ) + """Load Gaussians stored in a PLY file.""" + start_time = time.time() - rot = onp.array( - [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], dtype=onp.float32 - ) - SH_C0 = 0.28209479177387814 - color = onp.array( - [ - 0.5 + SH_C0 * v["f_dc_0"], - 0.5 + SH_C0 * v["f_dc_1"], - 0.5 + SH_C0 * v["f_dc_2"], - ] - ) - opacity = 1 / (1 + onp.exp(-v["opacity"])) - wxyz = rot / onp.linalg.norm(rot) # normalize - scales[idx] = scale - colors[idx] = color - opacities[idx] = onp.array([opacity]) - positions[idx] = position - wxyzs[idx] = wxyz - - Rs = onp.array([tf.SO3(wxyz).as_matrix() for wxyz in wxyzs]) + SH_C0 = 0.28209479177387814 + + plydata = PlyData.read(ply_file_path) + v = plydata["vertex"] + positions = onp.stack([v["x"], v["y"], v["z"]], axis=-1) + scales = onp.exp(onp.stack([v["scale_0"], v["scale_1"], v["scale_2"]], axis=-1)) + wxyzs = onp.stack([v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], axis=1) + colors = 0.5 + SH_C0 * onp.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1) + opacities = 1.0 / (1.0 + onp.exp(-v["opacity"][:, None])) + + Rs = tf.SO3(wxyzs).as_matrix() covariances = onp.einsum( "nij,njk,nlk->nil", Rs, onp.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs ) if center: positions -= onp.mean(positions, axis=0, keepdims=True) - print("PLY file loaded") + + num_gaussians = len(v) + print( + f"PLY file with {num_gaussians=} loaded in {time.time() - start_time} seconds" + ) return { "centers": positions, - # Colors should have shape (N, 3). "rgbs": colors, "opacities": opacities, - # Covariances should have shape (N, 3, 3). "covariances": covariances, } -def main(splat_paths: tuple[Path, ...], test_multisplat: bool = False) -> None: - server = viser.ViserServer(share=True) +def main(splat_paths: tuple[Path, ...]) -> None: + server = viser.ViserServer() server.gui.configure_theme(dark_mode=True) gui_reset_up = server.gui.add_button( "Reset up direction", @@ -151,7 +129,7 @@ def _(event: viser.GuiEvent) -> None: raise SystemExit("Please provide a filepath to a .splat or .ply file.") server.scene.add_transform_controls(f"/{i}") - server.scene._add_gaussian_splats( + gs_handle = server.scene._add_gaussian_splats( f"/{i}/gaussian_splats", centers=splat_data["centers"], rgbs=splat_data["rgbs"], @@ -159,6 +137,13 @@ def _(event: viser.GuiEvent) -> None: covariances=splat_data["covariances"], ) + remove_button = server.gui.add_button(f"Remove splat object {i}") + + @remove_button.on_click + def _(_, gs_handle=gs_handle, remove_button=remove_button) -> None: + gs_handle.remove() + remove_button.remove() + while True: time.sleep(10.0) diff --git a/pyproject.toml b/pyproject.toml index 7dcd84069..b9fee7b6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ dependencies = [ "websockets>=10.4", "numpy>=1.0.0", - "msgpack>=1.0.7", + "msgspec>=0.18.6", "imageio>=2.0.0", "pyliblzfse>=0.4.1; platform_system!='Windows'", "scikit-image>=0.18.0", @@ -71,6 +71,13 @@ examples = [ viser = ["py.typed", "*.pyi", "_icons/tabler-icons.tar", "client/**/*", "client/**/.*"] # +[tool.setuptools.exclude-package-data] +# We exclude node_modules to prevent long build times for wheels when +# installing from source, eg via `pip install .`. +# +# https://github.com/nerfstudio-project/viser/issues/271 +viser = ["**/node_modules/**"] + [project.scripts] viser-dev-checks = "viser.scripts.dev_checks:entrypoint" diff --git a/src/viser/_gui_api.py b/src/viser/_gui_api.py index 6408a567e..43a98d106 100644 --- a/src/viser/_gui_api.py +++ b/src/viser/_gui_api.py @@ -33,6 +33,7 @@ GuiModalHandle, GuiNotificationHandle, GuiPlotlyHandle, + GuiProgressBarHandle, GuiTabGroupHandle, GuiUploadButtonHandle, SupportsRemoveProtocol, @@ -58,6 +59,22 @@ TLiteralString = TypeVar("TLiteralString", bound=LiteralString) T = TypeVar("T") LengthTenStrTuple: TypeAlias = Tuple[str, str, str, str, str, str, str, str, str, str] +Color: TypeAlias = Literal[ + "dark", + "gray", + "red", + "pink", + "grape", + "violet", + "indigo", + "blue", + "cyan", + "green", + "lime", + "yellow", + "orange", + "teal", +] def _hex_from_hls(h: float, l: float, s: float) -> str: @@ -140,7 +157,7 @@ class GuiApi: """Interface for working with the 2D GUI in viser. Used by both our global server object, for sharing the same GUI elements - with all clients, and by invidividual client handles.""" + with all clients, and by individual client handles.""" _target_container_from_thread_id: dict[int, str] = {} """ID of container to put GUI elements into.""" @@ -326,6 +343,10 @@ def _set_container_id(self, container_id: str) -> None: """Set container ID associated with the current thread.""" self._target_container_from_thread_id[threading.get_ident()] = container_id + def reset(self) -> None: + """Reset the GUI.""" + self._websock_interface.queue_message(_messages.ResetGuiMessage()) + def set_panel_label(self, label: str | None) -> None: """Set the main label that appears in the GUI panel. @@ -638,25 +659,7 @@ def add_button( disabled: bool = False, visible: bool = True, hint: str | None = None, - color: ( - Literal[ - "dark", - "gray", - "red", - "pink", - "grape", - "violet", - "indigo", - "blue", - "cyan", - "green", - "lime", - "yellow", - "orange", - "teal", - ] - | None - ) = None, + color: Color | None = None, icon: IconName | None = None, order: float | None = None, ) -> GuiButtonHandle: @@ -704,25 +707,7 @@ def add_upload_button( disabled: bool = False, visible: bool = True, hint: str | None = None, - color: ( - Literal[ - "dark", - "gray", - "red", - "pink", - "grape", - "violet", - "indigo", - "blue", - "cyan", - "green", - "lime", - "yellow", - "orange", - "teal", - ] - | None - ) = None, + color: Color | None = None, icon: IconName | None = None, mime_type: str = "*/*", order: float | None = None, @@ -1229,6 +1214,49 @@ def add_dropdown( _impl_options=tuple(options), ) + def add_progress_bar( + self, + value: float, + visible: bool = True, + animated: bool = False, + color: Color | None = None, + order: float | None = None, + ) -> GuiProgressBarHandle: + """Add a progress bar to the GUI. + + Args: + value: Value of the progress bar. (0 - 100) + visible: Whether the progress bar is visible. + animated: Whether the progress bar is in a loading state (animated, striped). + color: The color of the progress bar. + order: Optional ordering, smallest values will be displayed first. + + Returns: + A handle that can be used to interact with the GUI element. + """ + assert value >= 0 and value <= 100 + handle = GuiProgressBarHandle( + _gui_api=self, + _id=_make_unique_id(), + _visible=visible, + _animated=animated, + _parent_container_id=self._get_container_id(), + _order=_apply_default_order(order), + _value=value, + ) + self._websock_interface.queue_message( + _messages.GuiAddProgressBarMessage( + order=handle._order, + id=handle._id, + value=value, + animated=animated, + color=color, + container_id=handle._parent_container_id, + visible=visible, + ) + ) + return handle + def add_slider( self, label: str, diff --git a/src/viser/_gui_handles.py b/src/viser/_gui_handles.py index 33cb9756d..6d8cfff2b 100644 --- a/src/viser/_gui_handles.py +++ b/src/viser/_gui_handles.py @@ -1,9 +1,9 @@ from __future__ import annotations +import base64 import dataclasses import re import time -import urllib.parse import uuid import warnings from pathlib import Path @@ -15,15 +15,7 @@ from ._icons import svg_from_icon from ._icons_enum import IconName -from ._messages import ( - RemoveNotificationMessage, - GuiCloseModalMessage, - GuiRemoveMessage, - GuiUpdateMessage, - Message, - NotificationMessage, - UpdateNotificationMessage, -) +from ._messages import GuiCloseModalMessage, GuiRemoveMessage, GuiUpdateMessage, Message, RemoveNotificationMessage from ._scene_api import _encode_image_base64 from .infra import ClientId @@ -586,9 +578,9 @@ def _get_data_url(url: str, image_root: Path | None) -> str: image_root = Path(__file__).parent try: image = iio.imread(image_root / url) - data_uri = _encode_image_base64(image, "png") - url = urllib.parse.quote(f"{data_uri[1]}") - return f"data:{data_uri[0]};base64,{url}" + media_type, binary = _encode_image_binary(image, "png") + url = base64.b64encode(binary).decode("utf-8") + return f"data:{media_type};base64,{url}" except (IOError, FileNotFoundError): warnings.warn( f"Failed to read image {url}, with image_root set to {image_root}.", @@ -608,6 +600,84 @@ def _parse_markdown(markdown: str, image_root: Path | None) -> str: return markdown +@dataclasses.dataclass +class GuiProgressBarHandle: + """Use to remove markdown.""" + + _gui_api: GuiApi + _id: str + _visible: bool + _animated: bool + _parent_container_id: str + _order: float + _value: float + + @property + def value(self) -> float: + """Current content of this progress bar element, 0 - 100. Synchronized + automatically when assigned.""" + return self._value + + @value.setter + def value(self, value: float) -> None: + assert value >= 0 and value <= 100 + self._value = value + self._gui_api._websock_interface.queue_message( + GuiUpdateMessage( + self._id, + {"value": value}, + ) + ) + + @property + def animated(self) -> bool: + """Show this progress bar as loading (animated, striped).""" + return self._animated + + @animated.setter + def animated(self, animated: bool) -> None: + self._animated = animated + self._gui_api._websock_interface.queue_message( + GuiUpdateMessage( + self._id, + {"animated": animated}, + ) + ) + + @property + def order(self) -> float: + """Read-only order value, which dictates the position of the GUI element.""" + return self._order + + @property + def visible(self) -> bool: + """Temporarily show or hide this GUI element from the visualizer. Synchronized + automatically when assigned.""" + return self._visible + + @visible.setter + def visible(self, visible: bool) -> None: + if visible == self.visible: + return + + self._gui_api._websock_interface.queue_message( + GuiUpdateMessage(self._id, {"visible": visible}) + ) + self._visible = visible + + def __post_init__(self) -> None: + """We need to register ourself after construction for callbacks to work.""" + parent = self._gui_api._container_handle_from_id[self._parent_container_id] + parent._children[self._id] = self + + def remove(self) -> None: + """Permanently remove this progress bar from the visualizer.""" + self._gui_api._websock_interface.queue_message(GuiRemoveMessage(self._id)) + + parent = self._gui_api._container_handle_from_id[self._parent_container_id] + parent._children.pop(self._id) + + @dataclasses.dataclass class GuiMarkdownHandle: """Use to remove markdown.""" @@ -615,7 +685,7 @@ class GuiMarkdownHandle: _gui_api: GuiApi _id: str _visible: bool - _parent_container_id: str # Parent. + _parent_container_id: str _order: float _image_root: Path | None _content: str | None @@ -677,7 +747,7 @@ class GuiPlotlyHandle: _gui_api: GuiApi _id: str _visible: bool - _parent_container_id: str # Parent. + _parent_container_id: str _order: float _figure: go.Figure | None _aspect: float | None @@ -745,7 +815,7 @@ def __post_init__(self) -> None: parent._children[self._id] = self def remove(self) -> None: - """Permanently remove this markdown from the visualizer.""" + """Permanently remove this figure from the visualizer.""" self._gui_api._websock_interface.queue_message(GuiRemoveMessage(self._id)) parent = self._gui_api._container_handle_from_id[self._parent_container_id] parent._children.pop(self._id) diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 2e36fe5a5..a373a4b2c 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -25,6 +25,22 @@ from . import infra, theme GuiSliderMark = TypedDict("GuiSliderMark", {"value": float, "label": NotRequired[str]}) +Color = Literal[ + "dark", + "gray", + "red", + "pink", + "grape", + "violet", + "indigo", + "blue", + "cyan", + "green", + "lime", + "yellow", + "orange", + "teal", +] class Message(infra.Message): @@ -170,7 +186,7 @@ class CameraFrustumMessage(Message): scale: float color: int image_media_type: Optional[Literal["image/jpeg", "image/png"]] - image_base64_data: Optional[str] + image_binary: Optional[bytes] @dataclasses.dataclass @@ -311,7 +327,7 @@ class SkinnedMeshMessage(MeshMessage): bone_wxyzs: Tuple[Tuple[float, float, float, float], ...] bone_positions: Tuple[Tuple[float, float, float], ...] - skin_indices: onpt.NDArray[onp.uint32] + skin_indices: onpt.NDArray[onp.uint16] skin_weights: onpt.NDArray[onp.float32] def __post_init__(self): @@ -443,8 +459,8 @@ class BackgroundImageMessage(Message): """Message for rendering a background image.""" media_type: Literal["image/jpeg", "image/png"] - base64_rgb: str - base64_depth: Optional[str] + rgb_bytes: bytes + depth_bytes: Optional[bytes] @dataclasses.dataclass @@ -453,7 +469,7 @@ class ImageMessage(Message): name: str media_type: Literal["image/jpeg", "image/png"] - base64_data: str + data: bytes render_width: float render_height: float @@ -495,6 +511,11 @@ class ResetSceneMessage(Message): """Reset scene.""" +@dataclasses.dataclass +class ResetGuiMessage(Message): + """Reset GUI.""" + + @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddFolderMessage(Message): @@ -516,6 +537,18 @@ class GuiAddMarkdownMessage(Message): visible: bool +@tag_class("GuiAddComponentMessage") +@dataclasses.dataclass +class GuiAddProgressBarMessage(Message): + order: float + id: str + value: float + animated: bool + color: Optional[Color] + container_id: str + visible: bool + + @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddPlotlyMessage(Message): @@ -571,48 +604,14 @@ class GuiAddButtonMessage(_GuiAddInputBase): # All GUI elements currently need an `value` field. # This makes our job on the frontend easier. value: bool - color: Optional[ - Literal[ - "dark", - "gray", - "red", - "pink", - "grape", - "violet", - "indigo", - "blue", - "cyan", - "green", - "lime", - "yellow", - "orange", - "teal", - ] - ] + color: Optional[Color] icon_html: Optional[str] @tag_class("GuiAddComponentMessage") @dataclasses.dataclass class GuiAddUploadButtonMessage(_GuiAddInputBase): - color: Optional[ - Literal[ - "dark", - "gray", - "red", - "pink", - "grape", - "violet", - "indigo", - "blue", - "cyan", - "green", - "lime", - "yellow", - "orange", - "teal", - ] - ] + color: Optional[Color] icon_html: Optional[str] mime_type: str diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index d273551bc..649abae94 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -1,6 +1,5 @@ from __future__ import annotations -import base64 import io import time import warnings @@ -73,11 +72,11 @@ def _encode_rgb(rgb: RgbTupleOrArray) -> int: return int(rgb_fixed[0] * (256**2) + rgb_fixed[1] * 256 + rgb_fixed[2]) -def _encode_image_base64( +def _encode_image_binary( image: onp.ndarray, format: Literal["png", "jpeg"], jpeg_quality: int | None = None, -) -> tuple[Literal["image/png", "image/jpeg"], str]: +) -> tuple[Literal["image/png", "image/jpeg"], bytes]: media_type: Literal["image/png", "image/jpeg"] image = _colors_to_uint8(image) with io.BytesIO() as data_buffer: @@ -94,10 +93,8 @@ def _encode_image_base64( ) else: assert_never(format) - - base64_data = base64.b64encode(data_buffer.getvalue()).decode("ascii") - - return media_type, base64_data + binary = data_buffer.getvalue() + return media_type, binary TVector = TypeVar("TVector", bound=tuple) @@ -115,7 +112,7 @@ class SceneApi: """Interface for adding 3D primitives to the scene. Used by both our global server object, for sharing the same GUI elements - with all clients, and by invidividual client handles.""" + with all clients, and by individual client handles.""" def __init__( self, @@ -444,12 +441,12 @@ def add_camera_frustum( """ if image is not None: - media_type, base64_data = _encode_image_base64( + media_type, binary = _encode_image_binary( image, format, jpeg_quality=jpeg_quality ) else: media_type = None - base64_data = None + binary = None self._websock_interface.queue_message( _messages.CameraFrustumMessage( @@ -460,7 +457,7 @@ def add_camera_frustum( # (255, 255, 255) => 0xffffff, etc color=_encode_rgb(color), image_media_type=media_type, - image_base64_data=base64_data, + image_binary=binary, ) ) return CameraFrustumHandle._make(self, name, wxyz, position, visible) @@ -946,8 +943,7 @@ def _add_gaussian_splats( position: Tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, ) -> GaussianSplatHandle: - """Add a model to render using Gaussian Splatting. Does not yet support - spherical harmonics. + """Add a model to render using Gaussian Splatting. **Work-in-progress.** This feature is experimental and still under development. It may be changed or removed. @@ -971,7 +967,8 @@ def _add_gaussian_splats( assert opacities.shape == (num_gaussians, 1) assert covariances.shape == (num_gaussians, 3, 3) - # Get cholesky factor of covariance. + # Get cholesky factor of covariance. This helps retain precision when + # we convert to float16. cov_cholesky_triu = ( onp.linalg.cholesky(covariances.astype(onp.float64) + onp.ones(3) * 1e-7) .swapaxes(-1, -2) # tril => triu @@ -980,10 +977,14 @@ def _add_gaussian_splats( buffer = onp.concatenate( [ # First texelFetch. + # - xyz (96 bits): centers. centers.astype(onp.float32).view(onp.uint8), + # - w (32 bits): this is reserved for use by the renderer. onp.zeros((num_gaussians, 4), dtype=onp.uint8), # Second texelFetch. + # - xyz (96 bits): upper-triangular Cholesky factor of covariance. cov_cholesky_triu.astype(onp.float16).copy().view(onp.uint8), + # - w (32 bits): rgba. _colors_to_uint8(rgbs), _colors_to_uint8(opacities), ], @@ -1095,13 +1096,13 @@ def set_background_image( jpeg_quality: Quality of the jpeg image (if jpeg format is used). depth: Optional depth image to use to composite background with scene elements. """ - media_type, base64_data = _encode_image_base64( + media_type, rgb_bytes = _encode_image_binary( image, format, jpeg_quality=jpeg_quality ) # Encode depth if provided. We use a 3-channel PNG to represent a fixed point # depth at each pixel. - depth_base64data = None + depth_bytes = None if depth is not None: # Convert to fixed-point. # We'll support from 0 -> (2^24 - 1) / 100_000. @@ -1116,15 +1117,13 @@ def set_background_image( assert intdepth.shape == (*depth.shape[:2], 4) with io.BytesIO() as data_buffer: iio.imwrite(data_buffer, intdepth[:, :, :3], extension=".png") - depth_base64data = base64.b64encode(data_buffer.getvalue()).decode( - "ascii" - ) + depth_bytes = data_buffer.getvalue() self._websock_interface.queue_message( _messages.BackgroundImageMessage( media_type=media_type, - base64_rgb=base64_data, - base64_depth=depth_base64data, + rgb_bytes=rgb_bytes, + depth_bytes=depth_bytes, ) ) @@ -1158,14 +1157,14 @@ def add_image( Handle for manipulating scene node. """ - media_type, base64_data = _encode_image_base64( + media_type, binary = _encode_image_binary( image, format, jpeg_quality=jpeg_quality ) self._websock_interface.queue_message( _messages.ImageMessage( name=name, media_type=media_type, - base64_data=base64_data, + data=binary, render_width=render_width, render_height=render_height, ) diff --git a/src/viser/_viser.py b/src/viser/_viser.py index 675ceb1bd..43485669e 100644 --- a/src/viser/_viser.py +++ b/src/viser/_viser.py @@ -533,6 +533,7 @@ def _(conn: infra.WebsockClientConnection) -> None: self.request_share_url() self.scene.reset() + self.gui.reset() self.gui.set_panel_label(label) def get_host(self) -> str: @@ -697,7 +698,11 @@ def _start_scene_recording(self) -> RecordHandle: **Work-in-progress.** This API may be changed or removed. """ - return self._websock_server.start_recording( + recorder = self._websock_server.start_recording( # Don't record GUI messages. This feels brittle. filter=lambda message: "Gui" not in type(message).__name__ ) + # Insert current scene state. + for message in self._websock_server._broadcast_buffer.message_from_id.values(): + recorder._insert_message(message) + return recorder diff --git a/src/viser/client/package.json b/src/viser/client/package.json index 2630f2390..fad1d82b6 100644 --- a/src/viser/client/package.json +++ b/src/viser/client/package.json @@ -10,6 +10,7 @@ "@mantine/vanilla-extract": "^7.6.2", "@mdx-js/mdx": "^3.0.1", "@mdx-js/react": "^3.0.1", + "@msgpack/msgpack": "^3.0.0-beta2", "@react-three/drei": "^9.64.0", "@react-three/fiber": "^8.12.0", "@tabler/icons-react": "^3.1.0", @@ -23,11 +24,12 @@ "clsx": "^2.1.0", "colortranslator": "^4.1.0", "dayjs": "^1.11.10", + "detect-browser": "^5.3.0", + "fflate": "^0.8.2", "hold-event": "^1.1.0", "immer": "^10.0.4", "its-fine": "^1.2.5", "mantine-react-table": "^2.0.0-beta.0", - "msgpackr": "^1.10.2", "postcss": "^8.4.38", "prettier": "^3.0.3", "react": "^18.2.0", diff --git a/src/viser/client/public/model.splat b/src/viser/client/public/model.splat deleted file mode 100644 index 63c303115..000000000 Binary files a/src/viser/client/public/model.splat and /dev/null differ diff --git a/src/viser/client/src/App.tsx b/src/viser/client/src/App.tsx index a2dfd1c6e..17e7c533e 100644 --- a/src/viser/client/src/App.tsx +++ b/src/viser/client/src/App.tsx @@ -6,10 +6,10 @@ import "./App.css"; import { Notifications } from "@mantine/notifications"; import { - AdaptiveDpr, - AdaptiveEvents, CameraControls, Environment, + PerformanceMonitor, + Stats, } from "@react-three/drei"; import * as THREE from "three"; import { Canvas, useThree, useFrame } from "@react-three/fiber"; @@ -40,14 +40,15 @@ import { Titlebar } from "./Titlebar"; import { ViserModal } from "./Modal"; import { useSceneTreeState } from "./SceneTreeState"; import { GetRenderRequestMessage, Message } from "./WebsocketMessages"; -import { makeThrottledMessageSender } from "./WebsocketFunctions"; +import { useThrottledMessageSender } from "./WebsocketFunctions"; import { useDisclosure } from "@mantine/hooks"; import { rayToViserCoords } from "./WorldTransformUtils"; import { ndcFromPointerXy, opencvXyFromPointerXy } from "./ClickUtils"; import { theme } from "./AppTheme"; -import { GaussianSplatsContext } from "./Splatting/GaussianSplats"; import { FrameSynchronizedMessageHandler } from "./MessageHandler"; import { PlaybackFromFile } from "./FilePlayback"; +import { SplatRenderContext } from "./Splatting/GaussianSplats"; +import { BrowserWarning } from "./BrowserWarning"; export type ViewerContextContents = { messageSource: "websocket" | "file_playback"; @@ -57,7 +58,7 @@ export type ViewerContextContents = { // Useful references. // TODO: there's really no reason these all need to be their own ref objects. // We could have just one ref to a global mutable struct. - websocketRef: React.MutableRefObject; + sendMessageRef: React.MutableRefObject<(message: Message) => void>; canvasRef: React.MutableRefObject; sceneRef: React.MutableRefObject; cameraRef: React.MutableRefObject; @@ -132,17 +133,27 @@ function ViewerRoot() { servers.length >= 1 ? servers[0] : getDefaultServerFromUrl(); // Playback mode for embedding viser. - const playbackPath = new URLSearchParams(window.location.search).get( - "playbackPath", - ); - console.log(playbackPath); + const searchParams = new URLSearchParams(window.location.search); + const playbackPath = searchParams.get("playbackPath"); + const darkMode = searchParams.get("darkMode") !== null; + const showStats = searchParams.get("showStats") !== null; // Values that can be globally accessed by components in a viewer. + const nodeRefFromName = React.useRef<{ + [name: string]: undefined | THREE.Object3D; + }>({}); const viewer: ViewerContextContents = { messageSource: playbackPath === null ? "websocket" : "file_playback", - useSceneTree: useSceneTreeState(), + useSceneTree: useSceneTreeState(nodeRefFromName), useGui: useGuiState(initialServer), - websocketRef: React.useRef(null), + sendMessageRef: React.useRef( + playbackPath == null + ? (message) => + console.log( + `Tried to send ${message.type} but websocket is not connected!`, + ) + : () => null, + ), canvasRef: React.useRef(null), sceneRef: React.useRef(null), cameraRef: React.useRef(null), @@ -161,7 +172,7 @@ function ViewerRoot() { })(), }, }), - nodeRefFromName: React.useRef({}), + nodeRefFromName: nodeRefFromName, messageQueueRef: React.useRef([]), getRenderRequestState: React.useRef("ready"), getRenderRequest: React.useRef(null), @@ -175,27 +186,32 @@ function ViewerRoot() { skinnedMeshState: React.useRef({}), }; + // Set dark default if specified in URL. + if (darkMode) viewer.useGui.getState().theme.dark_mode = darkMode; + return ( - {viewer.messageSource === "websocket" ? ( - - ) : null} - {viewer.messageSource === "file_playback" ? ( - - ) : null} - + + {viewer.messageSource === "websocket" ? ( + + ) : null} + {viewer.messageSource === "file_playback" ? ( + + ) : null} + {showStats ? : null} + ); } -function ViewerContents() { +function ViewerContents({ children }: { children: React.ReactNode }) { const viewer = React.useContext(ViewerContext)!; - const dark_mode = viewer.useGui((state) => state.theme.dark_mode); + const darkMode = viewer.useGui((state) => state.theme.dark_mode); const colors = viewer.useGui((state) => state.theme.colors); - const control_layout = viewer.useGui((state) => state.theme.control_layout); + const controlLayout = viewer.useGui((state) => state.theme.control_layout); return ( <> - + + {children} + ({ - backgroundColor: dark_mode ? theme.colors.dark[9] : "#fff", + backgroundColor: darkMode ? theme.colors.dark[9] : "#fff", flexGrow: 1, overflow: "hidden", height: "100%", })} > - - - - - + + + {viewer.useGui((state) => state.theme.show_logo) && viewer.messageSource == "websocket" ? ( ) : null} {viewer.messageSource == "websocket" ? ( - + ) : null} @@ -269,10 +283,7 @@ function ViewerContents() { function ViewerCanvas({ children }: { children: React.ReactNode }) { const viewer = React.useContext(ViewerContext)!; - const sendClickThrottled = makeThrottledMessageSender( - viewer.websocketRef, - 20, - ); + const sendClickThrottled = useThrottledMessageSender(20); const theme = useMantineTheme(); return ( @@ -285,7 +296,6 @@ function ViewerCanvas({ children }: { children: React.ReactNode }) { width: "100%", height: "100%", }} - performance={{ min: 0.95 }} ref={viewer.canvasRef} // Handle scene click events (onPointerDown, onPointerMove, onPointerUp) onPointerDown={(e) => { @@ -433,12 +443,13 @@ function ViewerCanvas({ children }: { children: React.ReactNode }) { > {children} - - + - - + + + + state.setDpr); + return ( + { + const max = Math.min(refreshrate * 0.9, 85); + const min = Math.max(max * 0.5, 38); + return [min, max]; + }} + onChange={({ factor, fps, refreshrate }) => { + const dpr = window.devicePixelRatio * (0.2 + 0.8 * factor); + console.log( + `[Performance] Setting DPR to ${dpr}; FPS=${fps}/${refreshrate}`, + ); + setDpr(dpr); + }} + /> + ); +} + /* HTML Canvas, for drawing 2D. */ function Viewer2DCanvas() { const viewer = React.useContext(ViewerContext)!; diff --git a/src/viser/client/src/BrowserWarning.tsx b/src/viser/client/src/BrowserWarning.tsx new file mode 100644 index 000000000..7e29780f8 --- /dev/null +++ b/src/viser/client/src/BrowserWarning.tsx @@ -0,0 +1,45 @@ +import { notifications } from "@mantine/notifications"; +import { detect } from "detect-browser"; +import { useEffect } from "react"; + +export function BrowserWarning() { + useEffect(() => { + const browser = detect(); + + // Browser version are based loosely on support for SIMD, OffscreenCanvas. + // + // https://caniuse.com/?search=simd + // https://caniuse.com/?search=OffscreenCanvas + if (browser === null || browser.version === null) { + console.log("Failed to detect browser"); + notifications.show({ + title: "Could not detect browser version", + message: + "Your browser version could not be detected. It may not be supported.", + autoClose: false, + color: "red", + }); + } else { + const version = parseFloat(browser.version); + console.log(`Detected ${browser.name} version ${version}`); + if ( + (browser.name === "chrome" && version < 91) || + (browser.name === "edge" && version < 91) || + (browser.name === "firefox" && version < 89) || + (browser.name === "opera" && version < 77) || + (browser.name === "safari" && version < 16.4) + ) + notifications.show({ + title: "Unsuppported browser", + message: `Your browser (${ + browser.name.slice(0, 1).toUpperCase() + browser.name.slice(1) + }/${ + browser.version + }) is outdated, which may cause problems. Consider updating.`, + autoClose: false, + color: "red", + }); + } + }); + return null; +} diff --git a/src/viser/client/src/CameraControls.tsx b/src/viser/client/src/CameraControls.tsx index 0eda6b611..a9c3ea7cf 100644 --- a/src/viser/client/src/CameraControls.tsx +++ b/src/viser/client/src/CameraControls.tsx @@ -1,5 +1,4 @@ import { ViewerContext } from "./App"; -import { makeThrottledMessageSender } from "./WebsocketFunctions"; import { CameraControls } from "@react-three/drei"; import { useThree } from "@react-three/fiber"; import * as holdEvent from "hold-event"; @@ -7,15 +6,13 @@ import React, { useContext, useRef } from "react"; import { PerspectiveCamera } from "three"; import * as THREE from "three"; import { computeT_threeworld_world } from "./WorldTransformUtils"; +import { useThrottledMessageSender } from "./WebsocketFunctions"; export function SynchronizedCameraControls() { const viewer = useContext(ViewerContext)!; const camera = useThree((state) => state.camera as PerspectiveCamera); - const sendCameraThrottled = makeThrottledMessageSender( - viewer.websocketRef, - 20, - ); + const sendCameraThrottled = useThrottledMessageSender(20); // Helper for resetting camera poses. const initialCameraRef = useRef<{ diff --git a/src/viser/client/src/ControlPanel/ControlPanel.tsx b/src/viser/client/src/ControlPanel/ControlPanel.tsx index 92c130a44..59c515aa0 100644 --- a/src/viser/client/src/ControlPanel/ControlPanel.tsx +++ b/src/viser/client/src/ControlPanel/ControlPanel.tsx @@ -38,7 +38,6 @@ import BottomPanel from "./BottomPanel"; import FloatingPanel from "./FloatingPanel"; import { ThemeConfigurationMessage } from "../WebsocketMessages"; import SidebarPanel from "./SidebarPanel"; -import { sendWebsocketMessage } from "../WebsocketFunctions"; // Must match constant in Python. const ROOT_CONTAINER_ID = "root"; @@ -270,7 +269,7 @@ function ShareButton() {