Skip to content

Commit

Permalink
Optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jul 24, 2024
1 parent d7e9ecc commit 5ca6344
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 179 deletions.
216 changes: 82 additions & 134 deletions src/viser/client/src/Splatting/GaussianSplats.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
depthWrite: false,
transparent: true,
bufferTexture: null,
groupTransformTexture: null,
T_camera_groups: new Float32Array(1),
transitionInState: 0.0,
},
`precision highp usampler2D; // Most important: ints must be 32-bit.
Expand All @@ -38,11 +38,8 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
// copy quadjr for this.
uniform usampler2D bufferTexture;
// Buffer containing SE(3) transforms, which will be independently applied to
// Gaussians. This enables multiple fast rigid "objects".
uniform sampler2D groupTransformTexture;
// Various other uniforms...
uniform mat4 T_camera_groups[128]; // We support up to 128 groups for now.
uniform uint numGaussians;
uniform vec2 focal;
uniform vec2 viewport;
Expand All @@ -62,38 +59,21 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial(
return hash2D( vec2( hash2D( value.xy ), value.z ) );
}
// Function to fetch and construct the i-th transform matrix using texelFetch
mat4 getGroupTransform(uint i) {
// Calculate the base index for the i-th transform.
uint baseIndex = i * 3u;
// Fetch the texels that represent the first 3 rows of the transform. We
// choose to use row-major here, since it lets us exclude the fourth row of
// the matrix.
vec4 row0 = texelFetch(groupTransformTexture, ivec2(baseIndex + 0u, 0), 0);
vec4 row1 = texelFetch(groupTransformTexture, ivec2(baseIndex + 1u, 0), 0);
vec4 row2 = texelFetch(groupTransformTexture, ivec2(baseIndex + 2u, 0), 0);
// Construct the mat4 with the fetched rows.
mat4 transform = mat4(row0, row1, row2, vec4(0.0, 0.0, 0.0, 1.0));
return transpose(transform);
}
void main () {
// Get position + scale from float buffer.
ivec2 texSize = textureSize(bufferTexture, 0);
ivec2 texPos0 = ivec2((sortedIndex * 2u) % uint(texSize.x), (sortedIndex * 2u) / uint(texSize.x));
// Fetch from textures.
uvec4 floatBufferData = texelFetch(bufferTexture, texPos0, 0);
mat4 groupTransform = getGroupTransform(groupIndex);
mat4 groupTransform = T_camera_groups[groupIndex];
// Any early return will discard the fragment.
gl_Position = vec4(0.0, 0.0, 2000.0, 1.0);
// Get center wrt camera. modelViewMatrix is T_cam_world.
vec3 center = uintBitsToFloat(floatBufferData.xyz);
mat4 T_world_group = modelViewMatrix * groupTransform;
mat4 T_world_group = groupTransform;
vec4 c_cam = T_world_group * vec4(center, 1);
if (-c_cam.z < near || -c_cam.z > far)
return;
Expand Down Expand Up @@ -191,7 +171,6 @@ export default function GlobalGaussianSplats() {
const maxTextureSize = useThree((state) => state.gl).capabilities
.maxTextureSize;
const initializedBufferTexture = React.useRef<boolean>(false);
const groupTransformTextureRef = React.useRef<THREE.DataTexture | null>(null);

const viewer = React.useContext(ViewerContext)!;
const splatContext = React.useContext(GaussianSplatsContext)!;
Expand All @@ -200,45 +179,33 @@ export default function GlobalGaussianSplats() {
);

// Pre-compute some buffers we need for rendering.
const [
numGaussians,
gaussianBuffer,
numGroups,
groupIndices,
groupTransformBuffer,
] = React.useMemo(() => {
// Create geometry. Each Gaussian will be rendered as a quad.
let totalBufferLength = 0;
for (const buffer of Object.values(groupBufferFromName)) {
totalBufferLength += buffer.length;
}
const numGaussians = totalBufferLength / 8;
const gaussianBuffer = new Uint32Array(totalBufferLength);
const groupIndices = new Uint32Array(numGaussians);
const [numGaussians, gaussianBuffer, numGroups, groupIndices] =
React.useMemo(() => {
// Create geometry. Each Gaussian will be rendered as a quad.
let totalBufferLength = 0;
for (const buffer of Object.values(groupBufferFromName)) {
totalBufferLength += buffer.length;
}
const numGaussians = totalBufferLength / 8;
const gaussianBuffer = new Uint32Array(totalBufferLength);
const groupIndices = new Uint32Array(numGaussians);

let offset = 0;
for (const [groupIndex, groupBuffer] of Object.values(
groupBufferFromName,
).entries()) {
groupIndices.fill(
groupIndex,
offset / 8,
(offset + groupBuffer.length) / 8,
);
gaussianBuffer.set(groupBuffer, offset);
offset += groupBuffer.length;
}
let offset = 0;
for (const [groupIndex, groupBuffer] of Object.values(
groupBufferFromName,
).entries()) {
groupIndices.fill(
groupIndex,
offset / 8,
(offset + groupBuffer.length) / 8,
);
gaussianBuffer.set(groupBuffer, offset);
offset += groupBuffer.length;
}

const numGroups = Object.keys(groupBufferFromName).length;
const groupTransformBuffer = new Float32Array(numGroups * 12);
return [
numGaussians,
gaussianBuffer,
numGroups,
groupIndices,
groupTransformBuffer,
];
}, [groupBufferFromName]);
const numGroups = Object.keys(groupBufferFromName).length;
return [numGaussians, gaussianBuffer, numGroups, groupIndices];
}, [groupBufferFromName]);

// Primary setup.
//
Expand Down Expand Up @@ -297,21 +264,9 @@ export default function GlobalGaussianSplats() {
bufferTexture.internalFormat = "RGBA32UI";
bufferTexture.needsUpdate = true;

const groupTransformTexture = new THREE.DataTexture(
groupTransformBuffer,
(numGroups * 12) / 4,
1,
THREE.RGBAFormat,
THREE.FloatType,
);
groupTransformTexture.internalFormat = "RGBA32F";
groupTransformTextureRef.current = groupTransformTexture;
groupTransformTexture.needsUpdate = true;

const material = new GaussianSplatMaterial({
// @ts-ignore
bufferTexture: bufferTexture,
groupTransformTexture: groupTransformTexture,
numGaussians: 0,
transitionInState: 0.0,
});
Expand Down Expand Up @@ -354,8 +309,6 @@ export default function GlobalGaussianSplats() {

return () => {
bufferTexture.dispose();
groupTransformTexture.dispose();
groupTransformTextureRef.current = null;
geometry.dispose();
if (material !== undefined) material.dispose();
postToWorker(sortWorker, { close: true });
Expand All @@ -370,20 +323,28 @@ export default function GlobalGaussianSplats() {
// matrices to make life easier for the garbage collector.
const meshRef = React.useRef<THREE.Mesh>(null);
const [prevT_camera_world] = React.useState(new THREE.Matrix4());
const tmpGroupTransformBuffer = React.useMemo(
() => new Float32Array(numGroups * 12),
const [tmpT_camera_group] = React.useState(new THREE.Matrix4());
// const T_camera_groups = React.useMemo(
// () => [...Array(numGroups)].map(() => new THREE.Matrix4()),
// [numGroups],
// );
const T_camera_groups = React.useMemo(
() => new Float32Array(numGroups * 16),
[numGroups],
);
const Tz_camera_groups = React.useMemo(
() => new Float32Array(numGroups * 4),
[numGroups],
);
const prevTz_camera_groups = React.useMemo(
() => new Float32Array(numGroups * 4),
[numGroups],
);
const staleSort = React.useRef(false);
useFrame((state, delta) => {
const mesh = meshRef.current;
const sortWorker = splatSortWorkerRef.current;
if (
mesh === null ||
sortWorker === null ||
groupTransformTextureRef.current == null
)
return;
if (mesh === null || sortWorker === null) return;

// Update camera parameter uniforms.
const dpr = state.viewport.dpr;
Expand All @@ -406,61 +367,48 @@ export default function GlobalGaussianSplats() {
uniforms.viewport.value = [state.size.width * dpr, state.size.height * dpr];

// Update group transforms.
if (groupTransformTextureRef.current !== null) {
for (const [groupIndex, name] of Object.keys(
groupBufferFromName,
).entries()) {
const node = viewer.nodeRefFromName.current[name];
if (node === undefined) continue;
const rowMajorElements = node.matrixWorld
.transpose()
.elements.slice(0, 12);
tmpGroupTransformBuffer.set(rowMajorElements, groupIndex * 12);

// If a group is not visible, we'll throw it off the screen with some
// Big Numbers.
let visible = node.visible && node.parent !== null;
if (visible) {
node.traverseAncestors((ancestor) => {
visible = visible && ancestor.visible;
});
}
if (!visible) {
tmpGroupTransformBuffer[groupIndex * 12 + 3] = 1e10;
tmpGroupTransformBuffer[groupIndex * 12 + 7] = 1e10;
tmpGroupTransformBuffer[groupIndex * 12 + 11] = 1e10;
}
}
if (
!groupTransformBuffer.every((v, i) => v === tmpGroupTransformBuffer[i])
) {
staleSort.current = true;
groupTransformBuffer.set(tmpGroupTransformBuffer);
groupTransformTextureRef.current.needsUpdate = true;
postToWorker(sortWorker, {
// Big values will break the counting sort precision. We just
// zero them out for now.
setT_world_groups: groupTransformBuffer.map((x) =>
x >= 1e10 ? 0.0 : x,
),
const T_camera_world = state.camera.matrixWorldInverse;
for (const [groupIndex, name] of Object.keys(
groupBufferFromName,
).entries()) {
const node = viewer.nodeRefFromName.current[name];
if (node === undefined) continue;
tmpT_camera_group.copy(T_camera_world).multiply(node.matrixWorld);
const colMajorElements = tmpT_camera_group.elements;
T_camera_groups.set(colMajorElements, groupIndex * 16);
Tz_camera_groups.set(
[
colMajorElements[2],
colMajorElements[6],
colMajorElements[10],
colMajorElements[14],
],
groupIndex * 4,
);

// If a group is not visible, we'll throw it off the screen with some
// Big Numbers.
let visible = node.visible && node.parent !== null;
if (visible) {
node.traverseAncestors((ancestor) => {
visible = visible && ancestor.visible;
});
}
if (!visible) {
T_camera_groups[groupIndex * 16 + 12] = 1e10;
T_camera_groups[groupIndex * 16 + 13] = 1e10;
T_camera_groups[groupIndex * 16 + 14] = 1e10;
}
}

// Update camera transform.
const T_camera_world = state.camera.matrixWorldInverse;

if (!T_camera_world.equals(prevT_camera_world)) {
if (!Tz_camera_groups.every((v, i) => v === prevTz_camera_groups[i])) {
prevTz_camera_groups.set(Tz_camera_groups);
staleSort.current = true;
uniforms.T_camera_groups.value = T_camera_groups;
postToWorker(sortWorker, {
setTz_camera_world: [
T_camera_world.elements[2],
T_camera_world.elements[6],
T_camera_world.elements[10],
T_camera_world.elements[14],
],
// Big values will break the counting sort precision. We just
// zero them out for now.
setTz_camera_groups: Tz_camera_groups,
});
prevT_camera_world.copy(T_camera_world);
}

// Sort Gaussians.
Expand All @@ -470,7 +418,7 @@ export default function GlobalGaussianSplats() {
triggerSort: true,
});
}
});
}, -100 /* Render early to reduce artifacts from slow texture updates. */);

return (
<mesh
Expand Down
Loading

0 comments on commit 5ca6344

Please sign in to comment.