diff --git a/docs/source/gui_handles.md b/docs/source/gui_handles.md index 623ea9619..827166d7d 100644 --- a/docs/source/gui_handles.md +++ b/docs/source/gui_handles.md @@ -10,7 +10,7 @@ clients. When a GUI element is added to a client (for example, via -.. autoapiclass:: viser.GuiHandle +.. autoapiclass:: viser.GuiInputHandle :members: :undoc-members: :inherited-members: diff --git a/examples/03_gui_callbacks.py b/examples/03_gui_callbacks.py index 68669d466..dba1795fb 100644 --- a/examples/03_gui_callbacks.py +++ b/examples/03_gui_callbacks.py @@ -72,7 +72,7 @@ def draw_points() -> None: gui_num_points.on_update(lambda _: draw_points()) @gui_reset_scene.on_click - def _(_: viser.GuiButtonHandle) -> None: + def _(_) -> None: """Reset the scene when the reset button is clicked.""" gui_show.value = True gui_location.value = 0.0 diff --git a/examples/08_smplx_visualizer.py b/examples/08_smplx_visualizer.py index cce06a480..5e1470420 100644 --- a/examples/08_smplx_visualizer.py +++ b/examples/08_smplx_visualizer.py @@ -36,6 +36,7 @@ def main( ext: Literal["npz", "pkl"] = "npz", ) -> None: server = viser.ViserServer() + server.configure_theme(control_layout="collapsible", dark_mode=True) model = smplx.create( model_path=str(model_path), model_type=model_type, @@ -114,10 +115,10 @@ def main( class GuiElements: """Structure containing handles for reading from GUI elements.""" - gui_rgb: viser.GuiHandle[Tuple[int, int, int]] - gui_wireframe: viser.GuiHandle[bool] - gui_betas: List[viser.GuiHandle[float]] - gui_joints: List[viser.GuiHandle[Tuple[float, float, float]]] + gui_rgb: viser.GuiInputHandle[Tuple[int, int, int]] + gui_wireframe: viser.GuiInputHandle[bool] + gui_betas: List[viser.GuiInputHandle[float]] + gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] changed: bool """This flag will be flipped to True whenever the mesh needs to be re-generated.""" @@ -197,7 +198,7 @@ def _(_): joint.value = tf.SO3(wxyz=quat).log() sync_transform_controls() - gui_joints: List[viser.GuiHandle[Tuple[float, float, float]]] = [] + gui_joints: List[viser.GuiInputHandle[Tuple[float, float, float]]] = [] for i in range(num_body_joints + 1): gui_joint = server.add_gui_vector3( label=smplx.joint_names.JOINT_NAMES[i], diff --git a/examples/09_urdf_visualizer.py b/examples/09_urdf_visualizer.py index 485e8d214..40eaa02ba 100644 --- a/examples/09_urdf_visualizer.py +++ b/examples/09_urdf_visualizer.py @@ -27,7 +27,7 @@ def main(urdf_path: Path) -> None: urdf = ViserUrdf(server, urdf_path) # Create joint angle sliders. - gui_joints: List[viser.GuiHandle[float]] = [] + gui_joints: List[viser.GuiInputHandle[float]] = [] initial_angles: List[float] = [] for joint_name, (lower, upper) in urdf.get_actuated_joint_limits().items(): lower = lower if lower is not None else -onp.pi diff --git a/examples/13_theming.py b/examples/13_theming.py index 826aa8cb9..9c4c2da22 100644 --- a/examples/13_theming.py +++ b/examples/13_theming.py @@ -1,6 +1,9 @@ +# mypy: disable-error-code="arg-type" +# +# Waiting on PEP 675 support in mypy. https://github.com/python/mypy/issues/12554 """Theming -Viser is adding support for theming. Work-in-progress. +Viser includes support for light theming. """ import time @@ -33,16 +36,34 @@ image_alt="NerfStudio Logo", href="https://docs.nerf.studio/", ) - -# image = None - titlebar_theme = TitlebarConfig(buttons=buttons, image=image) -server.configure_theme( - dark_mode=True, titlebar_content=titlebar_theme, control_layout="fixed" +server.add_gui_markdown( + "Viser includes support for light theming via the `.configure_theme()` method." +) + +# GUI elements for controllable values. +titlebar = server.add_gui_checkbox("Titlebar", initial_value=True) +dark_mode = server.add_gui_checkbox("Dark mode", initial_value=True) +control_layout = server.add_gui_dropdown( + "Control layout", ("floating", "fixed", "collapsible") ) +brand_color = server.add_gui_rgb("Brand color", (230, 180, 30)) +synchronize = server.add_gui_button("Apply theme") + + +def synchronize_theme() -> None: + server.configure_theme( + dark_mode=dark_mode.value, + titlebar_content=titlebar_theme if titlebar.value else None, + control_layout=control_layout.value, + brand_color=brand_color.value, + ) + server.world_axes.visible = True + -server.world_axes.visible = True +synchronize.on_click(lambda _: synchronize_theme()) +synchronize_theme() while True: time.sleep(10.0) diff --git a/examples/16_modal.py b/examples/16_modal.py index 80b0721ef..88afc612a 100644 --- a/examples/16_modal.py +++ b/examples/16_modal.py @@ -14,24 +14,21 @@ def main(): def _(client: viser.ClientHandle) -> None: with client.add_gui_modal("Modal example"): client.add_gui_markdown( - markdown="**The slider below determines how many modals will appear...**" + "**The input below determines the title of the modal...**" ) - gui_slider = client.add_gui_slider( - "Slider", - min=1, - max=10, - step=1, - initial_value=1, + gui_title = client.add_gui_text( + "Title", + initial_value="My Modal", ) modal_button = client.add_gui_button("Show more modals") @modal_button.on_click - def _(_: viser.GuiButtonHandle) -> None: - for i in range(gui_slider.value): - with client.add_gui_modal(f"Modal #{i}"): - client.add_gui_markdown("This is a modal!") + def _(_) -> None: + with client.add_gui_modal(gui_title.value) as modal: + client.add_gui_markdown("This is content inside the modal!") + client.add_gui_button("Close").on_click(lambda _: modal.close()) while True: time.sleep(0.15) diff --git a/examples/17_background_composite.py b/examples/17_background_composite.py new file mode 100644 index 000000000..5b9e40e99 --- /dev/null +++ b/examples/17_background_composite.py @@ -0,0 +1,36 @@ +# mypy: disable-error-code="var-annotated" +"""Background image example with depth compositing. + +In this example, we show how to use a background image with depth compositing. This can +be useful when we want a 2D image to occlude 3D geometry, such as for NeRF rendering. +""" + +import time + +import numpy as onp +import trimesh +import trimesh.creation + +import viser + +server = viser.ViserServer() + + +img = onp.random.randint(0, 255, size=(1000, 1000, 3), dtype=onp.uint8) +depth = onp.ones((1000, 1000, 1), dtype=onp.float32) + +# Make a square middle portal. +depth[250:750, 250:750, :] = 10.0 +img[250:750, 250:750, :] = 255 + +mesh = trimesh.creation.box((0.5, 0.5, 0.5)) +server.add_mesh_trimesh( + name="/cube", + mesh=mesh, + position=(0, 0, 0.0), +) +server.set_background_image(img, depth=depth) + + +while True: + time.sleep(1.0) diff --git a/pyproject.toml b/pyproject.toml index ddb587f60..6c92ef624 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "viser" -version = "0.0.17" +version = "0.1.0" description = "3D visualization + Python" readme = "README.md" license = { text="MIT" } @@ -91,6 +91,7 @@ select = [ "PLW", # Pylint warnings. ] ignore = [ + "E741", # Ambiguous variable name. (l, O, or I) "E501", # Line too long. "F722", # Forward annotation false positive from jaxtyping. Should be caught by pyright. "F821", # Forward annotation false positive from jaxtyping. Should be caught by pyright. diff --git a/viser/__init__.py b/viser/__init__.py index 08ba9282b..8427fd089 100644 --- a/viser/__init__.py +++ b/viser/__init__.py @@ -1,13 +1,17 @@ +from typing import TYPE_CHECKING + from ._gui_handles import GuiButtonGroupHandle as GuiButtonGroupHandle from ._gui_handles import GuiButtonHandle as GuiButtonHandle from ._gui_handles import GuiDropdownHandle as GuiDropdownHandle +from ._gui_handles import GuiEvent as GuiEvent from ._gui_handles import GuiFolderHandle as GuiFolderHandle -from ._gui_handles import GuiHandle as GuiHandle +from ._gui_handles import GuiInputHandle as GuiInputHandle from ._gui_handles import GuiMarkdownHandle as GuiMarkdownHandle from ._gui_handles import GuiTabGroupHandle as GuiTabGroupHandle from ._gui_handles import GuiTabHandle as GuiTabHandle from ._icons_enum import Icon as Icon from ._scene_handles import CameraFrustumHandle as CameraFrustumHandle +from ._scene_handles import ClickEvent as ClickEvent from ._scene_handles import FrameHandle as FrameHandle from ._scene_handles import Gui3dContainerHandle as Gui3dContainerHandle from ._scene_handles import ImageHandle as ImageHandle @@ -19,3 +23,7 @@ from ._viser import CameraHandle as CameraHandle from ._viser import ClientHandle as ClientHandle from ._viser import ViserServer as ViserServer + +if not TYPE_CHECKING: + # Backwards compatibility. + GuiHandle = GuiInputHandle diff --git a/viser/_client_autobuild.py b/viser/_client_autobuild.py index 35108253a..b1b1b08fd 100644 --- a/viser/_client_autobuild.py +++ b/viser/_client_autobuild.py @@ -9,20 +9,24 @@ build_dir = client_dir / "build" -def _check_process(process_name: str) -> bool: - """ - Check if a process is running - """ +def _check_viser_yarn_running() -> bool: + """Returns True if the viewer client has been launched via `yarn start`.""" for process in psutil.process_iter(): - if process_name == process.name(): - return True + try: + if Path(process.cwd()).as_posix().endswith("viser/client") and any( + [part.endswith("yarn") for part in process.cmdline()] + ): + return True + except (psutil.AccessDenied, psutil.ZombieProcess): + pass return False def ensure_client_is_built() -> None: - """Ensure that the client is built.""" + """Ensure that the client is built or already running.""" if not (client_dir / "src").exists(): + # Can't build client. assert (build_dir / "index.html").exists(), ( "Something went wrong! At least one of the client source or build" " directories should be present." @@ -31,14 +35,15 @@ def ensure_client_is_built() -> None: # Do we need to re-trigger a build? build = False - if not (build_dir / "index.html").exists(): - rich.print("[bold](viser)[/bold] No client build found. Building now...") - build = True - elif _check_process("Viser Viewer"): + if _check_viser_yarn_running(): + # Don't run `yarn build` if `yarn start` is already running. rich.print( - "[bold](viser)[/bold] A Viser viewer is already running. Skipping build check..." + "[bold](viser)[/bold] The Viser viewer looks like it has been launched via `yarn start`. Skipping build check..." ) build = False + elif not (build_dir / "index.html").exists(): + rich.print("[bold](viser)[/bold] No client build found. Building now...") + build = True elif _modified_time_recursive(client_dir / "src") > _modified_time_recursive( build_dir ): diff --git a/viser/_gui_api.py b/viser/_gui_api.py index c21ba5c87..1cad89397 100644 --- a/viser/_gui_api.py +++ b/viser/_gui_api.py @@ -5,6 +5,7 @@ from __future__ import annotations import abc +import dataclasses import re import threading import time @@ -25,21 +26,27 @@ import imageio.v3 as iio import numpy as onp -from typing_extensions import LiteralString +from typing_extensions import Literal, LiteralString from . import _messages from ._gui_handles import ( GuiButtonGroupHandle, GuiButtonHandle, + GuiContainerProtocol, GuiDropdownHandle, + GuiEvent, GuiFolderHandle, - GuiHandle, + GuiInputHandle, GuiMarkdownHandle, GuiModalHandle, GuiTabGroupHandle, + SupportsRemoveProtocol, _GuiHandleState, + _GuiInputHandle, _make_unique_id, ) +from ._icons import base64_from_icon +from ._icons_enum import Icon from ._message_api import MessageApi, _encode_image_base64, cast_vector if TYPE_CHECKING: @@ -112,6 +119,11 @@ def _parse_markdown(markdown: str, image_root: Optional[Path]) -> str: return markdown +@dataclasses.dataclass +class _RootGuiContainer: + _children: Dict[str, SupportsRemoveProtocol] + + class GuiApi(abc.ABC): _target_container_from_thread_id: Dict[int, str] = {} """ID of container to put GUI elements into.""" @@ -119,6 +131,51 @@ class GuiApi(abc.ABC): def __init__(self) -> None: super().__init__() + self._gui_handle_from_id: Dict[str, _GuiInputHandle[Any]] = {} + self._container_handle_from_id: Dict[str, GuiContainerProtocol] = { + "root": _RootGuiContainer({}) + } + self._get_api()._message_handler.register_handler( + _messages.GuiUpdateMessage, self._handle_gui_updates + ) + + def _handle_gui_updates( + self, client_id: ClientId, message: _messages.GuiUpdateMessage + ) -> None: + """Callback for handling GUI messages.""" + handle = self._gui_handle_from_id.get(message.id, None) + if handle is None: + return + + handle_state = handle._impl + value = handle_state.typ(message.value) + + # Only call update when value has actually changed. + if not handle_state.is_button and value == handle_state.value: + return + + # Update state. + with self._get_api()._atomic_lock: + handle_state.value = value + handle_state.update_timestamp = time.time() + + # Trigger callbacks. + for cb in handle_state.update_cb: + from ._viser import ClientHandle, ViserServer + + # Get the handle of the client that triggered this event. + api = self._get_api() + if isinstance(api, ClientHandle): + client = api + elif isinstance(api, ViserServer): + client = api.get_clients()[client_id] + else: + assert False + + cb(GuiEvent(client_id, client, handle)) + if handle_state.sync_cb is not None: + handle_state.sync_cb(client_id, value) + def _get_container_id(self) -> str: """Get container ID associated with the current thread.""" return self._target_container_from_thread_id.get(threading.get_ident(), "root") @@ -155,7 +212,8 @@ def add_gui_folder(self, label: str) -> GuiFolderHandle: ) return GuiFolderHandle( _gui_api=self, - _container_id=folder_container_id, + _id=folder_container_id, + _parent_container_id=self._get_container_id(), ) def add_gui_modal( @@ -170,12 +228,11 @@ def add_gui_modal( order=time.time(), id=modal_container_id, title=title, - container_id=self._get_container_id(), ) ) return GuiModalHandle( _gui_api=self, - _container_id=modal_container_id, + _id=modal_container_id, ) def add_gui_tab_group(self) -> GuiTabGroupHandle: @@ -185,7 +242,7 @@ def add_gui_tab_group(self) -> GuiTabGroupHandle: _tab_group_id=tab_group_id, _labels=[], _icons_base64=[], - _tab_container_ids=[], + _tabs=[], _gui_api=self, _container_id=self._get_container_id(), ) @@ -209,6 +266,7 @@ def add_gui_markdown( _gui_api=self, _id=markdown_id, _visible=True, + _container_id=self._get_container_id(), ) def add_gui_button( @@ -217,6 +275,25 @@ def add_gui_button( disabled: bool = False, visible: bool = True, hint: Optional[str] = None, + color: Optional[ + Literal[ + "dark", + "gray", + "red", + "pink", + "grape", + "violet", + "indigo", + "blue", + "cyan", + "green", + "lime", + "yellow", + "orange", + "teal", + ] + ] = None, + icon: Optional[Icon] = None, ) -> GuiButtonHandle: """Add a button to the GUI. The value of this input is set to `True` every time it is clicked; to detect clicks, we can manually set it back to `False`.""" @@ -233,6 +310,8 @@ def add_gui_button( container_id=self._get_container_id(), hint=hint, initial_value=False, + color=color, + icon_base64=None if icon is None else base64_from_icon(icon), ), disabled=disabled, visible=visible, @@ -302,7 +381,7 @@ def add_gui_checkbox( disabled: bool = False, visible: bool = True, hint: Optional[str] = None, - ) -> GuiHandle[bool]: + ) -> GuiInputHandle[bool]: """Add a checkbox to the GUI.""" assert isinstance(initial_value, bool) id = _make_unique_id() @@ -327,7 +406,7 @@ def add_gui_text( disabled: bool = False, visible: bool = True, hint: Optional[str] = None, - ) -> GuiHandle[str]: + ) -> GuiInputHandle[str]: """Add a text input to the GUI.""" assert isinstance(initial_value, str) id = _make_unique_id() @@ -355,7 +434,7 @@ def add_gui_number( disabled: bool = False, visible: bool = True, hint: Optional[str] = None, - ) -> GuiHandle[IntOrFloat]: + ) -> GuiInputHandle[IntOrFloat]: """Add a number input to the GUI, with user-specifiable bound and precision parameters.""" assert isinstance(initial_value, (int, float)) @@ -404,7 +483,7 @@ def add_gui_vector2( disabled: bool = False, visible: bool = True, hint: Optional[str] = None, - ) -> GuiHandle[Tuple[float, float]]: + ) -> GuiInputHandle[Tuple[float, float]]: """Add a length-2 vector input to the GUI.""" initial_value = cast_vector(initial_value, 2) min = cast_vector(min, 2) if min is not None else None @@ -449,7 +528,7 @@ def add_gui_vector3( disabled: bool = False, visible: bool = True, hint: Optional[str] = None, - ) -> GuiHandle[Tuple[float, float, float]]: + ) -> GuiInputHandle[Tuple[float, float, float]]: """Add a length-3 vector input to the GUI.""" initial_value = cast_vector(initial_value, 2) min = cast_vector(min, 3) if min is not None else None @@ -548,7 +627,7 @@ def add_gui_slider( disabled: bool = False, visible: bool = True, hint: Optional[str] = None, - ) -> GuiHandle[IntOrFloat]: + ) -> GuiInputHandle[IntOrFloat]: """Add a slider to the GUI. Types of the min, max, step, and initial value should match.""" assert max >= min if step > max - min: @@ -595,7 +674,7 @@ def add_gui_rgb( disabled: bool = False, visible: bool = True, hint: Optional[str] = None, - ) -> GuiHandle[Tuple[int, int, int]]: + ) -> GuiInputHandle[Tuple[int, int, int]]: """Add an RGB picker to the GUI.""" id = _make_unique_id() return self._create_gui_input( @@ -619,7 +698,7 @@ def add_gui_rgba( disabled: bool = False, visible: bool = True, hint: Optional[str] = None, - ) -> GuiHandle[Tuple[int, int, int, int]]: + ) -> GuiInputHandle[Tuple[int, int, int, int]]: """Add an RGBA picker to the GUI.""" id = _make_unique_id() return self._create_gui_input( @@ -643,7 +722,7 @@ def _create_gui_input( disabled: bool, visible: bool, is_button: bool = False, - ) -> GuiHandle[T]: + ) -> GuiInputHandle[T]: """Private helper for adding a simple GUI element.""" # Send add GUI input message. @@ -653,14 +732,13 @@ def _create_gui_input( handle_state = _GuiHandleState( label=message.label, typ=type(initial_value), - container=self, + gui_api=self, value=initial_value, update_timestamp=time.time(), container_id=self._get_container_id(), update_cb=[], is_button=is_button, sync_cb=None, - cleanup_cb=None, disabled=False, visible=True, id=message.id, @@ -668,10 +746,6 @@ def _create_gui_input( initial_value=initial_value, hint=message.hint, ) - self._get_api()._gui_handle_state_from_id[handle_state.id] = handle_state - handle_state.cleanup_cb = lambda: self._get_api()._gui_handle_state_from_id.pop( - handle_state.id - ) # For broadcasted GUI handles, we should synchronize all clients. # This will be a no-op for client handles. @@ -684,7 +758,7 @@ def sync_other_clients(client_id: ClientId, value: Any) -> None: handle_state.sync_cb = sync_other_clients - handle = GuiHandle(handle_state) + handle = GuiInputHandle(handle_state) # Set the disabled/visible fields. These will queue messages under-the-hood. if disabled: diff --git a/viser/_gui_handles.py b/viser/_gui_handles.py index f1e73157d..9e10e92fe 100644 --- a/viser/_gui_handles.py +++ b/viser/_gui_handles.py @@ -6,8 +6,8 @@ import uuid from typing import ( TYPE_CHECKING, - Any, Callable, + Dict, Generic, Iterable, List, @@ -19,13 +19,14 @@ ) import numpy as onp +from typing_extensions import Protocol from ._icons import base64_from_icon from ._icons_enum import Icon from ._messages import ( GuiAddDropdownMessage, GuiAddTabGroupMessage, - GuiRemoveContainerChildrenMessage, + GuiCloseModalMessage, GuiRemoveMessage, GuiSetDisabledMessage, GuiSetValueMessage, @@ -35,10 +36,11 @@ if TYPE_CHECKING: from ._gui_api import GuiApi + from ._viser import ClientHandle T = TypeVar("T") -TGuiHandle = TypeVar("TGuiHandle", bound="_GuiHandle") +TGuiHandle = TypeVar("TGuiHandle", bound="_GuiInputHandle") def _make_unique_id() -> str: @@ -46,20 +48,31 @@ def _make_unique_id() -> str: return str(uuid.uuid4()) +class GuiContainerProtocol(Protocol): + _children: Dict[str, SupportsRemoveProtocol] = dataclasses.field( + default_factory=dict + ) + + +class SupportsRemoveProtocol(Protocol): + def remove(self) -> None: + ... + + @dataclasses.dataclass class _GuiHandleState(Generic[T]): """Internal API for GUI elements.""" label: str typ: Type[T] - container: GuiApi + gui_api: GuiApi value: T update_timestamp: float container_id: str """Container that this GUI input was placed into.""" - update_cb: List[Callable[[Any], None]] + update_cb: List[Callable[[GuiEvent], None]] """Registered functions to call when this input is updated.""" is_button: bool @@ -68,9 +81,6 @@ class _GuiHandleState(Generic[T]): sync_cb: Optional[Callable[[ClientId, T], None]] """Callback for synchronizing inputs across clients.""" - cleanup_cb: Optional[Callable[[], Any]] - """Function to call when GUI element is removed.""" - disabled: bool visible: bool @@ -81,7 +91,7 @@ class _GuiHandleState(Generic[T]): @dataclasses.dataclass -class _GuiHandle(Generic[T]): +class _GuiInputHandle(Generic[T]): # Let's shove private implementation details in here... _impl: _GuiHandleState[T] @@ -115,7 +125,7 @@ def value(self, value: Union[T, onp.ndarray]) -> None: # Send to client, except for buttons. if not self._impl.is_button: - self._impl.container._get_api()._queue( + self._impl.gui_api._get_api()._queue( GuiSetValueMessage(self._impl.id, value) # type: ignore ) @@ -128,7 +138,15 @@ def value(self, value: Union[T, onp.ndarray]) -> None: for cb in self._impl.update_cb: # Pushing callbacks into separate threads helps prevent deadlocks when we # have a lock in a callback. TODO: revisit other callbacks. - threading.Thread(target=lambda: cb(self)).start() + threading.Thread( + target=lambda: cb( + GuiEvent( + client_id=None, + client=None, + target=self, + ) + ) + ).start() @property def update_timestamp(self) -> float: @@ -146,7 +164,7 @@ def disabled(self, disabled: bool) -> None: if disabled == self.disabled: return - self._impl.container._get_api()._queue( + self._impl.gui_api._get_api()._queue( GuiSetDisabledMessage(self._impl.id, disabled=disabled) ) self._impl.disabled = disabled @@ -162,58 +180,83 @@ def visible(self, visible: bool) -> None: if visible == self.visible: return - self._impl.container._get_api()._queue( + self._impl.gui_api._get_api()._queue( GuiSetVisibleMessage(self._impl.id, visible=visible) ) self._impl.visible = visible + def __post_init__(self) -> None: + """We need to register ourself after construction for callbacks to work.""" + gui_api = self._impl.gui_api + + # TODO: the current way we track GUI handles and children is fairly manual + + # error-prone. We should revist this design. + gui_api._gui_handle_from_id[self._impl.id] = self + parent = gui_api._container_handle_from_id[self._impl.container_id] + parent._children[self._impl.id] = self + def remove(self) -> None: """Permanently remove this GUI element from the visualizer.""" - self._impl.container._get_api()._queue(GuiRemoveMessage(self._impl.id)) - assert self._impl.cleanup_cb is not None - self._impl.cleanup_cb() + gui_api = self._impl.gui_api + gui_api._get_api()._queue(GuiRemoveMessage(self._impl.id)) + gui_api._gui_handle_from_id.pop(self._impl.id) StringType = TypeVar("StringType", bound=str) +# GuiInputHandle[T] is used for all inputs except for buttons. +# +# We inherit from _GuiInputHandle to special-case buttons because the usage semantics +# are slightly different: we have `on_click()` instead of `on_update()`. @dataclasses.dataclass -class GuiHandle(_GuiHandle[T], Generic[T]): +class GuiInputHandle(_GuiInputHandle[T], Generic[T]): """Handle for a general GUI inputs in our visualizer. Lets us get values, set values, and detect updates.""" def on_update( - self: TGuiHandle, func: Callable[[TGuiHandle], None] - ) -> Callable[[TGuiHandle], None]: + self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], None] + ) -> Callable[[GuiEvent[TGuiHandle]], None]: """Attach a function to call when a GUI input is updated. Happens in a thread.""" self._impl.update_cb.append(func) return func +@dataclasses.dataclass(frozen=True) +class GuiEvent(Generic[TGuiHandle]): + """Information associated with a GUI event, such as an update or click. + + Passed as input to callback functions.""" + + client_id: Optional[ClientId] + client: Optional[ClientHandle] + target: TGuiHandle + + @dataclasses.dataclass -class GuiButtonHandle(_GuiHandle[bool]): +class GuiButtonHandle(_GuiInputHandle[bool]): """Handle for a button input in our visualizer. Lets us detect clicks.""" def on_click( - self: TGuiHandle, func: Callable[[TGuiHandle], None] - ) -> Callable[[TGuiHandle], None]: + self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], None] + ) -> Callable[[GuiEvent[TGuiHandle]], None]: """Attach a function to call when a button is pressed. Happens in a thread.""" self._impl.update_cb.append(func) return func @dataclasses.dataclass -class GuiButtonGroupHandle(_GuiHandle[StringType], Generic[StringType]): +class GuiButtonGroupHandle(_GuiInputHandle[StringType], Generic[StringType]): """Handle for a button group input in our visualizer. Lets us detect clicks.""" def on_click( - self: TGuiHandle, func: Callable[[TGuiHandle], None] - ) -> Callable[[TGuiHandle], None]: + self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], None] + ) -> Callable[[GuiEvent[TGuiHandle]], None]: """Attach a function to call when a button is pressed. Happens in a thread.""" self._impl.update_cb.append(func) return func @@ -230,7 +273,7 @@ def disabled(self, disabled: bool) -> None: @dataclasses.dataclass -class GuiDropdownHandle(GuiHandle[StringType], Generic[StringType]): +class GuiDropdownHandle(GuiInputHandle[StringType], Generic[StringType]): """Handle for a dropdown-style GUI input in our visualizer. Lets us get values, set values, and detect updates.""" @@ -244,7 +287,7 @@ def options(self) -> Tuple[StringType, ...]: For projects that care about typing: the static type of `options` should be consistent with the `StringType` associated with a handle. Literal types will be inferred where possible when handles are instantiated; for the most flexibility, - we can declare handles as `_GuiHandle[str]`. + we can declare handles as `GuiDropdownHandle[str]`. """ return self._impl_options @@ -254,7 +297,7 @@ def options(self, options: Iterable[StringType]) -> None: if self._impl.initial_value not in self._impl_options: self._impl.initial_value = self._impl_options[0] - self._impl.container._get_api()._queue( + self._impl.gui_api._get_api()._queue( GuiAddDropdownMessage( order=self._impl.order, id=self._impl.id, @@ -275,9 +318,9 @@ class GuiTabGroupHandle: _tab_group_id: str _labels: List[str] _icons_base64: List[Optional[str]] - _tab_container_ids: List[str] + _tabs: List[GuiTabHandle] _gui_api: GuiApi - _container_id: str + _container_id: str # Parent. def add_tab(self, label: str, icon: Optional[Icon] = None) -> GuiTabHandle: """Add a tab. Returns a handle we can use to add GUI elements to it.""" @@ -285,23 +328,20 @@ def add_tab(self, label: str, icon: Optional[Icon] = None) -> GuiTabHandle: id = _make_unique_id() # We may want to make this thread-safe in the future. + out = GuiTabHandle(_parent=self, _id=id) + self._labels.append(label) self._icons_base64.append(None if icon is None else base64_from_icon(icon)) - self._tab_container_ids.append(id) + self._tabs.append(out) self._sync_with_client() - - return GuiTabHandle(_parent=self, _container_id=id) + return out def remove(self) -> None: """Remove this tab group and all contained GUI elements.""" + for tab in self._tabs: + tab.remove() self._gui_api._get_api()._queue(GuiRemoveMessage(self._tab_group_id)) - # Containers will be removed automatically by the client. - # - # for tab_container_id in self._tab_container_ids: - # self._gui_api._get_api()._queue( - # _messages.GuiRemoveContainerChildrenMessage(tab_container_id) - # ) def _sync_with_client(self) -> None: """Send a message that syncs tab state with the client.""" @@ -312,7 +352,7 @@ def _sync_with_client(self) -> None: container_id=self._container_id, tab_labels=tuple(self._labels), tab_icons_base64=tuple(self._icons_base64), - tab_container_ids=tuple(self._tab_container_ids), + tab_container_ids=tuple(tab._id for tab in self._tabs), ) ) @@ -322,12 +362,17 @@ class GuiFolderHandle: """Use as a context to place GUI elements into a folder.""" _gui_api: GuiApi - _container_id: str + _id: str # Used as container ID for children. + _parent_container_id: str # Container ID of parent. _container_id_restore: Optional[str] = None + _children: Dict[str, SupportsRemoveProtocol] = dataclasses.field( + default_factory=dict + ) - def __enter__(self) -> None: + def __enter__(self) -> GuiFolderHandle: self._container_id_restore = self._gui_api._get_container_id() - self._gui_api._set_container_id(self._container_id) + self._gui_api._set_container_id(self._id) + return self def __exit__(self, *args) -> None: del args @@ -335,10 +380,18 @@ def __exit__(self, *args) -> None: self._gui_api._set_container_id(self._container_id_restore) self._container_id_restore = None + def __post_init__(self) -> None: + self._gui_api._container_handle_from_id[self._id] = self + parent = self._gui_api._container_handle_from_id[self._parent_container_id] + parent._children[self._id] = self + def remove(self) -> None: """Permanently remove this folder and all contained GUI elements from the visualizer.""" - self._gui_api._get_api()._queue(GuiRemoveMessage(self._container_id)) + self._gui_api._get_api()._queue(GuiRemoveMessage(self._id)) + self._gui_api._container_handle_from_id.pop(self._id) + for child in self._children.values(): + child.remove() @dataclasses.dataclass @@ -346,12 +399,16 @@ class GuiModalHandle: """Use as a context to place GUI elements into a modal.""" _gui_api: GuiApi - _container_id: str + _id: str # Used as container ID of children. _container_id_restore: Optional[str] = None + _children: Dict[str, SupportsRemoveProtocol] = dataclasses.field( + default_factory=dict + ) - def __enter__(self) -> None: + def __enter__(self) -> GuiModalHandle: self._container_id_restore = self._gui_api._get_container_id() - self._gui_api._set_container_id(self._container_id) + self._gui_api._set_container_id(self._id) + return self def __exit__(self, *args) -> None: del args @@ -359,18 +416,34 @@ def __exit__(self, *args) -> None: self._gui_api._set_container_id(self._container_id_restore) self._container_id_restore = None + def __post_init__(self) -> None: + self._gui_api._container_handle_from_id[self._id] = self + + def close(self) -> None: + """Close this modal and permananently remove all contained GUI elements.""" + self._gui_api._get_api()._queue( + GuiCloseModalMessage(self._id), + ) + self._gui_api._container_handle_from_id.pop(self._id) + for child in self._children.values(): + child.remove() + @dataclasses.dataclass class GuiTabHandle: """Use as a context to place GUI elements into a tab.""" _parent: GuiTabGroupHandle - _container_id: str + _id: str # Used as container ID of children. _container_id_restore: Optional[str] = None + _children: Dict[str, SupportsRemoveProtocol] = dataclasses.field( + default_factory=dict + ) - def __enter__(self) -> None: + def __enter__(self) -> GuiTabHandle: self._container_id_restore = self._parent._gui_api._get_container_id() - self._parent._gui_api._set_container_id(self._container_id) + self._parent._gui_api._set_container_id(self._id) + return self def __exit__(self, *args) -> None: del args @@ -378,24 +451,30 @@ def __exit__(self, *args) -> None: self._parent._gui_api._set_container_id(self._container_id_restore) self._container_id_restore = None + def __post_init__(self) -> None: + self._parent._gui_api._container_handle_from_id[self._id] = self + def remove(self) -> None: """Permanently remove this tab and all contained GUI elements from the visualizer.""" # We may want to make this thread-safe in the future. - container_index = self._parent._tab_container_ids.index(self._container_id) + container_index = -1 + for i, tab in enumerate(self._parent._tabs): + if tab is self: + container_index = i + break assert container_index != -1, "Tab already removed!" - # Container needs to be manually removed. - self._parent._gui_api._get_api()._queue( - GuiRemoveContainerChildrenMessage(self._container_id) - ) + self._parent._gui_api._container_handle_from_id.pop(self._id) self._parent._labels.pop(container_index) self._parent._icons_base64.pop(container_index) - self._parent._tab_container_ids.pop(container_index) - + self._parent._tabs.pop(container_index) self._parent._sync_with_client() + for child in self._children.values(): + child.remove() + @dataclasses.dataclass class GuiMarkdownHandle: @@ -404,6 +483,7 @@ class GuiMarkdownHandle: _gui_api: GuiApi _id: str _visible: bool + _container_id: str # Parent. @property def visible(self) -> bool: @@ -419,6 +499,12 @@ def visible(self, visible: bool) -> None: self._gui_api._get_api()._queue(GuiSetVisibleMessage(self._id, visible=visible)) self._visible = visible + def __post_init__(self) -> None: + """We need to register ourself after construction for callbacks to work.""" + parent = self._gui_api._container_handle_from_id[self._container_id] + parent._children[self._id] = self + def remove(self) -> None: """Permanently remove this markdown from the visualizer.""" - self._gui_api._get_api()._queue(GuiRemoveMessage(self._id)) + api = self._gui_api._get_api() + api._queue(GuiRemoveMessage(self._id)) diff --git a/viser/_message_api.py b/viser/_message_api.py index 39085b0bc..b8a8091e7 100644 --- a/viser/_message_api.py +++ b/viser/_message_api.py @@ -9,11 +9,12 @@ import abc import base64 +import colorsys import io import threading import time from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, TypeVar, Union, cast +from typing import TYPE_CHECKING, Dict, Optional, Tuple, TypeVar, Union, cast import imageio.v3 as iio import numpy as onp @@ -23,7 +24,6 @@ from typing_extensions import Literal, ParamSpec, TypeAlias, assert_never from . import _messages, infra, theme -from ._gui_handles import GuiHandle, _GuiHandleState from ._scene_handles import ( CameraFrustumHandle, FrameHandle, @@ -45,6 +45,16 @@ P = ParamSpec("P") +def _hex_from_hls(h: float, l: float, s: float) -> str: + """Converts HLS values in [0.0, 1.0] to a hex-formatted string, eg 0xffffff.""" + return "#" + "".join( + [ + int(min(255, max(0, channel * 255.0)) + 0.5).to_bytes(1, "little").hex() + for channel in colorsys.hls_to_rgb(h, l, s) + ] + ) + + def _colors_to_uint8(colors: onp.ndarray) -> onpt.NDArray[onp.uint8]: """Convert intensity values to uint8. We assume the range [0,1] for floats, and [0,255] for integers.""" @@ -115,15 +125,15 @@ class MessageApi(abc.ABC): invidividual clients.""" def __init__(self, handler: infra.MessageHandler) -> None: + self._message_handler = handler + super().__init__() - self._gui_handle_state_from_id: Dict[str, _GuiHandleState[Any]] = {} self._handle_from_transform_controls_name: Dict[ str, TransformControlsHandle ] = {} self._handle_from_node_name: Dict[str, SceneNodeHandle] = {} - handler.register_handler(_messages.GuiUpdateMessage, self._handle_gui_updates) handler.register_handler( _messages.TransformControlsUpdateMessage, self._handle_transform_controls_updates, @@ -143,13 +153,53 @@ def configure_theme( titlebar_content: Optional[theme.TitlebarConfig] = None, control_layout: Literal["floating", "collapsible", "fixed"] = "floating", dark_mode: bool = False, + brand_color: Optional[Tuple[int, int, int]] = None, ) -> None: """Configure the viser front-end's visual appearance.""" + + colors_cast: Optional[ + Tuple[str, str, str, str, str, str, str, str, str, str] + ] = None + + if brand_color is not None: + assert len(brand_color) in (3, 10) + if len(brand_color) == 3: + assert all( + map(lambda val: isinstance(val, int), brand_color) + ), "All channels should be integers." + + # RGB => HLS. + h, l, s = colorsys.rgb_to_hls( + brand_color[0] / 255.0, + brand_color[1] / 255.0, + brand_color[2] / 255.0, + ) + + # Automatically generate a 10-color palette. + min_l = max(l - 0.08, 0.0) + max_l = min(0.8 + 0.5, 0.9) + l = max(min_l, min(max_l, l)) + + primary_index = 8 + ls = tuple( + onp.interp( + x=onp.arange(10), + xp=(0, primary_index, 9), + fp=(max_l, l, min_l), + ) + ) + colors_cast = tuple(_hex_from_hls(h, ls[i], s) for i in range(10)) # type: ignore + + assert colors_cast is None or all( + [isinstance(val, str) and val.startswith("#") for val in colors_cast] + ), "All string colors should be in hexadecimal + prefixed with #, eg #ffffff." + self._queue( _messages.ThemeConfigurationMessage( titlebar_content=titlebar_content, control_layout=control_layout, dark_mode=dark_mode, + colors=colors_cast, ), ) @@ -386,14 +436,39 @@ def set_background_image( image: onp.ndarray, format: Literal["png", "jpeg"] = "jpeg", jpeg_quality: Optional[int] = None, + depth: Optional[onp.ndarray] = None, ) -> None: - """Set a background image for the scene. Useful for NeRF visualization.""" + """Set a background image for the scene, optionally with depth compositing.""" media_type, base64_data = _encode_image_base64( image, format, jpeg_quality=jpeg_quality ) + + # Encode depth if provided. We use a 3-channel PNG to represent a fixed point + # depth at each pixel. + depth_base64data = None + if depth is not None: + # Convert to fixed-point. + # We'll support from 0 -> (2^24 - 1) / 100_000. + # + # This translates to a range of [0, 167.77215], with a precision of 1e-5. + assert len(depth.shape) == 2 or ( + len(depth.shape) == 3 and depth.shape[2] == 1 + ), "Depth should have shape (H,W) or (H,W,1)." + depth = onp.clip(depth * 100_000, 0, 2**24 - 1).astype(onp.uint32) + assert depth is not None # Appease mypy. + intdepth: onp.ndarray = depth.reshape((*depth.shape[:2], 1)).view(onp.uint8) + assert intdepth.shape == (*depth.shape[:2], 4) + with io.BytesIO() as data_buffer: + iio.imwrite(data_buffer, intdepth[:, :, :3], extension=".png") + depth_base64data = base64.b64encode(data_buffer.getvalue()).decode( + "ascii" + ) + self._queue( _messages.BackgroundImageMessage( - media_type=media_type, base64_data=base64_data + media_type=media_type, + base64_rgb=base64_data, + base64_depth=depth_base64data, ) ) @@ -517,31 +592,6 @@ def _queue_unsafe(self, message: _messages.Message) -> None: """Abstract method for sending messages.""" ... - def _handle_gui_updates( - self, client_id: ClientId, message: _messages.GuiUpdateMessage - ) -> None: - """Callback for handling GUI messages.""" - handle_state = self._gui_handle_state_from_id.get(message.id, None) - if handle_state is None: - return - - value = handle_state.typ(message.value) - - # Only call update when value has actually changed. - if not handle_state.is_button and value == handle_state.value: - return - - # Update state. - with self._atomic_lock: - handle_state.value = value - handle_state.update_timestamp = time.time() - - # Trigger callbacks. - for cb in handle_state.update_cb: - cb(GuiHandle(handle_state)) - if handle_state.sync_cb is not None: - handle_state.sync_cb(client_id, value) - def _handle_transform_controls_updates( self, client_id: ClientId, message: _messages.TransformControlsUpdateMessage ) -> None: @@ -583,6 +633,16 @@ def add_3d_gui_container( # Avoids circular import. from ._gui_api import GuiApi, _make_unique_id + # New name to make the type checker happy; ViserServer and ClientHandle inherit + # from both GuiApi and MessageApi. The pattern below is unideal. + gui_api = self + assert isinstance(gui_api, GuiApi) + + # Remove the 3D GUI container if it already exists. This will make sure + # contained GUI elements are removed, preventing potential memory leaks. + if name in gui_api._handle_from_node_name: + gui_api._handle_from_node_name[name].remove() + container_id = _make_unique_id() self._queue( _messages.Gui3DMessage( @@ -591,7 +651,5 @@ def add_3d_gui_container( container_id=container_id, ) ) - assert isinstance(self, MessageApi) node_handle = SceneNodeHandle._make(self, name, wxyz, position) - assert isinstance(self, GuiApi) - return Gui3dContainerHandle(node_handle._impl, self, container_id) + return Gui3dContainerHandle(node_handle._impl, gui_api, container_id) diff --git a/viser/_messages.py b/viser/_messages.py index 75bf74e78..e4b4be804 100644 --- a/viser/_messages.py +++ b/viser/_messages.py @@ -229,7 +229,8 @@ class BackgroundImageMessage(Message): """Message for rendering a background image.""" media_type: Literal["image/jpeg", "image/png"] - base64_data: str + base64_rgb: str + base64_depth: Optional[str] @dataclasses.dataclass @@ -321,7 +322,11 @@ class GuiModalMessage(Message): order: float id: str title: str - container_id: str + + +@dataclasses.dataclass +class GuiCloseModalMessage(Message): + id: str @dataclasses.dataclass @@ -329,6 +334,25 @@ class GuiAddButtonMessage(_GuiAddInputBase): # All GUI elements currently need an `initial_value` field. # This makes our job on the frontend easier. initial_value: bool + color: Optional[ + Literal[ + "dark", + "gray", + "red", + "pink", + "grape", + "violet", + "indigo", + "blue", + "cyan", + "green", + "lime", + "yellow", + "orange", + "teal", + ] + ] + icon_base64: Optional[str] @dataclasses.dataclass @@ -399,13 +423,6 @@ class GuiAddButtonGroupMessage(_GuiAddInputBase): options: Tuple[str, ...] -@dataclasses.dataclass -class GuiRemoveContainerChildrenMessage(Message): - """Sent server->client to recursively remove children of a GUI container.""" - - container_id: str - - @dataclasses.dataclass class GuiRemoveMessage(Message): """Sent server->client to remove a GUI element.""" @@ -451,6 +468,7 @@ class ThemeConfigurationMessage(Message): titlebar_content: Optional[theme.TitlebarConfig] control_layout: Literal["floating", "collapsible", "fixed"] + colors: Optional[Tuple[str, str, str, str, str, str, str, str, str, str]] dark_mode: bool diff --git a/viser/_scene_handles.py b/viser/_scene_handles.py index d40a31799..4178b44c1 100644 --- a/viser/_scene_handles.py +++ b/viser/_scene_handles.py @@ -6,7 +6,17 @@ from __future__ import annotations import dataclasses -from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + Generic, + List, + Optional, + Tuple, + Type, + TypeVar, +) import numpy as onp @@ -14,6 +24,7 @@ if TYPE_CHECKING: from ._gui_api import GuiApi + from ._gui_handles import SupportsRemoveProtocol from ._message_api import ClientId, MessageApi @@ -96,11 +107,17 @@ def remove(self) -> None: self._impl.api._queue(_messages.RemoveSceneNodeMessage(self._impl.name)) +@dataclasses.dataclass(frozen=True) +class ClickEvent(Generic[TSceneNodeHandle]): + client_id: ClientId + target: TSceneNodeHandle + + @dataclasses.dataclass class _SupportsClick(SceneNodeHandle): def on_click( - self: TSceneNodeHandle, func: Callable[[TSceneNodeHandle], None] - ) -> Callable[[TSceneNodeHandle], None]: + self: TSceneNodeHandle, func: Callable[[ClickEvent[TSceneNodeHandle]], None] + ) -> Callable[[ClickEvent[TSceneNodeHandle]], None]: """Attach a callback for when a scene node is clicked. TODO: @@ -211,10 +228,14 @@ class Gui3dContainerHandle(SceneNodeHandle): _gui_api: GuiApi _container_id: str _container_id_restore: Optional[str] = None + _children: Dict[str, SupportsRemoveProtocol] = dataclasses.field( + default_factory=dict + ) - def __enter__(self) -> None: + def __enter__(self) -> Gui3dContainerHandle: self._container_id_restore = self._gui_api._get_container_id() self._gui_api._set_container_id(self._container_id) + return self def __exit__(self, *args) -> None: del args @@ -222,6 +243,9 @@ def __exit__(self, *args) -> None: self._gui_api._set_container_id(self._container_id_restore) self._container_id_restore = None + def __post_init__(self) -> None: + self._gui_api._container_handle_from_id[self._container_id] = self + def remove(self) -> None: """Permanently remove this GUI container from the visualizer.""" @@ -229,6 +253,6 @@ def remove(self) -> None: super().remove() # Clean up contained GUI elements. - self._gui_api._get_api()._queue( - _messages.GuiRemoveContainerChildrenMessage(self._container_id) - ) + self._gui_api._container_handle_from_id.pop(self._container_id) + for child in self._children.values(): + child.remove() diff --git a/viser/client/src/App.tsx b/viser/client/src/App.tsx index 170804eb0..6253b2d00 100644 --- a/viser/client/src/App.tsx +++ b/viser/client/src/App.tsx @@ -6,7 +6,7 @@ import { Environment, } from "@react-three/drei"; import * as THREE from "three"; -import { Canvas, useThree } from "@react-three/fiber"; +import { Canvas, useThree, useFrame } from "@react-three/fiber"; import { EffectComposer, Outline, @@ -22,22 +22,31 @@ import { SceneNodeThreeObject, UseSceneTree } from "./SceneTree"; import "./index.css"; import ControlPanel from "./ControlPanel/ControlPanel"; -import { UseGui, useGuiState } from "./ControlPanel/GuiState"; +import { UseGui, useGuiState, useMantineTheme } from "./ControlPanel/GuiState"; import { searchParamKey } from "./SearchParamsUtils"; -import WebsocketInterface from "./WebsocketInterface"; +import { + WebsocketMessageProducer, + FrameSynchronizedMessageHandler, +} from "./WebsocketInterface"; import { Titlebar } from "./Titlebar"; import { ViserModal } from "./Modal"; import { useSceneTreeState } from "./SceneTreeState"; +import { Message } from "./WebsocketMessages"; export type ViewerContextContents = { + // Zustand hooks. useSceneTree: UseSceneTree; useGui: UseGui; + // Useful references. websocketRef: React.MutableRefObject; canvasRef: React.MutableRefObject; sceneRef: React.MutableRefObject; cameraRef: React.MutableRefObject; + backgroundMaterialRef: React.MutableRefObject; cameraControlRef: React.MutableRefObject; + // Scene node attributes. + // This is intentionally placed outside of the Zustand state to reduce overhead. nodeAttributesFromName: React.MutableRefObject<{ [name: string]: | undefined @@ -47,6 +56,7 @@ export type ViewerContextContents = { visibility?: boolean; }; }>; + messageQueueRef: React.MutableRefObject; }; export const ViewerContext = React.createContext( null, @@ -54,8 +64,8 @@ export const ViewerContext = React.createContext( THREE.ColorManagement.enabled = true; -function SingleViewer() { - // Default server logic. +function ViewerRoot() { + // What websocket server should we connect to? function getDefaultServerFromUrl() { // https://localhost:8080/ => ws://localhost:8080 // https://localhost:8080/?server=some_url => ws://localhost:8080 @@ -79,58 +89,58 @@ function SingleViewer() { canvasRef: React.useRef(null), sceneRef: React.useRef(null), cameraRef: React.useRef(null), + backgroundMaterialRef: React.useRef(null), cameraControlRef: React.useRef(null), - // Scene node attributes that aren't placed in the zustand state, for performance reasons. + // Scene node attributes that aren't placed in the zustand state for performance reasons. nodeAttributesFromName: React.useRef({}), + messageQueueRef: React.useRef([]), }; - // Memoize the websocket interface so it isn't remounted when the theme or - // viewer context changes. - const memoizedWebsocketInterface = React.useMemo( - () => , - [], + return ( + + + + ); +} +function ViewerContents() { + const viewer = React.useContext(ViewerContext)!; const control_layout = viewer.useGui((state) => state.theme.control_layout); return ( state.theme.dark_mode) - ? "dark" - : "light", - }} + theme={useMantineTheme()} > - - - - - - ({ - top: 0, - bottom: 0, - left: 0, - right: control_layout === "fixed" ? "20em" : 0, - position: "absolute", - backgroundColor: - theme.colorScheme === "light" ? "#fff" : theme.colors.dark[9], - })} - > - {memoizedWebsocketInterface} - - - - - + + + + + ({ + backgroundColor: + theme.colorScheme === "light" ? "#fff" : theme.colors.dark[9], + flexGrow: 1, + width: "10em", + })} + > + + + + + + + ); } @@ -151,6 +161,7 @@ function ViewerCanvas({ children }: { children: React.ReactNode }) { ref={viewer.canvasRef} > {children} + @@ -174,6 +185,116 @@ function ViewerCanvas({ children }: { children: React.ReactNode }) { ); } +/* Background image with support for depth compositing. */ +function BackgroundImage() { + // Create a fragment shader that composites depth using depth and rgb + const vertShader = ` + varying vec2 vUv; + + void main() { + vUv = uv; + gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0); + } + `.trim(); + const fragShader = ` + #include + precision highp float; + precision highp int; + + varying vec2 vUv; + uniform sampler2D colorMap; + uniform sampler2D depthMap; + uniform float cameraNear; + uniform float cameraFar; + uniform bool enabled; + uniform bool hasDepth; + + float readDepth(sampler2D depthMap, vec2 coord) { + vec4 rgbPacked = texture(depthMap, coord); + + // For the k-th channel, coefficients are calculated as: 255 * 1e-5 * 2^(8 * k). + // Note that: [0, 255] channels are scaled to [0, 1], and we multiply by 1e5 on the server side. + float depth = rgbPacked.r * 0.00255 + rgbPacked.g * 0.6528 + rgbPacked.b * 167.1168; + return depth; + } + + void main() { + if (!enabled) { + // discard the pixel if we're not enabled + discard; + } + vec4 color = texture(colorMap, vUv); + gl_FragColor = vec4(color.rgb, 1.0); + + float bufDepth; + if(hasDepth){ + float depth = readDepth(depthMap, vUv); + bufDepth = viewZToPerspectiveDepth(-depth, cameraNear, cameraFar); + } else { + // If no depth enabled, set depth to 1.0 (infinity) to treat it like a background image. + bufDepth = 1.0; + } + gl_FragDepth = bufDepth; + }`.trim(); + // initialize the rgb texture with all white and depth at infinity + const backgroundMaterial = new THREE.ShaderMaterial({ + fragmentShader: fragShader, + vertexShader: vertShader, + uniforms: { + enabled: { value: false }, + depthMap: { value: null }, + colorMap: { value: null }, + cameraNear: { value: null }, + cameraFar: { value: null }, + hasDepth: { value: false }, + }, + }); + const { backgroundMaterialRef } = React.useContext(ViewerContext)!; + backgroundMaterialRef.current = backgroundMaterial; + const backgroundMesh = React.useRef(null); + useFrame(({ camera }) => { + // Logic ahead relies on perspective camera assumption. + if (!(camera instanceof THREE.PerspectiveCamera)) { + console.error( + "Camera is not a perspective camera, cannot render background image", + ); + return; + } + + // Update the position of the mesh based on the camera position. + const lookdir = camera.getWorldDirection(new THREE.Vector3()); + backgroundMesh.current!.position.set( + camera.position.x, + camera.position.y, + camera.position.z, + ); + backgroundMesh.current!.position.addScaledVector(lookdir, 1.0); + backgroundMesh.current!.quaternion.copy(camera.quaternion); + + // Resize the mesh based on focal length. + const f = camera.getFocalLength(); + backgroundMesh.current!.scale.set( + camera.getFilmWidth() / f, + camera.getFilmHeight() / f, + 1.0, + ); + + // Set near/far uniforms. + backgroundMaterial.uniforms.cameraNear.value = camera.near; + backgroundMaterial.uniforms.cameraFar.value = camera.far; + }); + + return ( + + + + ); +} + /** Component for helping us set the scene reference. */ function SceneContextSetter() { const { sceneRef, cameraRef } = React.useContext(ViewerContext)!; @@ -195,7 +316,7 @@ export function Root() { flexDirection: "column", }} > - + ); } diff --git a/viser/client/src/ControlPanel/BottomPanel.tsx b/viser/client/src/ControlPanel/BottomPanel.tsx index 898c0477e..8be1ce2ac 100644 --- a/viser/client/src/ControlPanel/BottomPanel.tsx +++ b/viser/client/src/ControlPanel/BottomPanel.tsx @@ -1,8 +1,13 @@ import { Box, Collapse, Paper } from "@mantine/core"; import React from "react"; -import { FloatingPanelContext } from "./FloatingPanel"; import { useDisclosure } from "@mantine/hooks"; +const BottomPanelContext = React.createContext; + expanded: boolean; + toggleExpanded: () => void; +}>(null); + export default function BottomPanel({ children, }: { @@ -11,7 +16,7 @@ export default function BottomPanel({ const panelWrapperRef = React.useRef(null); const [expanded, { toggle: toggleExpanded }] = useDisclosure(true); return ( - {children} - + ); } BottomPanel.Handle = function BottomPanelHandle({ @@ -46,7 +51,7 @@ BottomPanel.Handle = function BottomPanelHandle({ }: { children: string | React.ReactNode; }) { - const panelContext = React.useContext(FloatingPanelContext)!; + const panelContext = React.useContext(BottomPanelContext)!; return ( { panelContext.toggleExpanded(); }} > - - {children} - + {children} ); }; @@ -82,6 +83,6 @@ BottomPanel.Contents = function BottomPanelContents({ }: { children: string | React.ReactNode; }) { - const panelContext = React.useContext(FloatingPanelContext)!; + const panelContext = React.useContext(BottomPanelContext)!; return {children}; }; diff --git a/viser/client/src/ControlPanel/ControlPanel.tsx b/viser/client/src/ControlPanel/ControlPanel.tsx index e5b9c8224..7f18db5d2 100644 --- a/viser/client/src/ControlPanel/ControlPanel.tsx +++ b/viser/client/src/ControlPanel/ControlPanel.tsx @@ -1,10 +1,9 @@ import { useDisclosure, useMediaQuery } from "@mantine/hooks"; import GeneratedGuiContainer from "./Generated"; import { ViewerContext } from "../App"; -import ServerControls from "./Server"; +import ServerControls from "./ServerControls"; import { ActionIcon, - Aside, Box, Collapse, Tooltip, @@ -15,116 +14,48 @@ import { IconCloudCheck, IconCloudOff, IconArrowBack, - IconChevronLeft, - IconChevronRight, } from "@tabler/icons-react"; import React from "react"; import BottomPanel from "./BottomPanel"; -import FloatingPanel, { FloatingPanelContext } from "./FloatingPanel"; +import FloatingPanel from "./FloatingPanel"; import { ThemeConfigurationMessage } from "../WebsocketMessages"; +import SidebarPanel from "./SidebarPanel"; // Must match constant in Python. const ROOT_CONTAINER_ID = "root"; -/** Hides contents when floating panel is collapsed. */ -function HideWhenCollapsed({ children }: { children: React.ReactNode }) { - const expanded = React.useContext(FloatingPanelContext)?.expanded ?? true; - return expanded ? children : null; -} - export default function ControlPanel(props: { control_layout: ThemeConfigurationMessage["control_layout"]; }) { const theme = useMantineTheme(); const useMobileView = useMediaQuery(`(max-width: ${theme.breakpoints.xs})`); - // TODO: will result in unnecessary re-renders + // TODO: will result in unnecessary re-renders. const viewer = React.useContext(ViewerContext)!; const showGenerated = viewer.useGui( (state) => "root" in state.guiIdSetFromContainerId, ); const [showSettings, { toggle }] = useDisclosure(false); - const [collapsed, { toggle: toggleCollapse }] = useDisclosure(false); - const handleContents = ( - <> - - - {/* We can't apply translateY directly to the ActionIcon, since it's used by - Mantine for the active/click indicator. */} - - { - evt.stopPropagation(); - toggle(); - }} - > - - {showSettings ? : } - - - - - - { - evt.stopPropagation(); - toggleCollapse(); - }} - > - {} - - - - ); - const collapsedView = ( -
+ const generatedServerToggleButton = ( + { evt.stopPropagation(); - toggleCollapse(); + toggle(); }} > - {} + + {showSettings ? ( + + ) : ( + + )} + -
+ ); const panelContents = ( @@ -139,75 +70,40 @@ export default function ControlPanel(props: { ); if (useMobileView) { + /* Mobile layout. */ return ( - {handleContents} + + + {generatedServerToggleButton} + {panelContents} ); - } else if (props.control_layout !== "floating") { - return ( - <> - - {collapsedView} - - - - ); - } else { + } else if (props.control_layout === "floating") { + /* Floating layout. */ return ( - {handleContents} + + + + {generatedServerToggleButton} + + {panelContents} ); + } else { + /* Sidebar view. */ + return ( + + + + {generatedServerToggleButton} + + {panelContents} + + ); } } @@ -220,9 +116,7 @@ function ConnectionStatus() { const StatusIcon = connected ? IconCloudCheck : IconCloudOff; return ( - + <> -     - {label === "" ? server : label} - + + {label === "" ? server : label} + + ); } diff --git a/viser/client/src/ControlPanel/FloatingPanel.tsx b/viser/client/src/ControlPanel/FloatingPanel.tsx index ed49c8d00..0826c5ba6 100644 --- a/viser/client/src/ControlPanel/FloatingPanel.tsx +++ b/viser/client/src/ControlPanel/FloatingPanel.tsx @@ -5,7 +5,7 @@ import React from "react"; import { isMouseEvent, isTouchEvent, mouseEvents, touchEvents } from "../Utils"; import { useDisclosure } from "@mantine/hooks"; -export const FloatingPanelContext = React.createContext; expanded: boolean; toggleExpanded: () => void; @@ -31,16 +31,16 @@ export default function FloatingPanel({ > { const state = dragInfo.current; @@ -246,14 +250,7 @@ FloatingPanel.Handle = function FloatingPanelHandle({ dragHandler(event); }} > - - {children} - + {children} ); }; @@ -266,3 +263,13 @@ FloatingPanel.Contents = function FloatingPanelContents({ const context = React.useContext(FloatingPanelContext); return {children}; }; + +/** Hides contents when floating panel is collapsed. */ +FloatingPanel.HideWhenCollapsed = function FloatingPanelHideWhenCollapsed({ + children, +}: { + children: React.ReactNode; +}) { + const expanded = React.useContext(FloatingPanelContext)?.expanded ?? true; + return expanded ? children : null; +}; diff --git a/viser/client/src/ControlPanel/Generated.tsx b/viser/client/src/ControlPanel/Generated.tsx index 78a392bb6..858a8edf2 100644 --- a/viser/client/src/ControlPanel/Generated.tsx +++ b/viser/client/src/ControlPanel/Generated.tsx @@ -27,6 +27,7 @@ import { ErrorBoundary } from "react-error-boundary"; /** Root of generated inputs. */ export default function GeneratedGuiContainer({ + // We need to take viewer as input in drei's elements, where contexts break. containerId, viewer, }: { @@ -47,7 +48,7 @@ export default function GeneratedGuiContainer({ {[...guiIdSet] .map((id) => guiConfigFromId[id]) .sort((a, b) => a.order - b.order) - .map((conf, index) => { + .map((conf) => { return ; })} @@ -116,6 +117,7 @@ function GeneratedInput({ @@ -383,7 +398,9 @@ function GeneratedTabGroup({ conf }: { conf: GuiAddTabGroupMessage }) { icon={ icons[index] === null ? undefined : ( ({ filter: theme.colorScheme == "dark" ? "invert(1)" : undefined, diff --git a/viser/client/src/ControlPanel/GuiState.tsx b/viser/client/src/ControlPanel/GuiState.tsx index 843379113..e7f269d4b 100644 --- a/viser/client/src/ControlPanel/GuiState.tsx +++ b/viser/client/src/ControlPanel/GuiState.tsx @@ -2,6 +2,8 @@ import * as Messages from "../WebsocketMessages"; import React from "react"; import { create } from "zustand"; import { immer } from "zustand/middleware/immer"; +import { ViewerContext } from "../App"; +import { MantineThemeOverride } from "@mantine/core"; export type GuiConfig = | Messages.GuiAddButtonMessage @@ -44,12 +46,11 @@ interface GuiActions { setTheme: (theme: Messages.ThemeConfigurationMessage) => void; addGui: (config: GuiConfig) => void; addModal: (config: Messages.GuiModalMessage) => void; - popModal: () => void; + removeModal: (id: string) => void; setGuiValue: (id: string, value: any) => void; setGuiVisible: (id: string, visible: boolean) => void; setGuiDisabled: (id: string, visible: boolean) => void; removeGui: (id: string) => void; - removeGuiContainer: (containerId: string) => void; resetGui: () => void; } @@ -59,6 +60,7 @@ const cleanGuiState: GuiState = { titlebar_content: null, control_layout: "floating", dark_mode: false, + colors: null, }, label: "", server: "ws://localhost:8080", // Currently this will always be overridden. @@ -92,9 +94,9 @@ export function useGuiState(initialServer: string) { set((state) => { state.modals.push(modalConfig); }), - popModal: () => + removeModal: (id) => set((state) => { - state.modals.pop(); + state.modals = state.modals.filter((m) => m.id !== id); }), setGuiValue: (id, value) => set((state) => { @@ -117,10 +119,6 @@ export function useGuiState(initialServer: string) { removeGui: (id) => set((state) => { const guiConfig = state.guiConfigFromId[id]; - if (guiConfig.type === "GuiAddFolderMessage") - state.removeGuiContainer(guiConfig.id); - if (guiConfig.type === "GuiAddTabGroupMessage") - guiConfig.tab_container_ids.forEach(state.removeGuiContainer); state.guiIdSetFromContainerId[guiConfig.container_id]!.delete( guiConfig.id, @@ -129,19 +127,6 @@ export function useGuiState(initialServer: string) { delete state.guiValueFromId[id]; delete state.guiAttributeFromId[id]; }), - removeGuiContainer: (containerId) => - set((state) => { - const guiIdSet = state.guiIdSetFromContainerId[containerId]; - if (guiIdSet === undefined) { - console.log( - "Tried to remove but could not find container ID", - containerId, - ); - return; - } - Object.keys(guiIdSet).forEach(state.removeGui); - delete state.guiIdSetFromContainerId[containerId]; - }), resetGui: () => set((state) => { state.guiIdSetFromContainerId = {}; @@ -154,5 +139,22 @@ export function useGuiState(initialServer: string) { )[0]; } +export function useMantineTheme(): MantineThemeOverride { + const viewer = React.useContext(ViewerContext)!; + const colors = viewer.useGui((state) => state.theme.colors); + return { + colorScheme: viewer.useGui((state) => state.theme.dark_mode) + ? "dark" + : "light", + primaryColor: colors === null ? undefined : "custom", + colors: + colors === null + ? undefined + : { + custom: colors, + }, + }; +} + /** Type corresponding to a zustand-style useGuiState hook. */ export type UseGui = ReturnType; diff --git a/viser/client/src/ControlPanel/Server.tsx b/viser/client/src/ControlPanel/ServerControls.tsx similarity index 99% rename from viser/client/src/ControlPanel/Server.tsx rename to viser/client/src/ControlPanel/ServerControls.tsx index 3683fa099..61f98c22c 100644 --- a/viser/client/src/ControlPanel/Server.tsx +++ b/viser/client/src/ControlPanel/ServerControls.tsx @@ -102,6 +102,7 @@ export default function ServerControls() { }} /> + Scene tree diff --git a/viser/client/src/ControlPanel/SidebarPanel.tsx b/viser/client/src/ControlPanel/SidebarPanel.tsx new file mode 100644 index 000000000..c018f9e9d --- /dev/null +++ b/viser/client/src/ControlPanel/SidebarPanel.tsx @@ -0,0 +1,137 @@ +// @refresh reset + +import { ActionIcon, Box, Paper, Tooltip } from "@mantine/core"; +import React from "react"; +import { useDisclosure } from "@mantine/hooks"; +import { IconChevronLeft, IconChevronRight } from "@tabler/icons-react"; + +export const SidebarPanelContext = React.createContext void; +}>(null); + +/** Root component for control panel. Parents a set of control tabs. + * This could be refactored+cleaned up a lot! */ +export default function SidebarPanel({ + children, + collapsible, +}: { + children: string | React.ReactNode; + collapsible: boolean; +}) { + const [collapsed, { toggle: toggleCollapsed }] = useDisclosure(false); + + const collapsedView = ( + ({ + /* Animate in when collapsed. */ + position: "absolute", + top: "0em", + right: collapsed ? "0em" : "-3em", + transitionProperty: "right", + transitionDuration: "0.5s", + transitionDelay: "0.25s", + /* Visuals. */ + borderBottomLeftRadius: "0.5em", + backgroundColor: + theme.colorScheme == "dark" + ? theme.colors.dark[5] + : theme.colors.gray[2], + padding: "0.5em", + })} + > + { + evt.stopPropagation(); + toggleCollapsed(); + }} + > + {} + + + ); + + return ( + + {collapsedView} + {/* Using an