Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Untie splat rendering components from viser-specific internals #258

Merged
merged 1 commit into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion examples/experimental/gaussian_splats.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,21 @@ def _(event: viser.GuiEvent) -> None:
raise SystemExit("Please provide a filepath to a .splat or .ply file.")

server.scene.add_transform_controls(f"/{i}")
server.scene._add_gaussian_splats(
gs_handle = server.scene._add_gaussian_splats(
f"/{i}/gaussian_splats",
centers=splat_data["centers"],
rgbs=splat_data["rgbs"],
opacities=splat_data["opacities"],
covariances=splat_data["covariances"],
)

remove_button = server.gui.add_button(f"Remove splat object {i}")

@remove_button.on_click
def _(_, gs_handle=gs_handle) -> None:
gs_handle.remove()
remove_button.remove()

while True:
time.sleep(10.0)

Expand Down
21 changes: 8 additions & 13 deletions src/viser/client/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,9 @@ import { useDisclosure } from "@mantine/hooks";
import { rayToViserCoords } from "./WorldTransformUtils";
import { ndcFromPointerXy, opencvXyFromPointerXy } from "./ClickUtils";
import { theme } from "./AppTheme";
import {
GaussianSplatsContext,
useGaussianSplatStore,
} from "./Splatting/SplatContext";
import { FrameSynchronizedMessageHandler } from "./MessageHandler";
import { PlaybackFromFile } from "./FilePlayback";
import GlobalGaussianSplats from "./Splatting/GaussianSplats";
import { SplatRenderContext } from "./Splatting/GaussianSplats";
import { BrowserWarning } from "./BrowserWarning";

export type ViewerContextContents = {
Expand Down Expand Up @@ -267,12 +263,9 @@ function ViewerContents({ children }: { children: React.ReactNode }) {
})}
>
<Viewer2DCanvas />
<GaussianSplatsContext.Provider value={useGaussianSplatStore()}>
<ViewerCanvas>
<GlobalGaussianSplats />
<FrameSynchronizedMessageHandler />
</ViewerCanvas>
</GaussianSplatsContext.Provider>
<ViewerCanvas>
<FrameSynchronizedMessageHandler />
</ViewerCanvas>
{viewer.useGui((state) => state.theme.show_logo) &&
viewer.messageSource == "websocket" ? (
<ViserLogo />
Expand Down Expand Up @@ -453,7 +446,9 @@ function ViewerCanvas({ children }: { children: React.ReactNode }) {
<AdaptiveDpr />
<SceneContextSetter />
<SynchronizedCameraControls />
<SceneNodeThreeObject name="" parent={null} />
<SplatRenderContext>
<SceneNodeThreeObject name="" parent={null} />
</SplatRenderContext>
<Environment path="hdri/" files="potsdamer_platz_1k.hdr" />
<directionalLight color={0xffffff} intensity={1.0} position={[0, 1, 0]} />
<directionalLight
Expand All @@ -473,7 +468,7 @@ function AdaptiveDpr() {
ms={100}
iterations={5}
step={0.2}
bounds={(refreshrate) => (refreshrate > 90 ? [40, 80] : [40, 50])}
bounds={(refreshrate) => (refreshrate > 90 ? [80, 90] : [50, 60])}
onChange={({ factor, fps, refreshrate }) => {
const dpr = window.devicePixelRatio * (0.2 + 0.8 * factor);
console.log(
Expand Down
39 changes: 16 additions & 23 deletions src/viser/client/src/MessageHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import GeneratedGuiContainer from "./ControlPanel/Generated";
import { Paper, Progress } from "@mantine/core";
import { IconCheck } from "@tabler/icons-react";
import { computeT_threeworld_world } from "./WorldTransformUtils";
import { GaussianSplatsContext } from "./Splatting/SplatContext";
import { SplatObject } from "./Splatting/GaussianSplats";

/** Convert raw RGB color buffers to linear color buffers. **/
function threeColorBufferFromUint8Buffer(colors: ArrayBuffer) {
Expand All @@ -51,10 +51,6 @@ function useMessageHandler() {
const viewer = useContext(ViewerContext)!;
const ContextBridge = useContextBridge();

const splatContext = useContext(GaussianSplatsContext)!;
const setGaussianBuffer = splatContext((state) => state.setBuffer);
const removeGaussianBuffer = splatContext((state) => state.removeBuffer);

// We could reduce the redundancy here if we wanted to.
// https://github.com/nerfstudio-project/viser/issues/39
const removeSceneNode = viewer.useSceneTree((state) => state.removeSceneNode);
Expand Down Expand Up @@ -970,25 +966,22 @@ function useMessageHandler() {
return;
}
case "GaussianSplatsMessage": {
setGaussianBuffer(
message.name,
new Uint32Array(
message.buffer.buffer.slice(
message.buffer.byteOffset,
message.buffer.byteOffset + message.buffer.byteLength,
),
),
);
addSceneNodeMakeParents(
new SceneNode<THREE.Group>(
message.name,
(ref) => {
return <group ref={ref}></group>;
},
() => {
removeGaussianBuffer(message.name);
},
),
new SceneNode<THREE.Group>(message.name, (ref) => {
return (
<SplatObject
ref={ref}
buffer={
new Uint32Array(
message.buffer.buffer.slice(
message.buffer.byteOffset,
message.buffer.byteOffset + message.buffer.byteLength,
),
)
}
/>
);
}),
);
return;
}
Expand Down
162 changes: 129 additions & 33 deletions src/viser/client/src/Splatting/GaussianSplats.tsx
Original file line number Diff line number Diff line change
@@ -1,14 +1,85 @@
/** Gaussian splatting implementation for viser.
*
* This borrows heavily from existing open-source implementations. Particularly
* useful references:
* - https://github.com/quadjr/aframe-gaussian-splatting
* - https://github.com/antimatter15/splat
* - https://github.com/pmndrs/drei
* - https://github.com/vincent-lecrubier-skydio/react-three-fiber-gaussian-splat
*
* Usage should look like:
*
* <Canvas>
* <SplatRenderContext>
* <SplatObject buffer={buffer} />
* </SplatRenderContext>
* </Canvas>
*
* Where `buffer` contains serialized Gaussian attributes. SplatObjects are
* globally sorted by a worker (with some help from WebAssembly + SIMD
* intrinsics), and then rendered as a single threejs mesh. Unlike other R3F
* implementations that we're aware of, this enables correct compositing
* between multiple splat objects.
*/

import React from "react";
import * as THREE from "three";
import SplatSortWorker from "./SplatSortWorker?worker";
import { useFrame, useThree } from "@react-three/fiber";
import { shaderMaterial } from "@react-three/drei";
import { GaussianSplatsContext } from "./SplatContext";
import { ViewerContext } from "../App";
import { SorterWorkerIncoming } from "./SplatSortWorker";
import { create } from "zustand";
import { Object3D } from "three";

/**Global splat state.*/
interface SplatState {
groupBufferFromId: { [id: string]: Uint32Array };
nodeRefFromId: React.MutableRefObject<{
[name: string]: undefined | Object3D;
}>;
setBuffer: (id: string, buffer: Uint32Array) => void;
removeBuffer: (id: string) => void;
}

function postToWorker(worker: Worker, message: SorterWorkerIncoming) {
worker.postMessage(message);
/**Hook for creating global splat state.*/
function useGaussianSplatStore() {
const nodeRefFromId = React.useRef({});
return React.useState(() =>
create<SplatState>((set) => ({
groupBufferFromId: {},
nodeRefFromId: nodeRefFromId,
setBuffer: (id, buffer) => {
return set((state) => ({
groupBufferFromId: { ...state.groupBufferFromId, [id]: buffer },
}));
},
removeBuffer: (id) => {
return set((state) => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { [id]: _, ...buffers } = state.groupBufferFromId;
return { groupBufferFromId: buffers };
});
},
})),
)[0];
}
const GaussianSplatsContext = React.createContext<ReturnType<
typeof useGaussianSplatStore
> | null>(null);

/**Provider for creating splat rendering context.*/
export function SplatRenderContext({
children,
}: {
children: React.ReactNode;
}) {
const store = useGaussianSplatStore();
return (
<GaussianSplatsContext.Provider value={store}>
<SplatRenderer />
{children}
</GaussianSplatsContext.Provider>
);
}

const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
Expand Down Expand Up @@ -52,14 +123,6 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
out vec4 vRgba;
out vec2 vPosition;

float hash2D(vec2 value) {
return fract( 1.0e4 * sin( 17.0 * value.x + 0.1 * value.y ) * ( 0.1 + abs( sin( 13.0 * value.y + value.x ) ) ) );
}

float hash3D(vec3 value) {
return hash2D( vec2( hash2D( value.xy ), value.z ) );
}

// Function to fetch and construct the i-th transform matrix using texelFetch
mat4 getGroupTransform(uint i) {
// Calculate the base index for the i-th transform.
Expand All @@ -80,14 +143,16 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
void main () {
// Get position + scale from float buffer.
ivec2 texSize = textureSize(textureBuffer, 0);
ivec2 texPos0 = ivec2((sortedIndex * 2u) % uint(texSize.x), (sortedIndex * 2u) / uint(texSize.x));
uint texStart = sortedIndex << 1u;
ivec2 texPos0 = ivec2(texStart % uint(texSize.x), texStart / uint(texSize.x));


// Fetch from textures.
uvec4 floatBufferData = texelFetch(textureBuffer, texPos0, 0);
mat4 T_camera_group = getGroupTransform(floatBufferData.w);

// Any early return will discard the fragment.
gl_Position = vec4(0.0, 0.0, 2000.0, 1.0);
gl_Position = vec4(0.0, 0.0, 2.0, 1.0);

// Get center wrt camera. modelViewMatrix is T_cam_world.
vec3 center = uintBitsToFloat(floatBufferData.xyz);
Expand All @@ -100,7 +165,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
return;

// Read covariance terms.
ivec2 texPos1 = ivec2((sortedIndex * 2u + 1u) % uint(texSize.x), (sortedIndex * 2u + 1u) / uint(texSize.x));
ivec2 texPos1 = ivec2((texStart + 1u) % uint(texSize.x), (texStart + 1u) / uint(texSize.x));
uvec4 intBufferData = texelFetch(textureBuffer, texPos1, 0);

// Get covariance terms from int buffer.
Expand All @@ -111,7 +176,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(

// Transition in.
float startTime = 0.8 * float(sortedIndex) / float(numGaussians);
float cov_scale = clamp((transitionInState - startTime) / 0.2f, 0.0f, 1.0f); // min() can freeze here. Not sure why...
float cov_scale = smoothstep(startTime, startTime + 0.2, transitionInState);

// Do the actual splatting.
mat3 chol = mat3(
Expand Down Expand Up @@ -152,10 +217,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(

// Throw the Gaussian off the screen if it's too close, too far, or too small.
float weightedDeterminant = vRgba.a * (diag1 * diag2 - offDiag * offDiag);
if (weightedDeterminant < 0.1)
return;
// This is not principled. It just makes things faster.
if (weightedDeterminant < 1.0 && hash3D(center) < weightedDeterminant)
if (weightedDeterminant < 0.5)
return;
vPosition = position.xy;

Expand All @@ -182,16 +244,48 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
}`,
);

export const SplatObject = React.forwardRef<
THREE.Group,
{
buffer: Uint32Array;
}
>(function SplatObject({ buffer }, ref) {
const splatContext = React.useContext(GaussianSplatsContext)!;
const setBuffer = splatContext((state) => state.setBuffer);
const removeBuffer = splatContext((state) => state.removeBuffer);
const nodeRefFromId = splatContext((state) => state.nodeRefFromId);
const name = React.useMemo(() => crypto.randomUUID(), [buffer]);

const [obj, setRef] = React.useState<THREE.Group | null>(null);

React.useEffect(() => {
if (obj === null) return;
setBuffer(name, buffer);
if (ref !== null) {
if ("current" in ref) {
ref.current = obj;
} else {
ref(obj);
}
}
nodeRefFromId.current[name] = obj;
return () => {
removeBuffer(name);
delete nodeRefFromId.current[name];
};
}, [obj]);

return <group ref={setRef}></group>;
});

/** External interface. Component should be added to the root of canvas. */
export default function GlobalGaussianSplats() {
const viewer = React.useContext(ViewerContext)!;
function SplatRenderer() {
const splatContext = React.useContext(GaussianSplatsContext)!;
const groupBufferFromName = splatContext(
(state) => state.groupBufferFromName,
);
const groupBufferFromId = splatContext((state) => state.groupBufferFromId);
const nodeRefFromId = splatContext((state) => state.nodeRefFromId);

// Consolidate Gaussian groups into a single buffer.
const merged = mergeGaussianGroups(groupBufferFromName);
const merged = mergeGaussianGroups(groupBufferFromId);
const meshProps = useGaussianMeshProps(
merged.gaussianBuffer,
merged.numGroups,
Expand All @@ -213,7 +307,11 @@ export default function GlobalGaussianSplats() {
initializedBufferTexture = true;
}
};
postToWorker(sortWorker, {
function postToWorker(message: SorterWorkerIncoming) {
sortWorker.postMessage(message);
}

postToWorker({
setBuffer: merged.gaussianBuffer,
setGroupIndices: merged.groupIndices,
});
Expand All @@ -224,7 +322,7 @@ export default function GlobalGaussianSplats() {
meshProps.textureBuffer.dispose();
meshProps.geometry.dispose();
meshProps.material.dispose();
postToWorker(sortWorker, { close: true });
postToWorker({ close: true });
};
});

Expand Down Expand Up @@ -267,10 +365,8 @@ export default function GlobalGaussianSplats() {
const T_camera_world = state.camera.matrixWorldInverse;
const groupVisibles: boolean[] = [];
let visibilitiesChanged = false;
for (const [groupIndex, name] of Object.keys(
groupBufferFromName,
).entries()) {
const node = viewer.nodeRefFromName.current[name];
for (const [groupIndex, name] of Object.keys(groupBufferFromId).entries()) {
const node = nodeRefFromId.current[name];
if (node === undefined) continue;
tmpT_camera_group.copy(T_camera_world).multiply(node.matrixWorld);
const colMajorElements = tmpT_camera_group.elements;
Expand Down Expand Up @@ -312,7 +408,7 @@ export default function GlobalGaussianSplats() {

if (groupsMovedWrtCam) {
// Gaussians need to be re-sorted.
postToWorker(sortWorker, {
postToWorker({
setTz_camera_groups: Tz_camera_groups,
});
}
Expand Down
Loading
Loading