Skip to content

Commit

Permalink
URDF refactor + related low-level improvements & fixes (#67)
Browse files Browse the repository at this point in the history
* GUI container refactor

* Clean up panel display, still some kinks to work out

* Improve expanded panel behavior; drag regression for mobile view

* Fix toggle for sidebar display

* Fix type

* URDF refactor

* URDF refactor, state serialization

* Add scale parameter to URDF

* Folder / tab remove(), switch back to contexts

* Performance improvements

* Put icons in tarball, clean up SMPL-X example

* Lint

* Cleanup

* Suppress mypy error

* Cleanup

* Add missing icons

* Remove debug print

* Remove outdated client README

* Nits

* Fix Accordion for empty folder names

* Fix container removal, context thread safety

* Address comments from Chungmin, fix container removal, context thread safety

* Fix type error

* Fix URDF example

* Fix tooltip zIndex

* Reduce asyncio overhead (significantly!)

* Bug fixes, details
  • Loading branch information
brentyi committed Jul 24, 2023
1 parent 307f06d commit e02410f
Show file tree
Hide file tree
Showing 15 changed files with 466 additions and 212 deletions.
6 changes: 3 additions & 3 deletions examples/05_camera_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,17 @@ def _(_):

T_current_target = T_world_current.inverse() @ T_world_target

for j in range(50):
for j in range(20):
T_world_set = T_world_current @ tf.SE3.exp(
T_current_target.log() * j / 49.0
T_current_target.log() * j / 19.0
)

# Important bit: we atomically set both the orientation and the position
# of the camera.
with client.atomic():
client.camera.wxyz = T_world_set.rotation().wxyz
client.camera.position = T_world_set.translation()
time.sleep(0.01)
time.sleep(1.0 / 60.0)

# Mouse interactions should orbit around the frame origin.
client.camera.look_at = frame.position
Expand Down
17 changes: 11 additions & 6 deletions examples/07_record3d_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

def main(
data_path: Path = Path(__file__).parent / "assets/record3d_dance",
downsample_factor: int = 2,
max_frames: int = 50,
downsample_factor: int = 4,
max_frames: int = 100,
) -> None:
server = viser.ViserServer()

Expand All @@ -30,11 +30,16 @@ def main(
# Add playback UI.
with server.add_gui_folder("Playback"):
gui_timestep = server.add_gui_slider(
"Timestep", min=0, max=num_frames - 1, step=1, initial_value=0
"Timestep",
min=0,
max=num_frames - 1,
step=1,
initial_value=0,
disabled=True,
)
gui_next_frame = server.add_gui_button("Next Frame")
gui_prev_frame = server.add_gui_button("Prev Frame")
gui_playing = server.add_gui_checkbox("Playing", False)
gui_next_frame = server.add_gui_button("Next Frame", disabled=True)
gui_prev_frame = server.add_gui_button("Prev Frame", disabled=True)
gui_playing = server.add_gui_checkbox("Playing", True)
gui_framerate = server.add_gui_slider(
"FPS", min=1, max=60, step=0.1, initial_value=loader.fps
)
Expand Down
111 changes: 36 additions & 75 deletions examples/09_urdf_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,97 +6,58 @@
- https://github.com/OrebroUniversity/yumi/blob/master/yumi_description/urdf/yumi.urdf
- https://github.com/ankurhanda/robot-assets
"""
from __future__ import annotations

import time
from functools import partial
from pathlib import Path
from typing import List

import numpy as onp
import trimesh
import tyro
import yourdfpy

import viser
import viser.transforms as tf
from viser.extras import ViserUrdf


def main(urdf_path: Path) -> None:
urdf = yourdfpy.URDF.load(
urdf_path,
filename_handler=partial(yourdfpy.filename_handler_magic, dir=urdf_path.parent),
)
server = viser.ViserServer()

def frame_name_with_parents(frame_name: str) -> str:
frames = []
while frame_name != urdf.scene.graph.base_frame:
frames.append(frame_name)
frame_name = urdf.scene.graph.transforms.parents[frame_name]
return "/" + "/".join(frames[::-1])

for frame_name, mesh in urdf.scene.geometry.items():
assert isinstance(mesh, trimesh.Trimesh)
T_parent_child = urdf.get_transform(
frame_name, urdf.scene.graph.transforms.parents[frame_name]
# Create a helper for adding URDFs to Viser. This just adds meshes to the scene,
# helps us set the joint angles, etc.
urdf = ViserUrdf(server, urdf_path)

# Create joint angle sliders.
gui_joints: List[viser.GuiHandle[float]] = []
initial_angles: List[float] = []
for joint_name, (lower, upper) in urdf.get_actuated_joint_limits().items():
lower = lower if lower is not None else -onp.pi
upper = upper if upper is not None else onp.pi

initial_angle = 0.0 if lower < 0 and upper > 0 else (lower + upper) / 2.0
slider = server.add_gui_slider(
label=joint_name,
min=lower,
max=upper,
step=1e-3,
initial_value=initial_angle,
)
server.add_mesh_trimesh(
frame_name_with_parents(frame_name),
mesh,
wxyz=tf.SO3.from_matrix(T_parent_child[:3, :3]).wxyz,
position=T_parent_child[:3, 3],
slider.on_update( # When sliders move, we update the URDF configuration.
lambda _: urdf.update_cfg(onp.array([gui.value for gui in gui_joints]))
)

gui_joints: List[viser.GuiHandle[float]] = []
with server.add_gui_folder("Joints"):
button = server.add_gui_button("Reset")

@button.on_click
def _(_):
for g in gui_joints:
g.value = 0.0

def update_frames():
urdf.update_cfg(onp.array([gui.value for gui in gui_joints]))
for joint in urdf.joint_map.values():
assert isinstance(joint, yourdfpy.Joint)
T_parent_child = urdf.get_transform(joint.child, joint.parent)
server.add_frame(
frame_name_with_parents(joint.child),
wxyz=tf.SO3.from_matrix(T_parent_child[:3, :3]).wxyz,
position=T_parent_child[:3, 3],
show_axes=False,
)

for joint_name, joint in urdf.joint_map.items():
assert isinstance(joint, yourdfpy.Joint)

min = (
joint.limit.lower
if joint.limit is not None and joint.limit.lower is not None
else -onp.pi
)
max = (
joint.limit.upper
if joint.limit is not None and joint.limit.upper is not None
else onp.pi
)
slider = server.add_gui_slider(
label=joint_name,
min=min,
max=max,
step=1e-3,
initial_value=0.0 if min < 0 and max > 0 else (min + max) / 2.0,
)
if joint.limit is None:
slider.visible = False

@slider.on_update
def _(_):
update_frames()

gui_joints.append(slider)

update_frames()
gui_joints.append(slider)
initial_angles.append(initial_angle)

# Create joint reset button.
reset_button = server.add_gui_button("Reset")

@reset_button.on_click
def _(_):
for g, initial_angle in zip(gui_joints, initial_angles):
g.value = initial_angle

# Apply initial joint angles.
urdf.update_cfg(onp.array([gui.value for gui in gui_joints]))

while True:
time.sleep(10.0)
Expand Down
16 changes: 15 additions & 1 deletion viser/_message_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import io
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, TypeVar, Union, cast

import imageio.v3 as iio
Expand Down Expand Up @@ -133,6 +134,7 @@ def __init__(self, handler: infra.MessageHandler) -> None:

self._atomic_lock = threading.Lock()
self._locked_thread_id = -1
self._queue_thread = ThreadPoolExecutor(max_workers=1)

def configure_theme(
self,
Expand Down Expand Up @@ -447,8 +449,20 @@ def reset_scene(self):
"""Reset the scene."""
self._queue(_messages.ResetSceneMessage())

@abc.abstractmethod
def _queue(self, message: _messages.Message) -> None:
"""Wrapped method for sending messages safely."""
# This implementation will retain message ordering because _queue_thread has
# just 1 worker.
self._queue_thread.submit(lambda: self._queue_blocking(message))

def _queue_blocking(self, message: _messages.Message) -> None:
"""Wrapped method for sending messages safely. Blocks until ready to send."""
self._atomic_lock.acquire()
self._queue_unsafe(message)
self._atomic_lock.release()

@abc.abstractmethod
def _queue_unsafe(self, message: _messages.Message) -> None:
"""Abstract method for sending messages."""
...

Expand Down
10 changes: 0 additions & 10 deletions viser/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,16 +419,6 @@ class GuiSetValueMessage(Message):
value: Any


@dataclasses.dataclass
class MessageGroupStart(Message):
"""Sent server->client to indicate the start of a message group."""


@dataclasses.dataclass
class MessageGroupEnd(Message):
"""Sent server->client to indicate the end of a message group."""


@dataclasses.dataclass
class ThemeConfigurationMessage(Message):
"""Message from server->client to configure parts of the GUI."""
Expand Down
37 changes: 13 additions & 24 deletions viser/_viser.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def on_update(

@dataclasses.dataclass
class _ClientHandleState:
server: infra.Server
connection: infra.ClientConnection


Expand All @@ -175,15 +176,16 @@ def _get_api(self) -> MessageApi:
return self

@override
def _queue(self, message: infra.Message) -> None:
def _queue_unsafe(self, message: _messages.Message) -> None:
"""Define how the message API should send messages."""
self._state.connection.send(message)

@contextlib.contextmanager
def atomic(self) -> Generator[None, None, None]:
"""Returns a context where:
- All outgoing messages are grouped and applied by clients atomically.
- No incoming messages, like camera or GUI state updates, are processed.
- `viser` will attempt to group outgoing messages, which will then be sent after
the context is exited.
This can be helpful for things like animations, or when we want position and
orientation updates to happen synchronously.
Expand All @@ -194,18 +196,21 @@ def atomic(self) -> Generator[None, None, None]:
got_lock = False
else:
self._atomic_lock.acquire()
self._queue(_messages.MessageGroupStart())
self._locked_thread_id = thread_id
got_lock = True

yield

if got_lock:
self._queue(_messages.MessageGroupEnd())
self._atomic_lock.release()
self._locked_thread_id = -1


# We can serialize the state of a ViserServer via a tuple of
# (serialized message, timestamp) pairs.
SerializedServerState = Tuple[Tuple[bytes, float], ...]


@dataclasses.dataclass
class _ViserServerState:
connection: infra.Server
Expand All @@ -231,6 +236,7 @@ def __init__(self, host: str = "0.0.0.0", port: int = 8080):
message_class=_messages.Message,
http_server_root=Path(__file__).absolute().parent / "client" / "build",
)
self._server = server
super().__init__(server)

_client_autobuild.ensure_client_is_built()
Expand Down Expand Up @@ -260,7 +266,7 @@ def _(conn: infra.ClientConnection) -> None:
client = ClientHandle(
conn.client_id,
camera,
_ClientHandleState(conn),
_ClientHandleState(server, conn),
)
camera._state.client = client
first = True
Expand Down Expand Up @@ -330,9 +336,9 @@ def _get_api(self) -> MessageApi:
return self

@override
def _queue(self, message: infra.Message) -> None:
def _queue_unsafe(self, message: _messages.Message) -> None:
"""Define how the message API should send messages."""
self._state.connection.broadcast(message)
self._server.broadcast(message)

def get_clients(self) -> Dict[int, ClientHandle]:
"""Creates and returns a copy of the mapping from connected client IDs to
Expand Down Expand Up @@ -393,25 +399,8 @@ def atomic(self) -> Generator[None, None, None]:
for client in self.get_clients().values():
stack.enter_context(client._atomic_lock)

self._queue(_messages.MessageGroupStart())

# There's a possible race condition here if we write something like:
#
# with server.atomic():
# client.add_frame(...)
#
# - We enter the server's atomic() context.
# - The server pushes a MessageGroupStart() to the broadcast buffer.
# - The client pushes a frame message to the client buffer.
# - For whatever reason the client buffer is handled before the broadcast
# buffer.
#
# The likelihood of this seems exceedingly low, but I don't think we have
# any actual guarantees here. Likely worth revisiting but super lower
# priority.
yield

if got_lock:
self._queue(_messages.MessageGroupEnd())
self._atomic_lock.release()
self._locked_thread_id = -1
14 changes: 11 additions & 3 deletions viser/client/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ function SingleViewer() {
// Scene node attributes that aren't placed in the zustand state, for performance reasons.
nodeAttributesFromName: React.useRef({}),
};

// Memoize the websocket interface so it isn't remounted when the theme or
// viewer context changes.
const memoizedWebsocketInterface = React.useMemo(
() => <WebsocketInterface />,
[]
);

const fixed_sidebar = viewer.useGui((state) => state.theme.fixed_sidebar);
return (
<ViewerContext.Provider value={viewer}>
Expand All @@ -94,7 +102,6 @@ function SingleViewer() {
flex: "1 0 auto",
}}
>
<WebsocketInterface />
<MediaQuery smallerThan={"xs"} styles={{ right: 0, bottom: "3.5em" }}>
<Box
sx={(theme) => ({
Expand All @@ -107,7 +114,7 @@ function SingleViewer() {
theme.colorScheme === "light" ? "#fff" : theme.colors.dark[9],
})}
>
<ViewerCanvas />
<ViewerCanvas>{memoizedWebsocketInterface}</ViewerCanvas>
</Box>
</MediaQuery>
<ControlPanel fixed_sidebar={fixed_sidebar} />
Expand All @@ -116,7 +123,7 @@ function SingleViewer() {
);
}

function ViewerCanvas() {
function ViewerCanvas({ children }: { children: React.ReactNode }) {
const viewer = React.useContext(ViewerContext)!;
return (
<Canvas
Expand All @@ -131,6 +138,7 @@ function ViewerCanvas() {
performance={{ min: 0.95 }}
ref={viewer.canvasRef}
>
{children}
<AdaptiveDpr pixelated />
<AdaptiveEvents />
<SceneContextSetter />
Expand Down
Loading

0 comments on commit e02410f

Please sign in to comment.