From 24ad049830d0714b61f5c9ca66d822dbec24e174 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 30 Jul 2024 19:45:56 -0700 Subject: [PATCH] Untie splat rendering components from `viser`-specific internals (#258) --- examples/experimental/gaussian_splats.py | 9 +- src/viser/client/src/App.tsx | 21 +-- src/viser/client/src/MessageHandler.tsx | 39 ++--- .../client/src/Splatting/GaussianSplats.tsx | 162 ++++++++++++++---- .../client/src/Splatting/SplatContext.ts | 31 ---- 5 files changed, 161 insertions(+), 101 deletions(-) delete mode 100644 src/viser/client/src/Splatting/SplatContext.ts diff --git a/examples/experimental/gaussian_splats.py b/examples/experimental/gaussian_splats.py index ddc41233..b5f6e282 100644 --- a/examples/experimental/gaussian_splats.py +++ b/examples/experimental/gaussian_splats.py @@ -129,7 +129,7 @@ def _(event: viser.GuiEvent) -> None: raise SystemExit("Please provide a filepath to a .splat or .ply file.") server.scene.add_transform_controls(f"/{i}") - server.scene._add_gaussian_splats( + gs_handle = server.scene._add_gaussian_splats( f"/{i}/gaussian_splats", centers=splat_data["centers"], rgbs=splat_data["rgbs"], @@ -137,6 +137,13 @@ def _(event: viser.GuiEvent) -> None: covariances=splat_data["covariances"], ) + remove_button = server.gui.add_button(f"Remove splat object {i}") + + @remove_button.on_click + def _(_, gs_handle=gs_handle) -> None: + gs_handle.remove() + remove_button.remove() + while True: time.sleep(10.0) diff --git a/src/viser/client/src/App.tsx b/src/viser/client/src/App.tsx index 86df31c5..9e9eb180 100644 --- a/src/viser/client/src/App.tsx +++ b/src/viser/client/src/App.tsx @@ -45,13 +45,9 @@ import { useDisclosure } from "@mantine/hooks"; import { rayToViserCoords } from "./WorldTransformUtils"; import { ndcFromPointerXy, opencvXyFromPointerXy } from "./ClickUtils"; import { theme } from "./AppTheme"; -import { - GaussianSplatsContext, - useGaussianSplatStore, -} from "./Splatting/SplatContext"; import { FrameSynchronizedMessageHandler } from "./MessageHandler"; import { PlaybackFromFile } from "./FilePlayback"; -import GlobalGaussianSplats from "./Splatting/GaussianSplats"; +import { SplatRenderContext } from "./Splatting/GaussianSplats"; import { BrowserWarning } from "./BrowserWarning"; export type ViewerContextContents = { @@ -267,12 +263,9 @@ function ViewerContents({ children }: { children: React.ReactNode }) { })} > - - - - - - + + + {viewer.useGui((state) => state.theme.show_logo) && viewer.messageSource == "websocket" ? ( @@ -453,7 +446,9 @@ function ViewerCanvas({ children }: { children: React.ReactNode }) { - + + + (refreshrate > 90 ? [40, 80] : [40, 50])} + bounds={(refreshrate) => (refreshrate > 90 ? [80, 90] : [50, 60])} onChange={({ factor, fps, refreshrate }) => { const dpr = window.devicePixelRatio * (0.2 + 0.8 * factor); console.log( diff --git a/src/viser/client/src/MessageHandler.tsx b/src/viser/client/src/MessageHandler.tsx index a533a31e..3d43408c 100644 --- a/src/viser/client/src/MessageHandler.tsx +++ b/src/viser/client/src/MessageHandler.tsx @@ -29,7 +29,7 @@ import GeneratedGuiContainer from "./ControlPanel/Generated"; import { Paper, Progress } from "@mantine/core"; import { IconCheck } from "@tabler/icons-react"; import { computeT_threeworld_world } from "./WorldTransformUtils"; -import { GaussianSplatsContext } from "./Splatting/SplatContext"; +import { SplatObject } from "./Splatting/GaussianSplats"; /** Convert raw RGB color buffers to linear color buffers. **/ function threeColorBufferFromUint8Buffer(colors: ArrayBuffer) { @@ -51,10 +51,6 @@ function useMessageHandler() { const viewer = useContext(ViewerContext)!; const ContextBridge = useContextBridge(); - const splatContext = useContext(GaussianSplatsContext)!; - const setGaussianBuffer = splatContext((state) => state.setBuffer); - const removeGaussianBuffer = splatContext((state) => state.removeBuffer); - // We could reduce the redundancy here if we wanted to. // https://github.com/nerfstudio-project/viser/issues/39 const removeSceneNode = viewer.useSceneTree((state) => state.removeSceneNode); @@ -970,25 +966,22 @@ function useMessageHandler() { return; } case "GaussianSplatsMessage": { - setGaussianBuffer( - message.name, - new Uint32Array( - message.buffer.buffer.slice( - message.buffer.byteOffset, - message.buffer.byteOffset + message.buffer.byteLength, - ), - ), - ); addSceneNodeMakeParents( - new SceneNode( - message.name, - (ref) => { - return ; - }, - () => { - removeGaussianBuffer(message.name); - }, - ), + new SceneNode(message.name, (ref) => { + return ( + + ); + }), ); return; } diff --git a/src/viser/client/src/Splatting/GaussianSplats.tsx b/src/viser/client/src/Splatting/GaussianSplats.tsx index 4501d012..4d42313a 100644 --- a/src/viser/client/src/Splatting/GaussianSplats.tsx +++ b/src/viser/client/src/Splatting/GaussianSplats.tsx @@ -1,14 +1,85 @@ +/** Gaussian splatting implementation for viser. + * + * This borrows heavily from existing open-source implementations. Particularly + * useful references: + * - https://github.com/quadjr/aframe-gaussian-splatting + * - https://github.com/antimatter15/splat + * - https://github.com/pmndrs/drei + * - https://github.com/vincent-lecrubier-skydio/react-three-fiber-gaussian-splat + * + * Usage should look like: + * + * + * + * + * + * + * + * Where `buffer` contains serialized Gaussian attributes. SplatObjects are + * globally sorted by a worker (with some help from WebAssembly + SIMD + * intrinsics), and then rendered as a single threejs mesh. Unlike other R3F + * implementations that we're aware of, this enables correct compositing + * between multiple splat objects. + */ + import React from "react"; import * as THREE from "three"; import SplatSortWorker from "./SplatSortWorker?worker"; import { useFrame, useThree } from "@react-three/fiber"; import { shaderMaterial } from "@react-three/drei"; -import { GaussianSplatsContext } from "./SplatContext"; -import { ViewerContext } from "../App"; import { SorterWorkerIncoming } from "./SplatSortWorker"; +import { create } from "zustand"; +import { Object3D } from "three"; + +/**Global splat state.*/ +interface SplatState { + groupBufferFromId: { [id: string]: Uint32Array }; + nodeRefFromId: React.MutableRefObject<{ + [name: string]: undefined | Object3D; + }>; + setBuffer: (id: string, buffer: Uint32Array) => void; + removeBuffer: (id: string) => void; +} -function postToWorker(worker: Worker, message: SorterWorkerIncoming) { - worker.postMessage(message); +/**Hook for creating global splat state.*/ +function useGaussianSplatStore() { + const nodeRefFromId = React.useRef({}); + return React.useState(() => + create((set) => ({ + groupBufferFromId: {}, + nodeRefFromId: nodeRefFromId, + setBuffer: (id, buffer) => { + return set((state) => ({ + groupBufferFromId: { ...state.groupBufferFromId, [id]: buffer }, + })); + }, + removeBuffer: (id) => { + return set((state) => { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const { [id]: _, ...buffers } = state.groupBufferFromId; + return { groupBufferFromId: buffers }; + }); + }, + })), + )[0]; +} +const GaussianSplatsContext = React.createContext | null>(null); + +/**Provider for creating splat rendering context.*/ +export function SplatRenderContext({ + children, +}: { + children: React.ReactNode; +}) { + const store = useGaussianSplatStore(); + return ( + + + {children} + + ); } const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( @@ -52,14 +123,6 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( out vec4 vRgba; out vec2 vPosition; - float hash2D(vec2 value) { - return fract( 1.0e4 * sin( 17.0 * value.x + 0.1 * value.y ) * ( 0.1 + abs( sin( 13.0 * value.y + value.x ) ) ) ); - } - - float hash3D(vec3 value) { - return hash2D( vec2( hash2D( value.xy ), value.z ) ); - } - // Function to fetch and construct the i-th transform matrix using texelFetch mat4 getGroupTransform(uint i) { // Calculate the base index for the i-th transform. @@ -80,14 +143,16 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( void main () { // Get position + scale from float buffer. ivec2 texSize = textureSize(textureBuffer, 0); - ivec2 texPos0 = ivec2((sortedIndex * 2u) % uint(texSize.x), (sortedIndex * 2u) / uint(texSize.x)); + uint texStart = sortedIndex << 1u; + ivec2 texPos0 = ivec2(texStart % uint(texSize.x), texStart / uint(texSize.x)); + // Fetch from textures. uvec4 floatBufferData = texelFetch(textureBuffer, texPos0, 0); mat4 T_camera_group = getGroupTransform(floatBufferData.w); // Any early return will discard the fragment. - gl_Position = vec4(0.0, 0.0, 2000.0, 1.0); + gl_Position = vec4(0.0, 0.0, 2.0, 1.0); // Get center wrt camera. modelViewMatrix is T_cam_world. vec3 center = uintBitsToFloat(floatBufferData.xyz); @@ -100,7 +165,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( return; // Read covariance terms. - ivec2 texPos1 = ivec2((sortedIndex * 2u + 1u) % uint(texSize.x), (sortedIndex * 2u + 1u) / uint(texSize.x)); + ivec2 texPos1 = ivec2((texStart + 1u) % uint(texSize.x), (texStart + 1u) / uint(texSize.x)); uvec4 intBufferData = texelFetch(textureBuffer, texPos1, 0); // Get covariance terms from int buffer. @@ -111,7 +176,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( // Transition in. float startTime = 0.8 * float(sortedIndex) / float(numGaussians); - float cov_scale = clamp((transitionInState - startTime) / 0.2f, 0.0f, 1.0f); // min() can freeze here. Not sure why... + float cov_scale = smoothstep(startTime, startTime + 0.2, transitionInState); // Do the actual splatting. mat3 chol = mat3( @@ -152,10 +217,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( // Throw the Gaussian off the screen if it's too close, too far, or too small. float weightedDeterminant = vRgba.a * (diag1 * diag2 - offDiag * offDiag); - if (weightedDeterminant < 0.1) - return; - // This is not principled. It just makes things faster. - if (weightedDeterminant < 1.0 && hash3D(center) < weightedDeterminant) + if (weightedDeterminant < 0.5) return; vPosition = position.xy; @@ -182,16 +244,48 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( }`, ); +export const SplatObject = React.forwardRef< + THREE.Group, + { + buffer: Uint32Array; + } +>(function SplatObject({ buffer }, ref) { + const splatContext = React.useContext(GaussianSplatsContext)!; + const setBuffer = splatContext((state) => state.setBuffer); + const removeBuffer = splatContext((state) => state.removeBuffer); + const nodeRefFromId = splatContext((state) => state.nodeRefFromId); + const name = React.useMemo(() => crypto.randomUUID(), [buffer]); + + const [obj, setRef] = React.useState(null); + + React.useEffect(() => { + if (obj === null) return; + setBuffer(name, buffer); + if (ref !== null) { + if ("current" in ref) { + ref.current = obj; + } else { + ref(obj); + } + } + nodeRefFromId.current[name] = obj; + return () => { + removeBuffer(name); + delete nodeRefFromId.current[name]; + }; + }, [obj]); + + return ; +}); + /** External interface. Component should be added to the root of canvas. */ -export default function GlobalGaussianSplats() { - const viewer = React.useContext(ViewerContext)!; +function SplatRenderer() { const splatContext = React.useContext(GaussianSplatsContext)!; - const groupBufferFromName = splatContext( - (state) => state.groupBufferFromName, - ); + const groupBufferFromId = splatContext((state) => state.groupBufferFromId); + const nodeRefFromId = splatContext((state) => state.nodeRefFromId); // Consolidate Gaussian groups into a single buffer. - const merged = mergeGaussianGroups(groupBufferFromName); + const merged = mergeGaussianGroups(groupBufferFromId); const meshProps = useGaussianMeshProps( merged.gaussianBuffer, merged.numGroups, @@ -213,7 +307,11 @@ export default function GlobalGaussianSplats() { initializedBufferTexture = true; } }; - postToWorker(sortWorker, { + function postToWorker(message: SorterWorkerIncoming) { + sortWorker.postMessage(message); + } + + postToWorker({ setBuffer: merged.gaussianBuffer, setGroupIndices: merged.groupIndices, }); @@ -224,7 +322,7 @@ export default function GlobalGaussianSplats() { meshProps.textureBuffer.dispose(); meshProps.geometry.dispose(); meshProps.material.dispose(); - postToWorker(sortWorker, { close: true }); + postToWorker({ close: true }); }; }); @@ -267,10 +365,8 @@ export default function GlobalGaussianSplats() { const T_camera_world = state.camera.matrixWorldInverse; const groupVisibles: boolean[] = []; let visibilitiesChanged = false; - for (const [groupIndex, name] of Object.keys( - groupBufferFromName, - ).entries()) { - const node = viewer.nodeRefFromName.current[name]; + for (const [groupIndex, name] of Object.keys(groupBufferFromId).entries()) { + const node = nodeRefFromId.current[name]; if (node === undefined) continue; tmpT_camera_group.copy(T_camera_world).multiply(node.matrixWorld); const colMajorElements = tmpT_camera_group.elements; @@ -312,7 +408,7 @@ export default function GlobalGaussianSplats() { if (groupsMovedWrtCam) { // Gaussians need to be re-sorted. - postToWorker(sortWorker, { + postToWorker({ setTz_camera_groups: Tz_camera_groups, }); } diff --git a/src/viser/client/src/Splatting/SplatContext.ts b/src/viser/client/src/Splatting/SplatContext.ts deleted file mode 100644 index bbe45a31..00000000 --- a/src/viser/client/src/Splatting/SplatContext.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { create } from "zustand"; -import React from "react"; - -interface SplatState { - groupBufferFromName: { [name: string]: Uint32Array }; - setBuffer: (name: string, buffer: Uint32Array) => void; - removeBuffer: (name: string) => void; -} - -export function useGaussianSplatStore() { - return React.useState(() => - create((set) => ({ - groupBufferFromName: {}, - setBuffer: (name, buffer) => { - return set((state) => ({ - groupBufferFromName: { ...state.groupBufferFromName, [name]: buffer }, - })); - }, - removeBuffer: (name) => { - return set((state) => { - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const { [name]: _, ...buffers } = state.groupBufferFromName; - return { groupBufferFromName: buffers }; - }); - }, - })), - )[0]; -} -export const GaussianSplatsContext = React.createContext | null>(null);