Skip to content

Commit

Permalink
Global sort + SIMD for Gaussian rendering (#252)
Browse files Browse the repository at this point in the history
* Start new splatting impl

* Working global sorting + SIMD

* Cleanup

* Cleanup

* Guarantee inline for dot product helper

* Fix minor loading issues

* (unrelated) playback styling

* Add browser version warning

* Respect node visibility, run prettier

* Minor optimizations

* Significant optimization

* Cleanup

* typescript fix

* Fix edge cases

* More cleanup + fixes

* Optimization

* Reduce maximum number of Gaussian groups

* Bump max to 64

* Drop limit back down to 32

* Cleanup, fix hook-related edge cases

* Fix tsc

* Typos

* Move group transforms back into texture, cleanup

* Naming consistency

* More optimizations
  • Loading branch information
brentyi authored Jul 25, 2024
1 parent ba56408 commit 203d92a
Show file tree
Hide file tree
Showing 17 changed files with 668 additions and 361 deletions.
73 changes: 34 additions & 39 deletions examples/experimental/gaussian_splats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tyro
import viser
from plyfile import PlyData
from tqdm import tqdm
from viser import transforms as tf


Expand Down Expand Up @@ -72,49 +73,43 @@ def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile:
def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile:
plydata = PlyData.read(ply_file_path)
vert = plydata["vertex"]
sorted_indices = onp.argsort(
-onp.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"])
/ (1 + onp.exp(-vert["opacity"]))
)
numgaussians = len(vert)
print("Number of gaussians to render: ", numgaussians)
colors = onp.zeros((numgaussians, 3))
opacities = onp.zeros((numgaussians, 1))
positions = onp.zeros((numgaussians, 3))
wxyzs = onp.zeros((numgaussians, 4))
scales = onp.zeros((numgaussians, 3))
for idx in sorted_indices:
v = plydata["vertex"][idx]
position = onp.array([v["x"], v["y"], v["z"]], dtype=onp.float32)
scale = onp.exp(
onp.array([v["scale_0"], v["scale_1"], v["scale_2"]], dtype=onp.float32)
)

rot = onp.array(
[v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], dtype=onp.float32
)
SH_C0 = 0.28209479177387814
color = onp.array(
[
0.5 + SH_C0 * v["f_dc_0"],
0.5 + SH_C0 * v["f_dc_1"],
0.5 + SH_C0 * v["f_dc_2"],
]
)
opacity = 1 / (1 + onp.exp(-v["opacity"]))
wxyz = rot / onp.linalg.norm(rot) # normalize
scales[idx] = scale
colors[idx] = color
opacities[idx] = onp.array([opacity])
positions[idx] = position
wxyzs[idx] = wxyz

Rs = onp.array([tf.SO3(wxyz).as_matrix() for wxyz in wxyzs])
num_gaussians = len(vert)
print("Number of gaussians to render: ", num_gaussians)
colors = onp.zeros((num_gaussians, 3))
opacities = onp.zeros((num_gaussians, 1))
positions = onp.zeros((num_gaussians, 3))
wxyzs = onp.zeros((num_gaussians, 4))
scales = onp.zeros((num_gaussians, 3))

for i in tqdm(range(num_gaussians)):
v = plydata["vertex"][i]
positions[i, 0] = v["x"]
positions[i, 1] = v["y"]
positions[i, 2] = v["z"]
scales[i, 0] = v["scale_0"]
scales[i, 1] = v["scale_1"]
scales[i, 2] = v["scale_2"]
wxyzs[i, 0] = v["rot_0"]
wxyzs[i, 1] = v["rot_1"]
wxyzs[i, 2] = v["rot_2"]
wxyzs[i, 3] = v["rot_3"]
colors[i, 0] = v["f_dc_0"]
colors[i, 1] = v["f_dc_1"]
colors[i, 2] = v["f_dc_2"]
opacities[i, 0] = v["opacity"]

SH_C0 = 0.28209479177387814
scales = onp.exp(scales)
colors = 0.5 + SH_C0 * colors
opacities = 1.0 / (1.0 + onp.exp(-opacities))

Rs = tf.SO3(wxyzs).as_matrix()
covariances = onp.einsum(
"nij,njk,nlk->nil", Rs, onp.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs
)
if center:
positions -= onp.mean(positions, axis=0, keepdims=True)

print("PLY file loaded")
return {
"centers": positions,
Expand All @@ -127,7 +122,7 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile:


def main(splat_paths: tuple[Path, ...]) -> None:
server = viser.ViserServer(share=True)
server = viser.ViserServer()
server.gui.configure_theme(dark_mode=True)
gui_reset_up = server.gui.add_button(
"Reset up direction",
Expand Down
2 changes: 1 addition & 1 deletion src/viser/_gui_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class GuiApi:
"""Interface for working with the 2D GUI in viser.
Used by both our global server object, for sharing the same GUI elements
with all clients, and by invidividual client handles."""
with all clients, and by individual client handles."""

_target_container_from_thread_id: dict[int, str] = {}
"""ID of container to put GUI elements into."""
Expand Down
8 changes: 4 additions & 4 deletions src/viser/_scene_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class SceneApi:
"""Interface for adding 3D primitives to the scene.
Used by both our global server object, for sharing the same GUI elements
with all clients, and by invidividual client handles."""
with all clients, and by individual client handles."""

def __init__(
self,
Expand Down Expand Up @@ -946,8 +946,7 @@ def _add_gaussian_splats(
position: Tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0),
visible: bool = True,
) -> GaussianSplatHandle:
"""Add a model to render using Gaussian Splatting. Does not yet support
spherical harmonics.
"""Add a model to render using Gaussian Splatting.
**Work-in-progress.** This feature is experimental and still under
development. It may be changed or removed.
Expand All @@ -971,7 +970,8 @@ def _add_gaussian_splats(
assert opacities.shape == (num_gaussians, 1)
assert covariances.shape == (num_gaussians, 3, 3)

# Get cholesky factor of covariance.
# Get cholesky factor of covariance. This helps retain precision when
# we convert to float16.
cov_cholesky_triu = (
onp.linalg.cholesky(covariances.astype(onp.float64) + onp.ones(3) * 1e-7)
.swapaxes(-1, -2) # tril => triu
Expand Down
1 change: 1 addition & 0 deletions src/viser/client/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"clsx": "^2.1.0",
"colortranslator": "^4.1.0",
"dayjs": "^1.11.10",
"detect-browser": "^5.3.0",
"fflate": "^0.8.2",
"hold-event": "^1.1.0",
"immer": "^10.0.4",
Expand Down
13 changes: 9 additions & 4 deletions src/viser/client/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,14 @@ import { useDisclosure } from "@mantine/hooks";
import { rayToViserCoords } from "./WorldTransformUtils";
import { ndcFromPointerXy, opencvXyFromPointerXy } from "./ClickUtils";
import { theme } from "./AppTheme";
import { GaussianSplatsContext } from "./Splatting/GaussianSplats";
import {
GaussianSplatsContext,
useGaussianSplatStore,
} from "./Splatting/SplatContext";
import { FrameSynchronizedMessageHandler } from "./MessageHandler";
import { PlaybackFromFile } from "./FilePlayback";
import GlobalGaussianSplats from "./Splatting/GaussianSplats";
import { BrowserWarning } from "./BrowserWarning";

export type ViewerContextContents = {
messageSource: "websocket" | "file_playback";
Expand Down Expand Up @@ -228,6 +233,7 @@ function ViewerContents({ children }: { children: React.ReactNode }) {
},
}}
/>
<BrowserWarning />
<ViserModal />
<Box
style={{
Expand Down Expand Up @@ -259,10 +265,9 @@ function ViewerContents({ children }: { children: React.ReactNode }) {
})}
>
<Viewer2DCanvas />
<GaussianSplatsContext.Provider
value={React.useRef({ numSorting: 0, sortUpdateCallbacks: [] })}
>
<GaussianSplatsContext.Provider value={useGaussianSplatStore()}>
<ViewerCanvas>
<GlobalGaussianSplats />
<FrameSynchronizedMessageHandler />
</ViewerCanvas>
</GaussianSplatsContext.Provider>
Expand Down
45 changes: 45 additions & 0 deletions src/viser/client/src/BrowserWarning.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import { notifications } from "@mantine/notifications";
import { detect } from "detect-browser";
import { useEffect } from "react";

export function BrowserWarning() {
useEffect(() => {
const browser = detect();

// Browser version are based loosely on support for SIMD, OffscreenCanvas.
//
// https://caniuse.com/?search=simd
// https://caniuse.com/?search=OffscreenCanvas
if (browser === null || browser.version === null) {
console.log("Failed to detect browser");
notifications.show({
title: "Could not detect browser version",
message:
"Your browser version could not be detected. It may not be supported.",
autoClose: false,
color: "red",
});
} else {
const version = parseFloat(browser.version);
console.log(`Detected ${browser.name} version ${version}`);
if (
(browser.name === "chrome" && version < 91) ||
(browser.name === "edge" && version < 91) ||
(browser.name === "firefox" && version < 89) ||
(browser.name === "opera" && version < 77) ||
(browser.name === "safari" && version < 16.4)
)
notifications.show({
title: "Unsuppported browser",
message: `Your browser (${
browser.name.slice(0, 1).toUpperCase() + browser.name.slice(1)
}/${
browser.version
}) is outdated, which may cause problems. Consider updating.`,
autoClose: false,
color: "red",
});
}
});
return null;
}
32 changes: 16 additions & 16 deletions src/viser/client/src/FilePlayback.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,10 @@ export function PlaybackFromFile({ fileUrl }: { fileUrl: string }) {
shadow="0.1em 0 1em 0 rgba(0,0,0,0.1)"
style={{
position: "fixed",
bottom: "0.75em",
bottom: "1em",
left: "50%",
transform: "translateX(-50%)",
width: "20em",
width: "25em",
maxWidth: "95%",
zIndex: 1,
padding: "0.5em",
Expand All @@ -206,11 +206,19 @@ export function PlaybackFromFile({ fileUrl }: { fileUrl: string }) {
gap: "0.375em",
}}
>
<ActionIcon size="xs" variant="subtle">
<ActionIcon size="md" variant="subtle">
{paused ? (
<IconPlayerPlayFilled onClick={() => setPaused(false)} />
<IconPlayerPlayFilled
onClick={() => setPaused(false)}
height="1.125em"
width="1.125em"
/>
) : (
<IconPlayerPauseFilled onClick={() => setPaused(true)} />
<IconPlayerPauseFilled
onClick={() => setPaused(true)}
height="1.125em"
width="1.125em"
/>
)}
</ActionIcon>
<NumberInput
Expand All @@ -220,15 +228,12 @@ export function PlaybackFromFile({ fileUrl }: { fileUrl: string }) {
step={0.01}
styles={{
wrapper: {
width:
(recording.durationSeconds.toFixed(1).length * 0.8).toString() +
"em",
width: "3.1em",
},
input: {
padding: "0.2em",
fontFamily: theme.fontFamilyMonospace,
padding: "0.5em",
minHeight: "1.25rem",
height: "1.25rem",
textAlign: "center",
},
}}
onChange={(value) =>
Expand Down Expand Up @@ -257,11 +262,6 @@ export function PlaybackFromFile({ fileUrl }: { fileUrl: string }) {
data={["0.5x", "1x", "2x", "4x", "8x"]}
styles={{
wrapper: { width: "3.25em" },
input: {
padding: "0.5em",
minHeight: "1.25rem",
height: "1.25rem",
},
}}
comboboxProps={{ zIndex: 5, width: "5.25em" }}
/>
Expand Down
40 changes: 23 additions & 17 deletions src/viser/client/src/MessageHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import GeneratedGuiContainer from "./ControlPanel/Generated";
import { Paper, Progress } from "@mantine/core";
import { IconCheck } from "@tabler/icons-react";
import { computeT_threeworld_world } from "./WorldTransformUtils";
import GaussianSplats from "./Splatting/GaussianSplats";
import { GaussianSplatsContext } from "./Splatting/SplatContext";

/** Convert raw RGB color buffers to linear color buffers. **/
function threeColorBufferFromUint8Buffer(colors: ArrayBuffer) {
Expand All @@ -51,6 +51,10 @@ function useMessageHandler() {
const viewer = useContext(ViewerContext)!;
const ContextBridge = useContextBridge();

const splatContext = useContext(GaussianSplatsContext)!;
const setGaussianBuffer = splatContext((state) => state.setBuffer);
const removeGaussianBuffer = splatContext((state) => state.removeBuffer);

// We could reduce the redundancy here if we wanted to.
// https://github.com/nerfstudio-project/viser/issues/39
const removeSceneNode = viewer.useSceneTree((state) => state.removeSceneNode);
Expand Down Expand Up @@ -966,23 +970,25 @@ function useMessageHandler() {
return;
}
case "GaussianSplatsMessage": {
setGaussianBuffer(
message.name,
new Uint32Array(
message.buffer.buffer.slice(
message.buffer.byteOffset,
message.buffer.byteOffset + message.buffer.byteLength,
),
),
);
addSceneNodeMakeParents(
new SceneNode<THREE.Group>(message.name, (ref) => {
return (
<group ref={ref}>
<GaussianSplats
buffers={{
buffer: new Uint32Array(
message.buffer.buffer.slice(
message.buffer.byteOffset,
message.buffer.byteOffset + message.buffer.byteLength,
),
),
}}
/>
</group>
);
}),
new SceneNode<THREE.Group>(
message.name,
(ref) => {
return <group ref={ref}></group>;
},
() => {
removeGaussianBuffer(message.name);
},
),
);
return;
}
Expand Down
Loading

0 comments on commit 203d92a

Please sign in to comment.