Skip to content

Commit

Permalink
Callback, scene node removal API improvements (#290)
Browse files Browse the repository at this point in the history
* Callback, scene node removal API improvements

* Warning
  • Loading branch information
brentyi committed Sep 25, 2024
1 parent 125328e commit a2f4abd
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 60 deletions.
61 changes: 60 additions & 1 deletion src/viser/_gui_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Dict,
Generic,
Iterable,
Literal,
Tuple,
TypeVar,
cast,
Expand Down Expand Up @@ -106,6 +107,8 @@ class _GuiHandleState(Generic[T]):
sync_cb: Callable[[ClientId, dict[str, Any]], None] | None = None
"""Callback for synchronizing inputs across clients."""

removed: bool = False


class _OverridableGuiPropApi:
"""Mixin that allows reading/assigning properties defined in each scene node message."""
Expand Down Expand Up @@ -157,6 +160,17 @@ def __init__(self, _impl: _GuiHandleState[T]) -> None:

def remove(self) -> None:
"""Permanently remove this GUI element from the visualizer."""

# Warn if already removed.
if self._impl.removed:
warnings.warn(
f"Attempted to remove an already removed {self.__class__.__name__}.",
stacklevel=2,
)
return
self._impl.removed = True

# Send remove to client(s) + update internal state.
self._impl.gui_api._websock_interface.queue_message(
GuiRemoveMessage(self._impl.id)
)
Expand Down Expand Up @@ -241,10 +255,25 @@ class GuiInputHandle(_GuiInputHandle[T], Generic[T]):
def on_update(
self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], Any]
) -> Callable[[GuiEvent[TGuiHandle]], None]:
"""Attach a function to call when a GUI input is updated. Happens in a thread."""
"""Attach a function to call when a GUI input is updated. Callbacks stack (need
to be manually removed via :meth:`remove_update_callback()`) and will be called
from a thread."""
self._impl.update_cb.append(func)
return func

def remove_update_callback(
self, callback: Literal["all"] | Callable = "all"
) -> None:
"""Remove update callbacks from the GUI input.
Args:
callback: Either "all" to remove all callbacks, or a specific callback function to remove.
"""
if callback == "all":
self._impl.update_cb.clear()
else:
self._impl.update_cb = [cb for cb in self._impl.update_cb if cb != callback]


class GuiCheckboxHandle(GuiInputHandle[bool], GuiCheckboxProps):
"""Handle for checkbox inputs.
Expand Down Expand Up @@ -506,6 +535,16 @@ def __post_init__(self) -> None:

def remove(self) -> None:
"""Remove this tab group and all contained GUI elements."""
# Warn if already removed.
if self._impl.removed:
warnings.warn(
f"Attempted to remove an already removed {self.__class__.__name__}.",
stacklevel=2,
)
return
self._impl.removed = True

# Remove tabs, then self.
for tab in tuple(self._tab_handles):
tab.remove()
gui_api = self._impl.gui_api
Expand All @@ -524,6 +563,7 @@ class GuiTabHandle:
_children: dict[str, SupportsRemoveProtocol] = dataclasses.field(
default_factory=dict
)
_removed: bool = False

def __enter__(self) -> GuiTabHandle:
self._container_id_restore = self._parent._impl.gui_api._get_container_id()
Expand All @@ -542,6 +582,15 @@ def __post_init__(self) -> None:
def remove(self) -> None:
"""Permanently remove this tab and all contained GUI elements from the
visualizer."""
# Warn if already removed.
if self._removed:
warnings.warn(
f"Attempted to remove an already removed {self.__class__.__name__}.",
stacklevel=2,
)
return
self._removed = True

# We may want to make this thread-safe in the future.
found_index = -1
for i, tab in enumerate(self._parent._tab_handles):
Expand Down Expand Up @@ -594,6 +643,16 @@ def __exit__(self, *args) -> None:
def remove(self) -> None:
"""Permanently remove this folder and all contained GUI elements from the
visualizer."""
# Warn if already removed.
if self._impl.removed:
warnings.warn(
f"Attempted to remove an already removed {self.__class__.__name__}.",
stacklevel=2,
)
return
self._impl.removed = True

# Remove children, then self.
self._impl.gui_api._websock_interface.queue_message(
GuiRemoveMessage(self._impl.id)
)
Expand Down
2 changes: 0 additions & 2 deletions src/viser/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,6 @@ class MeshProps:
"""A numpy array of faces, where each face is represented by indices of vertices. Should have shape (F, 3). Synchronized automatically when assigned."""
color: Optional[Tuple[int, int, int]]
"""Color of the mesh as RGB integers. Synchronized automatically when assigned."""
vertex_colors: Optional[npt.NDArray[np.uint8]]
"""Optional array of vertex colors. Synchronized automatically when assigned."""
wireframe: bool
"""Boolean indicating if the mesh should be rendered as a wireframe. Synchronized automatically when assigned."""
opacity: Optional[float]
Expand Down
12 changes: 10 additions & 2 deletions src/viser/_scene_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,7 @@ def add_mesh_skinned(
stacklevel=2,
)

assert len(bone_wxyzs) == len(bone_positions)
num_bones = len(bone_wxyzs)
assert skin_weights.shape == (vertices.shape[0], num_bones)

Expand All @@ -1059,7 +1060,6 @@ def add_mesh_skinned(
vertices=vertices.astype(np.float32),
faces=faces.astype(np.uint32),
color=_encode_rgb(color),
vertex_colors=None,
wireframe=wireframe,
opacity=opacity,
flat_shading=flat_shading,
Expand Down Expand Up @@ -1153,7 +1153,6 @@ def add_mesh_simple(
vertices=vertices.astype(np.float32),
faces=faces.astype(np.uint32),
color=_encode_rgb(color),
vertex_colors=None,
wireframe=wireframe,
opacity=opacity,
flat_shading=flat_shading,
Expand Down Expand Up @@ -1757,3 +1756,12 @@ def add_3d_gui_container(
self, message, name, wxyz, position, visible=visible
)
return Gui3dContainerHandle(node_handle._impl, gui_api, container_id)

def remove_by_name(self, name: str) -> None:
"""Helper to call `.remove()` on the scene node handles of the `name`
element or any of its children."""
handle_from_node_name = self._handle_from_node_name.copy()
name = name.rstrip("/") # '/parent/' => '/parent'
for node_name, handle in handle_from_node_name.items():
if node_name == name or node_name.startswith(name + "/"):
handle.remove()
57 changes: 48 additions & 9 deletions src/viser/_scene_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import copy
import dataclasses
import warnings
from functools import cached_property
from typing import (
TYPE_CHECKING,
Expand All @@ -17,7 +18,7 @@

import numpy as np
import numpy.typing as onpt
from typing_extensions import get_type_hints
from typing_extensions import Self, get_type_hints

from . import _messages
from .infra._infra import WebsockClientConnection, WebsockServer
Expand Down Expand Up @@ -123,10 +124,10 @@ class _SceneNodeHandleState:
default_factory=lambda: np.array([0.0, 0.0, 0.0])
)
visible: bool = True
# TODO: we should remove SceneNodeHandle as an argument here.
click_cb: list[Callable[[SceneNodePointerEvent[SceneNodeHandle]], None]] | None = (
None
)
click_cb: list[
Callable[[SceneNodePointerEvent[_ClickableSceneNodeHandle]], None]
] = dataclasses.field(default_factory=list)
removed: bool = False


class _SceneNodeMessage(Protocol):
Expand Down Expand Up @@ -223,6 +224,12 @@ def visible(self, visible: bool) -> None:

def remove(self) -> None:
"""Remove the node from the scene."""
# Warn if already removed.
if self._impl.removed:
warnings.warn(f"Attempted to remove already removed node: {self.name}")
return

self._impl.removed = True
self._impl.api._websock_interface.queue_message(
_messages.RemoveSceneNodeMessage(self._impl.name)
)
Expand Down Expand Up @@ -253,18 +260,35 @@ class SceneNodePointerEvent(Generic[TSceneNodeHandle]):

class _ClickableSceneNodeHandle(SceneNodeHandle):
def on_click(
self: TSceneNodeHandle,
func: Callable[[SceneNodePointerEvent[TSceneNodeHandle]], None],
) -> Callable[[SceneNodePointerEvent[TSceneNodeHandle]], None]:
self: Self,
func: Callable[[SceneNodePointerEvent[Self]], None],
) -> Callable[[SceneNodePointerEvent[Self]], None]:
"""Attach a callback for when a scene node is clicked."""
self._impl.api._websock_interface.queue_message(
_messages.SetSceneNodeClickableMessage(self._impl.name, True)
)
if self._impl.click_cb is None:
self._impl.click_cb = []
self._impl.click_cb.append(func) # type: ignore
self._impl.click_cb.append(
cast(
Callable[[SceneNodePointerEvent[_ClickableSceneNodeHandle]], None], func
)
)
return func

def remove_click_callback(
self, callback: Literal["all"] | Callable = "all"
) -> None:
"""Remove click callbacks from scene node.
Args:
callback: Either "all" to remove all callbacks, or a specific callback function to remove.
"""
if callback == "all":
self._impl.click_cb.clear()
else:
self._impl.click_cb = [cb for cb in self._impl.click_cb if cb != callback]


class CameraFrustumHandle(
_ClickableSceneNodeHandle,
Expand Down Expand Up @@ -510,6 +534,21 @@ def on_update(
self._impl_aux.update_cb.append(func)
return func

def remove_update_callback(
self, callback: Literal["all"] | Callable = "all"
) -> None:
"""Remove update callbacks from the transform controls.
Args:
callback: Either "all" to remove all callbacks, or a specific callback function to remove.
"""
if callback == "all":
self._impl_aux.update_cb.clear()
else:
self._impl_aux.update_cb = [
cb for cb in self._impl_aux.update_cb if cb != callback
]


class Gui3dContainerHandle(
SceneNodeHandle,
Expand Down
8 changes: 4 additions & 4 deletions src/viser/client/src/ControlPanel/ControlPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ export default function ControlPanel(props: {
controlWidthString == "small"
? "16em"
: controlWidthString == "medium"
? "20em"
: controlWidthString == "large"
? "24em"
: null
? "20em"
: controlWidthString == "large"
? "24em"
: null
)!;

const generatedServerToggleButton = (
Expand Down
24 changes: 14 additions & 10 deletions src/viser/client/src/SceneTree.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,20 @@ function useObjectFactory(message: SceneNodeMessage | undefined): {
message.props.plane == "xz"
? new THREE.Euler(0.0, 0.0, 0.0)
: message.props.plane == "xy"
? new THREE.Euler(Math.PI / 2.0, 0.0, 0.0)
: message.props.plane == "yx"
? new THREE.Euler(0.0, Math.PI / 2.0, Math.PI / 2.0)
: message.props.plane == "yz"
? new THREE.Euler(0.0, 0.0, Math.PI / 2.0)
: message.props.plane == "zx"
? new THREE.Euler(0.0, Math.PI / 2.0, 0.0)
: message.props.plane == "zy"
? new THREE.Euler(-Math.PI / 2.0, 0.0, -Math.PI / 2.0)
: undefined
? new THREE.Euler(Math.PI / 2.0, 0.0, 0.0)
: message.props.plane == "yx"
? new THREE.Euler(0.0, Math.PI / 2.0, Math.PI / 2.0)
: message.props.plane == "yz"
? new THREE.Euler(0.0, 0.0, Math.PI / 2.0)
: message.props.plane == "zx"
? new THREE.Euler(0.0, Math.PI / 2.0, 0.0)
: message.props.plane == "zy"
? new THREE.Euler(
-Math.PI / 2.0,
0.0,
-Math.PI / 2.0,
)
: undefined
}
/>
</group>
Expand Down
40 changes: 10 additions & 30 deletions src/viser/client/src/ThreeAssets.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -408,19 +408,6 @@ export const InstancedAxes = React.forwardRef<
});

/** Convert raw RGB color buffers to linear color buffers. **/
function threeColorBufferFromUint8Buffer(colors: ArrayBuffer) {
return new THREE.Float32BufferAttribute(
new Float32Array(new Uint8Array(colors)).map((value) => {
value = value / 255.0;
if (value <= 0.04045) {
return value / 12.92;
} else {
return Math.pow((value + 0.055) / 1.055, 2.4);
}
}),
3,
);
}
export const ViserMesh = React.forwardRef<
THREE.Mesh | THREE.SkinnedMesh,
MeshMessage | SkinnedMeshMessage
Expand Down Expand Up @@ -448,7 +435,6 @@ export const ViserMesh = React.forwardRef<
const standardArgs = {
color:
message.props.color === null ? undefined : rgbToInt(message.props.color),
vertexColors: message.props.vertex_colors !== null,
wireframe: message.props.wireframe,
transparent: message.props.opacity !== null,
opacity: message.props.opacity ?? 1.0,
Expand All @@ -474,16 +460,16 @@ export const ViserMesh = React.forwardRef<
message.props.material == "standard" || message.props.wireframe
? new THREE.MeshStandardMaterial(standardArgs)
: message.props.material == "toon3"
? new THREE.MeshToonMaterial({
gradientMap: generateGradientMap(3),
...standardArgs,
})
: message.props.material == "toon5"
? new THREE.MeshToonMaterial({
gradientMap: generateGradientMap(5),
...standardArgs,
})
: assertUnreachable(message.props.material);
? new THREE.MeshToonMaterial({
gradientMap: generateGradientMap(3),
...standardArgs,
})
: message.props.material == "toon5"
? new THREE.MeshToonMaterial({
gradientMap: generateGradientMap(5),
...standardArgs,
})
: assertUnreachable(message.props.material);
const geometry = new THREE.BufferGeometry();
geometry.setAttribute(
"position",
Expand All @@ -498,12 +484,6 @@ export const ViserMesh = React.forwardRef<
3,
),
);
if (message.props.vertex_colors !== null) {
geometry.setAttribute(
"color",
threeColorBufferFromUint8Buffer(message.props.vertex_colors),
);
}

geometry.setIndex(
new THREE.Uint32BufferAttribute(
Expand Down
Loading

0 comments on commit a2f4abd

Please sign in to comment.