Skip to content

Commit

Permalink
Start new splatting impl
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jul 22, 2024
1 parent ba56408 commit a0d26e0
Showing 1 changed file with 25 additions and 58 deletions.
83 changes: 25 additions & 58 deletions src/viser/client/src/Splatting/GaussianSplats.tsx
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import React from "react";
import { create } from "zustand";
import * as THREE from "three";
import SplatSortWorker from "./SplatSortWorker?worker";
import { useFrame, useThree } from "@react-three/fiber";
import { shaderMaterial } from "@react-three/drei";

export const GaussianSplatsContext =
React.createContext<null | React.MutableRefObject<{
numSorting: number;
sortUpdateCallbacks: (() => void)[];
}>>(null);
export function useGaussianSplatStore() {
return React.useState(() =>
create((set) => {
x: null;
}),
);
}
export const GaussianSplatsContext = React.createContext<{ x: null } | null>(
null,
);

export type GaussianBuffers = {
buffer: Uint32Array;
Expand All @@ -22,10 +28,9 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
near: 1.0,
far: 100.0,
depthTest: true,
depthWrite: true,
depthWrite: false,
transparent: true,
bufferTexture: null,
sortSynchronizedModelViewMatrix: new THREE.Matrix4(),
transitionInState: 0.0,
},
`precision highp usampler2D; // Most important: ints must be 32-bit.
Expand All @@ -45,12 +50,6 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
uniform float near;
uniform float far;
// Depth testing is useful for compositing multiple splat objects, but causes
// artifacts when closer Gaussians are rendered before further ones.
// Synchronizing the modelViewMatrix updates used for depth computation with
// the splat sorter mitigates this for Gaussians within the same object.
uniform mat4 sortSynchronizedModelViewMatrix;
// Fade in state between [0, 1].
uniform float transitionInState;
Expand Down Expand Up @@ -83,9 +82,6 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
if (pos2d.x < -clip || pos2d.x > clip || pos2d.y < -clip || pos2d.y > clip)
return;
vec4 c_camstable = sortSynchronizedModelViewMatrix * vec4(center, 1);
vec4 stablePos2d = projectionMatrix * c_camstable;
float perGaussianShift = 1.0 - (float(numGaussians * 2u) - float(sortedIndex)) / float(numGaussians * 2u);
float cov_scale = max(0.0, transitionInState - perGaussianShift) / (1.0 - perGaussianShift);
Expand Down Expand Up @@ -145,7 +141,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
gl_Position = vec4(
vec2(pos2d) / pos2d.w
+ position.x * v1 / viewport * 2.0
+ position.y * v2 / viewport * 2.0, stablePos2d.z / stablePos2d.w, 1.);
+ position.y * v2 / viewport * 2.0, pos2d.z / pos2d.w, 1.);
}
`,
`precision mediump float;
Expand Down Expand Up @@ -176,7 +172,6 @@ export default function GaussianSplats({
const maxTextureSize = useThree((state) => state.gl).capabilities
.maxTextureSize;
const initializedTextures = React.useRef<boolean>(false);
const [sortSynchronizedModelViewMatrix] = React.useState(new THREE.Matrix4());

const splatContext = React.useContext(GaussianSplatsContext)!.current;

Expand Down Expand Up @@ -230,7 +225,6 @@ export default function GaussianSplats({
bufferTexture: bufferTexture,
numGaussians: 0,
transitionInState: 0.0,
sortSynchronizedModelViewMatrix: new THREE.Matrix4(),
});

// Update component state.
Expand All @@ -241,44 +235,18 @@ export default function GaussianSplats({
const sortWorker = new SplatSortWorker();
sortWorker.onmessage = (e) => {
sortedIndexAttribute.set(e.data.sortedIndices as Int32Array);
const synchronizedSortUpdateCallback = () => {
isSortingRef.current = false;

// Wait for onmessage to be triggered for all Gaussians.
sortedIndexAttribute.needsUpdate = true;
material.uniforms.sortSynchronizedModelViewMatrix.value.copy(
sortSynchronizedModelViewMatrix,
);
// A simple but reasonably effective heuristic for render ordering.
//
// To minimize artifacts:
// - When there are multiple splat objects, we want to render the closest
// ones *last*. This improves the likelihood of correct alpha
// compositing and reduces reliance on alpha testing.
// - We generally want to render other objects like meshes *before*
// Gaussians. They're usually opaque.
meshRef.current!.renderOrder = (-e.data.minDepth as number) + 1000.0;

// Trigger initial render.
if (!initializedTextures.current) {
material.uniforms.numGaussians.value = numGaussians;
bufferTexture.needsUpdate = true;
initializedTextures.current = true;
}
};

// Synchronize sort updates across multiple Gaussian splats. This
// prevents z-fighting.
splatContext.numSorting -= 1;
if (splatContext.numSorting === 0) {
synchronizedSortUpdateCallback();
console.log(splatContext.sortUpdateCallbacks.length);
for (const callback of splatContext.sortUpdateCallbacks) {
callback();
}
splatContext.sortUpdateCallbacks.length = 0;
} else {
splatContext.sortUpdateCallbacks.push(synchronizedSortUpdateCallback);
sortedIndexAttribute.needsUpdate = true;

isSortingRef.current = false;

// Render Gaussians last.
meshRef.current!.renderOrder = 1000.0;

// Trigger initial render.
if (!initializedTextures.current) {
material.uniforms.numGaussians.value = numGaussians;
bufferTexture.needsUpdate = true;
initializedTextures.current = true;
}
};
sortWorker.postMessage({
Expand Down Expand Up @@ -340,7 +308,6 @@ export default function GaussianSplats({
!isSortingRef.current &&
(prevT_camera_obj === undefined || !T_camera_obj.equals(prevT_camera_obj))
) {
sortSynchronizedModelViewMatrix.copy(T_camera_obj);
sortWorker.postMessage({ setT_camera_obj: T_camera_obj.elements });
splatContext.numSorting += 1;
isSortingRef.current = true;
Expand Down

0 comments on commit a0d26e0

Please sign in to comment.