Skip to content

Commit

Permalink
Atomic updates for multiple GUI properties (#170)
Browse files Browse the repository at this point in the history
* Multi-property GUI updates

* ruff

* Fix value casting

* Remove now-unused interface gen code

* Stronger client-side type for GuiUpdateMessage

* Docstring
  • Loading branch information
brentyi authored Feb 8, 2024
1 parent fa3ff41 commit 27cd1ad
Show file tree
Hide file tree
Showing 12 changed files with 130 additions and 92 deletions.
1 change: 1 addition & 0 deletions examples/02_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def main() -> None:
* color_coeffs[:, None]
).astype(onp.uint8),
position=gui_vector2.value + (0,),
point_shape="circle",
)

# We can use `.visible` and `.disabled` to toggle GUI elements.
Expand Down
59 changes: 29 additions & 30 deletions src/viser/_gui_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,36 +128,38 @@ def _handle_gui_updates(
handle = self._gui_handle_from_id.get(message.id, None)
if handle is None:
return

prop_name = message.prop_name
prop_value = message.prop_value
del message

handle_state = handle._impl
assert hasattr(handle_state, prop_name)
current_value = getattr(handle_state, prop_name)

has_changed = current_value != prop_value

if prop_name == "value":
# Do some type casting. This is necessary when we expect floats but the
# Javascript side gives us integers.
if handle_state.typ is tuple:
assert len(prop_value) == len(handle_state.value)
prop_value = tuple(
type(handle_state.value[i])(prop_value[i])
for i in range(len(prop_value))
)
else:
prop_value = handle_state.typ(prop_value)

has_changed = False
updates_cast = {}
for prop_name, prop_value in message.updates.items():
assert hasattr(handle_state, prop_name)
current_value = getattr(handle_state, prop_name)

# Do some type casting. This is brittle, but necessary when we
# expect floats but the Javascript side gives us integers.
if prop_name == "value":
if handle_state.typ is tuple:
assert len(prop_value) == len(handle_state.value)
prop_value = tuple(
type(handle_state.value[i])(prop_value[i])
for i in range(len(prop_value))
)
else:
prop_value = handle_state.typ(prop_value)

# Update handle property.
if current_value != prop_value:
has_changed = True
setattr(handle_state, prop_name, prop_value)

# Save value, which might have been cast.
updates_cast[prop_name] = prop_value

# Only call update when value has actually changed.
if not handle_state.is_button and not has_changed:
return

# Update state.
setattr(handle_state, prop_name, prop_value)

# Trigger callbacks.
for cb in handle_state.update_cb:
from ._viser import ClientHandle, ViserServer
Expand All @@ -174,7 +176,7 @@ def _handle_gui_updates(
cb(GuiEvent(client, client_id, handle))

if handle_state.sync_cb is not None:
handle_state.sync_cb(client_id, prop_name, prop_value)
handle_state.sync_cb(client_id, updates_cast)

def _get_container_id(self) -> str:
"""Get container ID associated with the current thread."""
Expand Down Expand Up @@ -1080,7 +1082,6 @@ def _create_gui_input(
typ=type(value),
gui_api=self,
value=value,
initial_value=value,
update_timestamp=time.time(),
container_id=self._get_container_id(),
update_cb=[],
Expand All @@ -1098,11 +1099,9 @@ def _create_gui_input(
if not is_button:

def sync_other_clients(
client_id: ClientId, prop_name: str, prop_value: Any
client_id: ClientId, updates: Dict[str, Any]
) -> None:
message = _messages.GuiUpdateMessage(
handle_state.id, prop_name, prop_value
)
message = _messages.GuiUpdateMessage(handle_state.id, updates)
message.excluded_self_client = client_id
self._get_api()._queue(message)

Expand Down
62 changes: 30 additions & 32 deletions src/viser/_gui_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@
from ._icons import base64_from_icon
from ._icons_enum import IconName
from ._message_api import _encode_image_base64
from ._messages import (
GuiCloseModalMessage,
GuiRemoveMessage,
GuiUpdateMessage,
Message,
)
from ._messages import GuiCloseModalMessage, GuiRemoveMessage, GuiUpdateMessage, Message
from .infra import ClientId

if TYPE_CHECKING:
Expand Down Expand Up @@ -81,15 +76,14 @@ class _GuiHandleState(Generic[T]):
is_button: bool
"""Indicates a button element, which requires special handling."""

sync_cb: Optional[Callable[[ClientId, str, Any], None]]
sync_cb: Optional[Callable[[ClientId, Dict[str, Any]], None]]
"""Callback for synchronizing inputs across clients."""

disabled: bool
visible: bool

order: float
id: str
initial_value: T
hint: Optional[str]

message_type: Type[Message]
Expand Down Expand Up @@ -136,7 +130,7 @@ def value(self, value: T | onp.ndarray) -> None:
# Send to client, except for buttons.
if not self._impl.is_button:
self._impl.gui_api._get_api()._queue(
GuiUpdateMessage(self._impl.id, "value", value)
GuiUpdateMessage(self._impl.id, {"value": value})
)

# Set internal state. We automatically convert numpy arrays to the expected
Expand Down Expand Up @@ -175,7 +169,7 @@ def disabled(self, disabled: bool) -> None:
return

self._impl.gui_api._get_api()._queue(
GuiUpdateMessage(self._impl.id, "disabled", disabled)
GuiUpdateMessage(self._impl.id, {"disabled": disabled})
)
self._impl.disabled = disabled

Expand All @@ -191,7 +185,7 @@ def visible(self, visible: bool) -> None:
return

self._impl.gui_api._get_api()._queue(
GuiUpdateMessage(self._impl.id, "visible", visible)
GuiUpdateMessage(self._impl.id, {"visible": visible})
)
self._impl.visible = visible

Expand Down Expand Up @@ -307,15 +301,23 @@ def options(self) -> Tuple[StringType, ...]:
@options.setter
def options(self, options: Iterable[StringType]) -> None:
self._impl_options = tuple(options)
if self._impl.initial_value not in self._impl_options:
self._impl.initial_value = self._impl_options[0]

self._impl.gui_api._get_api()._queue(
GuiUpdateMessage(self._impl.id, "options", self._impl_options)
)

if self.value not in self._impl_options:
self.value = self._impl_options[0]
need_to_overwrite_value = self.value not in self._impl_options
if need_to_overwrite_value:
self._impl.gui_api._get_api()._queue(
GuiUpdateMessage(
self._impl.id,
{"options": self._impl_options, "value": self._impl_options[0]},
)
)
self._impl.value = self._impl_options[0]
else:
self._impl.gui_api._get_api()._queue(
GuiUpdateMessage(
self._impl.id,
{"options": self._impl_options},
)
)


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -355,19 +357,14 @@ def remove(self) -> None:

def _sync_with_client(self) -> None:
"""Send messages for syncing tab state with the client."""
self._gui_api._get_api()._queue(
GuiUpdateMessage(self._tab_group_id, "tab_labels", tuple(self._labels))
)
self._gui_api._get_api()._queue(
GuiUpdateMessage(
self._tab_group_id, "tab_icons_base64", tuple(self._icons_base64)
)
)
self._gui_api._get_api()._queue(
GuiUpdateMessage(
self._tab_group_id,
"tab_container_ids",
tuple(tab._id for tab in self._tabs),
{
"tab_labels": tuple(self._labels),
"tab_icons_base64": tuple(self._icons_base64),
"tab_container_ids": tuple(tab._id for tab in self._tabs),
},
)
)

Expand Down Expand Up @@ -558,8 +555,7 @@ def content(self, content: str) -> None:
self._gui_api._get_api()._queue(
GuiUpdateMessage(
self._id,
"markdown",
_parse_markdown(content, self._image_root),
{"markdown": _parse_markdown(content, self._image_root)},
)
)

Expand All @@ -579,7 +575,9 @@ def visible(self, visible: bool) -> None:
if visible == self.visible:
return

self._gui_api._get_api()._queue(GuiUpdateMessage(self._id, "visible", visible))
self._gui_api._get_api()._queue(
GuiUpdateMessage(self._id, {"visible": visible})
)
self._visible = visible

def __post_init__(self) -> None:
Expand Down
19 changes: 14 additions & 5 deletions src/viser/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from __future__ import annotations

import dataclasses
from typing import Any, Callable, ClassVar, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Callable, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union

import numpy as onp
import numpy.typing as onpt
from typing_extensions import Literal, NotRequired, TypedDict, override
from typing_extensions import Annotated, Literal, NotRequired, TypedDict, override

from . import infra, theme

Expand Down Expand Up @@ -552,12 +552,21 @@ class GuiUpdateMessage(Message):
"""Sent client<->server when any property of a GUI component is changed."""

id: str
prop_name: str
prop_value: Any
updates: Annotated[
Dict[str, Any],
infra.TypeScriptAnnotationOverride("Partial<GuiAddComponentMessage>"),
]
"""Mapping from property name to new value."""

@override
def redundancy_key(self) -> str:
return type(self).__name__ + "-" + self.id + "-" + self.prop_name
return (
type(self).__name__
+ "-"
+ self.id
+ "-"
+ ",".join(list(self.updates.keys()))
)


@dataclasses.dataclass
Expand Down
5 changes: 2 additions & 3 deletions src/viser/client/src/ControlPanel/Generated.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ export default function GeneratedGuiContainer({
const messageSender = makeThrottledMessageSender(viewer.websocketRef, 50);

function setValue(id: string, value: any) {
updateGuiProps(id, "value", value);
updateGuiProps(id, { value: value });
messageSender({
type: "GuiUpdateMessage",
id: id,
prop_name: "value",
prop_value: value,
updates: { value: value },
});
}
return (
Expand Down
15 changes: 12 additions & 3 deletions src/viser/client/src/ControlPanel/GuiState.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ interface GuiActions {
addGui: (config: GuiConfig) => void;
addModal: (config: Messages.GuiModalMessage) => void;
removeModal: (id: string) => void;
updateGuiProps: (id: string, prop_name: string, prop_value: any) => void;
updateGuiProps: (id: string, updates: { [key: string]: any }) => void;
removeGui: (id: string) => void;
resetGui: () => void;
}
Expand Down Expand Up @@ -121,16 +121,25 @@ export function useGuiState(initialServer: string) {
state.guiOrderFromId = {};
state.guiConfigFromId = {};
}),
updateGuiProps: (id, name, value) => {
updateGuiProps: (id, updates) => {
set((state) => {
const config = state.guiConfigFromId[id];
if (config === undefined) {
console.error("Tried to update non-existent component", id);
return;
}

// Double-check that key exists.
Object.keys(updates).forEach((key) => {
if (!(key in config))
console.error(
`Tried to update nonexistent property '${key}' of GUI element ${id}!`,
);
});

state.guiConfigFromId[id] = {
...config,
[name]: value,
...updates,
} as GuiConfig;
});
},
Expand Down
2 changes: 1 addition & 1 deletion src/viser/client/src/WebsocketInterface.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ function useMessageHandler() {
}
// Update props of a GUI component
case "GuiUpdateMessage": {
updateGuiProps(message.id, message.prop_name, message.prop_value);
updateGuiProps(message.id, message.updates);
return;
}
// Remove a GUI input.
Expand Down
3 changes: 1 addition & 2 deletions src/viser/client/src/WebsocketMessages.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -622,8 +622,7 @@ export interface GuiRemoveMessage {
export interface GuiUpdateMessage {
type: "GuiUpdateMessage";
id: string;
prop_name: string;
prop_value: any;
updates: Partial<GuiAddComponentMessage>;
}
/** Message from server->client to configure parts of the GUI.
*
Expand Down
3 changes: 1 addition & 2 deletions src/viser/client/src/components/Button.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ export default function ButtonComponent({
messageSender({
type: "GuiUpdateMessage",
id: id,
prop_name: "value",
prop_value: true,
updates: { value: true },
})
}
style={{ height: "2.125em" }}
Expand Down
3 changes: 1 addition & 2 deletions src/viser/client/src/components/ButtonGroup.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ export default function ButtonGroupComponent({
messageSender({
type: "GuiUpdateMessage",
id: id,
prop_name: "value",
prop_value: option,
updates: { value: option },
})
}
style={{ flexGrow: 1, width: 0 }}
Expand Down
3 changes: 3 additions & 0 deletions src/viser/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from ._infra import MessageHandler as MessageHandler
from ._infra import Server as Server
from ._messages import Message as Message
from ._typescript_interface_gen import (
TypeScriptAnnotationOverride as TypeScriptAnnotationOverride,
)
from ._typescript_interface_gen import (
generate_typescript_interfaces as generate_typescript_interfaces,
)
Loading

0 comments on commit 27cd1ad

Please sign in to comment.