Skip to content

Commit

Permalink
Untie splat rendering components from viser-specific internals (#258)
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jul 31, 2024
1 parent 9d9bb79 commit 24ad049
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 101 deletions.
9 changes: 8 additions & 1 deletion examples/experimental/gaussian_splats.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,21 @@ def _(event: viser.GuiEvent) -> None:
raise SystemExit("Please provide a filepath to a .splat or .ply file.")

server.scene.add_transform_controls(f"/{i}")
server.scene._add_gaussian_splats(
gs_handle = server.scene._add_gaussian_splats(
f"/{i}/gaussian_splats",
centers=splat_data["centers"],
rgbs=splat_data["rgbs"],
opacities=splat_data["opacities"],
covariances=splat_data["covariances"],
)

remove_button = server.gui.add_button(f"Remove splat object {i}")

@remove_button.on_click
def _(_, gs_handle=gs_handle) -> None:
gs_handle.remove()
remove_button.remove()

while True:
time.sleep(10.0)

Expand Down
21 changes: 8 additions & 13 deletions src/viser/client/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,9 @@ import { useDisclosure } from "@mantine/hooks";
import { rayToViserCoords } from "./WorldTransformUtils";
import { ndcFromPointerXy, opencvXyFromPointerXy } from "./ClickUtils";
import { theme } from "./AppTheme";
import {
GaussianSplatsContext,
useGaussianSplatStore,
} from "./Splatting/SplatContext";
import { FrameSynchronizedMessageHandler } from "./MessageHandler";
import { PlaybackFromFile } from "./FilePlayback";
import GlobalGaussianSplats from "./Splatting/GaussianSplats";
import { SplatRenderContext } from "./Splatting/GaussianSplats";
import { BrowserWarning } from "./BrowserWarning";

export type ViewerContextContents = {
Expand Down Expand Up @@ -267,12 +263,9 @@ function ViewerContents({ children }: { children: React.ReactNode }) {
})}
>
<Viewer2DCanvas />
<GaussianSplatsContext.Provider value={useGaussianSplatStore()}>
<ViewerCanvas>
<GlobalGaussianSplats />
<FrameSynchronizedMessageHandler />
</ViewerCanvas>
</GaussianSplatsContext.Provider>
<ViewerCanvas>
<FrameSynchronizedMessageHandler />
</ViewerCanvas>
{viewer.useGui((state) => state.theme.show_logo) &&
viewer.messageSource == "websocket" ? (
<ViserLogo />
Expand Down Expand Up @@ -453,7 +446,9 @@ function ViewerCanvas({ children }: { children: React.ReactNode }) {
<AdaptiveDpr />
<SceneContextSetter />
<SynchronizedCameraControls />
<SceneNodeThreeObject name="" parent={null} />
<SplatRenderContext>
<SceneNodeThreeObject name="" parent={null} />
</SplatRenderContext>
<Environment path="hdri/" files="potsdamer_platz_1k.hdr" />
<directionalLight color={0xffffff} intensity={1.0} position={[0, 1, 0]} />
<directionalLight
Expand All @@ -473,7 +468,7 @@ function AdaptiveDpr() {
ms={100}
iterations={5}
step={0.2}
bounds={(refreshrate) => (refreshrate > 90 ? [40, 80] : [40, 50])}
bounds={(refreshrate) => (refreshrate > 90 ? [80, 90] : [50, 60])}
onChange={({ factor, fps, refreshrate }) => {
const dpr = window.devicePixelRatio * (0.2 + 0.8 * factor);
console.log(
Expand Down
39 changes: 16 additions & 23 deletions src/viser/client/src/MessageHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import GeneratedGuiContainer from "./ControlPanel/Generated";
import { Paper, Progress } from "@mantine/core";
import { IconCheck } from "@tabler/icons-react";
import { computeT_threeworld_world } from "./WorldTransformUtils";
import { GaussianSplatsContext } from "./Splatting/SplatContext";
import { SplatObject } from "./Splatting/GaussianSplats";

/** Convert raw RGB color buffers to linear color buffers. **/
function threeColorBufferFromUint8Buffer(colors: ArrayBuffer) {
Expand All @@ -51,10 +51,6 @@ function useMessageHandler() {
const viewer = useContext(ViewerContext)!;
const ContextBridge = useContextBridge();

const splatContext = useContext(GaussianSplatsContext)!;
const setGaussianBuffer = splatContext((state) => state.setBuffer);
const removeGaussianBuffer = splatContext((state) => state.removeBuffer);

// We could reduce the redundancy here if we wanted to.
// https://github.com/nerfstudio-project/viser/issues/39
const removeSceneNode = viewer.useSceneTree((state) => state.removeSceneNode);
Expand Down Expand Up @@ -970,25 +966,22 @@ function useMessageHandler() {
return;
}
case "GaussianSplatsMessage": {
setGaussianBuffer(
message.name,
new Uint32Array(
message.buffer.buffer.slice(
message.buffer.byteOffset,
message.buffer.byteOffset + message.buffer.byteLength,
),
),
);
addSceneNodeMakeParents(
new SceneNode<THREE.Group>(
message.name,
(ref) => {
return <group ref={ref}></group>;
},
() => {
removeGaussianBuffer(message.name);
},
),
new SceneNode<THREE.Group>(message.name, (ref) => {
return (
<SplatObject
ref={ref}
buffer={
new Uint32Array(
message.buffer.buffer.slice(
message.buffer.byteOffset,
message.buffer.byteOffset + message.buffer.byteLength,
),
)
}
/>
);
}),
);
return;
}
Expand Down
162 changes: 129 additions & 33 deletions src/viser/client/src/Splatting/GaussianSplats.tsx
Original file line number Diff line number Diff line change
@@ -1,14 +1,85 @@
/** Gaussian splatting implementation for viser.
*
* This borrows heavily from existing open-source implementations. Particularly
* useful references:
* - https://github.com/quadjr/aframe-gaussian-splatting
* - https://github.com/antimatter15/splat
* - https://github.com/pmndrs/drei
* - https://github.com/vincent-lecrubier-skydio/react-three-fiber-gaussian-splat
*
* Usage should look like:
*
* <Canvas>
* <SplatRenderContext>
* <SplatObject buffer={buffer} />
* </SplatRenderContext>
* </Canvas>
*
* Where `buffer` contains serialized Gaussian attributes. SplatObjects are
* globally sorted by a worker (with some help from WebAssembly + SIMD
* intrinsics), and then rendered as a single threejs mesh. Unlike other R3F
* implementations that we're aware of, this enables correct compositing
* between multiple splat objects.
*/

import React from "react";
import * as THREE from "three";
import SplatSortWorker from "./SplatSortWorker?worker";
import { useFrame, useThree } from "@react-three/fiber";
import { shaderMaterial } from "@react-three/drei";
import { GaussianSplatsContext } from "./SplatContext";
import { ViewerContext } from "../App";
import { SorterWorkerIncoming } from "./SplatSortWorker";
import { create } from "zustand";
import { Object3D } from "three";

/**Global splat state.*/
interface SplatState {
groupBufferFromId: { [id: string]: Uint32Array };
nodeRefFromId: React.MutableRefObject<{
[name: string]: undefined | Object3D;
}>;
setBuffer: (id: string, buffer: Uint32Array) => void;
removeBuffer: (id: string) => void;
}

function postToWorker(worker: Worker, message: SorterWorkerIncoming) {
worker.postMessage(message);
/**Hook for creating global splat state.*/
function useGaussianSplatStore() {
const nodeRefFromId = React.useRef({});
return React.useState(() =>
create<SplatState>((set) => ({
groupBufferFromId: {},
nodeRefFromId: nodeRefFromId,
setBuffer: (id, buffer) => {
return set((state) => ({
groupBufferFromId: { ...state.groupBufferFromId, [id]: buffer },
}));
},
removeBuffer: (id) => {
return set((state) => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { [id]: _, ...buffers } = state.groupBufferFromId;
return { groupBufferFromId: buffers };
});
},
})),
)[0];
}
const GaussianSplatsContext = React.createContext<ReturnType<
typeof useGaussianSplatStore
> | null>(null);

/**Provider for creating splat rendering context.*/
export function SplatRenderContext({
children,
}: {
children: React.ReactNode;
}) {
const store = useGaussianSplatStore();
return (
<GaussianSplatsContext.Provider value={store}>
<SplatRenderer />
{children}
</GaussianSplatsContext.Provider>
);
}

const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
Expand Down Expand Up @@ -52,14 +123,6 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
out vec4 vRgba;
out vec2 vPosition;
float hash2D(vec2 value) {
return fract( 1.0e4 * sin( 17.0 * value.x + 0.1 * value.y ) * ( 0.1 + abs( sin( 13.0 * value.y + value.x ) ) ) );
}
float hash3D(vec3 value) {
return hash2D( vec2( hash2D( value.xy ), value.z ) );
}
// Function to fetch and construct the i-th transform matrix using texelFetch
mat4 getGroupTransform(uint i) {
// Calculate the base index for the i-th transform.
Expand All @@ -80,14 +143,16 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
void main () {
// Get position + scale from float buffer.
ivec2 texSize = textureSize(textureBuffer, 0);
ivec2 texPos0 = ivec2((sortedIndex * 2u) % uint(texSize.x), (sortedIndex * 2u) / uint(texSize.x));
uint texStart = sortedIndex << 1u;
ivec2 texPos0 = ivec2(texStart % uint(texSize.x), texStart / uint(texSize.x));
// Fetch from textures.
uvec4 floatBufferData = texelFetch(textureBuffer, texPos0, 0);
mat4 T_camera_group = getGroupTransform(floatBufferData.w);
// Any early return will discard the fragment.
gl_Position = vec4(0.0, 0.0, 2000.0, 1.0);
gl_Position = vec4(0.0, 0.0, 2.0, 1.0);
// Get center wrt camera. modelViewMatrix is T_cam_world.
vec3 center = uintBitsToFloat(floatBufferData.xyz);
Expand All @@ -100,7 +165,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
return;
// Read covariance terms.
ivec2 texPos1 = ivec2((sortedIndex * 2u + 1u) % uint(texSize.x), (sortedIndex * 2u + 1u) / uint(texSize.x));
ivec2 texPos1 = ivec2((texStart + 1u) % uint(texSize.x), (texStart + 1u) / uint(texSize.x));
uvec4 intBufferData = texelFetch(textureBuffer, texPos1, 0);
// Get covariance terms from int buffer.
Expand All @@ -111,7 +176,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
// Transition in.
float startTime = 0.8 * float(sortedIndex) / float(numGaussians);
float cov_scale = clamp((transitionInState - startTime) / 0.2f, 0.0f, 1.0f); // min() can freeze here. Not sure why...
float cov_scale = smoothstep(startTime, startTime + 0.2, transitionInState);
// Do the actual splatting.
mat3 chol = mat3(
Expand Down Expand Up @@ -152,10 +217,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
// Throw the Gaussian off the screen if it's too close, too far, or too small.
float weightedDeterminant = vRgba.a * (diag1 * diag2 - offDiag * offDiag);
if (weightedDeterminant < 0.1)
return;
// This is not principled. It just makes things faster.
if (weightedDeterminant < 1.0 && hash3D(center) < weightedDeterminant)
if (weightedDeterminant < 0.5)
return;
vPosition = position.xy;
Expand All @@ -182,16 +244,48 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
}`,
);

export const SplatObject = React.forwardRef<
THREE.Group,
{
buffer: Uint32Array;
}
>(function SplatObject({ buffer }, ref) {
const splatContext = React.useContext(GaussianSplatsContext)!;
const setBuffer = splatContext((state) => state.setBuffer);
const removeBuffer = splatContext((state) => state.removeBuffer);
const nodeRefFromId = splatContext((state) => state.nodeRefFromId);
const name = React.useMemo(() => crypto.randomUUID(), [buffer]);

const [obj, setRef] = React.useState<THREE.Group | null>(null);

React.useEffect(() => {
if (obj === null) return;
setBuffer(name, buffer);
if (ref !== null) {
if ("current" in ref) {
ref.current = obj;
} else {
ref(obj);
}
}
nodeRefFromId.current[name] = obj;
return () => {
removeBuffer(name);
delete nodeRefFromId.current[name];
};
}, [obj]);

return <group ref={setRef}></group>;
});

/** External interface. Component should be added to the root of canvas. */
export default function GlobalGaussianSplats() {
const viewer = React.useContext(ViewerContext)!;
function SplatRenderer() {
const splatContext = React.useContext(GaussianSplatsContext)!;
const groupBufferFromName = splatContext(
(state) => state.groupBufferFromName,
);
const groupBufferFromId = splatContext((state) => state.groupBufferFromId);
const nodeRefFromId = splatContext((state) => state.nodeRefFromId);

// Consolidate Gaussian groups into a single buffer.
const merged = mergeGaussianGroups(groupBufferFromName);
const merged = mergeGaussianGroups(groupBufferFromId);
const meshProps = useGaussianMeshProps(
merged.gaussianBuffer,
merged.numGroups,
Expand All @@ -213,7 +307,11 @@ export default function GlobalGaussianSplats() {
initializedBufferTexture = true;
}
};
postToWorker(sortWorker, {
function postToWorker(message: SorterWorkerIncoming) {
sortWorker.postMessage(message);
}

postToWorker({
setBuffer: merged.gaussianBuffer,
setGroupIndices: merged.groupIndices,
});
Expand All @@ -224,7 +322,7 @@ export default function GlobalGaussianSplats() {
meshProps.textureBuffer.dispose();
meshProps.geometry.dispose();
meshProps.material.dispose();
postToWorker(sortWorker, { close: true });
postToWorker({ close: true });
};
});

Expand Down Expand Up @@ -267,10 +365,8 @@ export default function GlobalGaussianSplats() {
const T_camera_world = state.camera.matrixWorldInverse;
const groupVisibles: boolean[] = [];
let visibilitiesChanged = false;
for (const [groupIndex, name] of Object.keys(
groupBufferFromName,
).entries()) {
const node = viewer.nodeRefFromName.current[name];
for (const [groupIndex, name] of Object.keys(groupBufferFromId).entries()) {
const node = nodeRefFromId.current[name];
if (node === undefined) continue;
tmpT_camera_group.copy(T_camera_world).multiply(node.matrixWorld);
const colMajorElements = tmpT_camera_group.elements;
Expand Down Expand Up @@ -312,7 +408,7 @@ export default function GlobalGaussianSplats() {

if (groupsMovedWrtCam) {
// Gaussians need to be re-sorted.
postToWorker(sortWorker, {
postToWorker({
setTz_camera_groups: Tz_camera_groups,
});
}
Expand Down
Loading

0 comments on commit 24ad049

Please sign in to comment.