From f82f7a2a2fd8a74e818ab109620c0ad4021b4e82 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Wed, 17 Jul 2024 18:22:06 +0900 Subject: [PATCH] Minor bug fixes (serialization, search params), playback improvements --- src/viser/client/src/App.tsx | 12 ++++--- src/viser/client/src/FilePlayback.tsx | 2 +- src/viser/client/src/SearchParamsUtils.tsx | 37 +++++++++++----------- src/viser/infra/_messages.py | 26 +++++++++------ 4 files changed, 43 insertions(+), 34 deletions(-) diff --git a/src/viser/client/src/App.tsx b/src/viser/client/src/App.tsx index 4079c53a1..817097273 100644 --- a/src/viser/client/src/App.tsx +++ b/src/viser/client/src/App.tsx @@ -132,10 +132,9 @@ function ViewerRoot() { servers.length >= 1 ? servers[0] : getDefaultServerFromUrl(); // Playback mode for embedding viser. - const playbackPath = new URLSearchParams(window.location.search).get( - "playbackPath", - ); - console.log(playbackPath); + const searchParams = new URLSearchParams(window.location.search); + const playbackPath = searchParams.get("playbackPath"); + const darkMode = searchParams.get("darkMode") !== null; // Values that can be globally accessed by components in a viewer. const viewer: ViewerContextContents = { @@ -179,6 +178,9 @@ function ViewerRoot() { skinnedMeshState: React.useRef({}), }; + // Set dark default if specified in URL. + if (darkMode) viewer.useGui.getState().theme.dark_mode = darkMode; + return ( @@ -441,7 +443,7 @@ function ViewerCanvas({ children }: { children: React.ReactNode }) { - + ); diff --git a/src/viser/client/src/SearchParamsUtils.tsx b/src/viser/client/src/SearchParamsUtils.tsx index d198feb73..747f571fa 100644 --- a/src/viser/client/src/SearchParamsUtils.tsx +++ b/src/viser/client/src/SearchParamsUtils.tsx @@ -5,30 +5,31 @@ export const searchParamKey = "websocket"; export function syncSearchParamServer(server: string) { - setServerParams([server]); -} - -function setServerParams(serverParams: string[]) { + const searchParams = new URLSearchParams(window.location.search); // No need to update the URL bar if the websocket port matches the HTTP port. // So if we navigate to http://localhost:8081, this should by default connect to ws://localhost:8081. - if ( - serverParams.length === 1 && - (window.location.host.includes( - serverParams[0].replace("ws://", "").replace("/", ""), + const isDefaultServer = + window.location.host.includes( + server.replace("ws://", "").replace("/", ""), ) || - window.location.host.includes( - serverParams[0].replace("wss://", "").replace("/", ""), - )) - ) - serverParams = []; - + window.location.host.includes( + server.replace("wss://", "").replace("/", ""), + ); + if (isDefaultServer && searchParams.has(searchParamKey)) { + searchParams.delete(searchParamKey); + } else if (!isDefaultServer) { + searchParams.set(searchParamKey, server); + } window.history.replaceState( null, "Viser", - // We could use URLSearchParams() to build this string, but that would escape - // it. We're going to just not escape the string. :) - serverParams.length === 0 + // We could use URLSearchParams.toString() to build this string, but that + // would escape it. We're going to just not escape the string. :) + searchParams.size === 0 ? window.location.href.split("?")[0] - : `?${serverParams.map((s) => `${searchParamKey}=${s}`).join("&")}`, + : "?" + + Array.from(searchParams.entries()) + .map(([k, v]) => `${k}=${v}`) + .join("&"), ); } diff --git a/src/viser/infra/_messages.py b/src/viser/infra/_messages.py index 204d66afe..a6c17a4cd 100644 --- a/src/viser/infra/_messages.py +++ b/src/viser/infra/_messages.py @@ -44,20 +44,20 @@ def _prepare_for_deserialization(value: Any, annotation: Type) -> Any: return value -def _prepare_for_serialization(value: Any, annotation: Type) -> Any: +def _prepare_for_serialization(value: Any, annotation: object) -> Any: """Prepare any special types for serialization.""" if annotation is Any: annotation = type(value) # Coerce some scalar types: if we've annotated as float / int but we get an # onp.float32 / onp.int64, for example, we should cast automatically. - if annotation is float: + if annotation is float or isinstance(value, onp.floating): return float(value) - if annotation is int: + if annotation is int or isinstance(value, onp.integer): return int(value) # Recursively handle tuples. - if get_origin(annotation) is tuple: + if isinstance(value, tuple): if isinstance(value, onp.ndarray): assert False, ( "Expected a tuple, but got an array... missing a cast somewhere?" @@ -65,12 +65,15 @@ def _prepare_for_serialization(value: Any, annotation: Type) -> Any: ) out = [] - args = get_args(annotation) - if len(args) >= 2 and args[1] == ...: - args = (args[0],) * len(value) - elif len(value) != len(args): - warnings.warn(f"[viser] {value} does not match annotation {annotation}") - return value + if get_origin(annotation) is tuple: + args = get_args(annotation) + if len(args) >= 2 and args[1] == ...: + args = (args[0],) * len(value) + elif len(value) != len(args): + warnings.warn(f"[viser] {value} does not match annotation {annotation}") + return value + else: + args = [Any] * len(value) for i, v in enumerate(value): out.append( @@ -85,6 +88,9 @@ def _prepare_for_serialization(value: Any, annotation: Type) -> Any: if isinstance(value, onp.ndarray): return value.data if value.data.c_contiguous else value.copy().data + if isinstance(value, dict): + return {k: _prepare_for_serialization(v, Any) for k, v in value.items()} # type: ignore + return value