Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Global sort + SIMD for Gaussian rendering #252

Merged
merged 25 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 34 additions & 39 deletions examples/experimental/gaussian_splats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tyro
import viser
from plyfile import PlyData
from tqdm import tqdm
from viser import transforms as tf


Expand Down Expand Up @@ -72,49 +73,43 @@ def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile:
def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile:
plydata = PlyData.read(ply_file_path)
vert = plydata["vertex"]
sorted_indices = onp.argsort(
-onp.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"])
/ (1 + onp.exp(-vert["opacity"]))
)
numgaussians = len(vert)
print("Number of gaussians to render: ", numgaussians)
colors = onp.zeros((numgaussians, 3))
opacities = onp.zeros((numgaussians, 1))
positions = onp.zeros((numgaussians, 3))
wxyzs = onp.zeros((numgaussians, 4))
scales = onp.zeros((numgaussians, 3))
for idx in sorted_indices:
v = plydata["vertex"][idx]
position = onp.array([v["x"], v["y"], v["z"]], dtype=onp.float32)
scale = onp.exp(
onp.array([v["scale_0"], v["scale_1"], v["scale_2"]], dtype=onp.float32)
)

rot = onp.array(
[v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], dtype=onp.float32
)
SH_C0 = 0.28209479177387814
color = onp.array(
[
0.5 + SH_C0 * v["f_dc_0"],
0.5 + SH_C0 * v["f_dc_1"],
0.5 + SH_C0 * v["f_dc_2"],
]
)
opacity = 1 / (1 + onp.exp(-v["opacity"]))
wxyz = rot / onp.linalg.norm(rot) # normalize
scales[idx] = scale
colors[idx] = color
opacities[idx] = onp.array([opacity])
positions[idx] = position
wxyzs[idx] = wxyz

Rs = onp.array([tf.SO3(wxyz).as_matrix() for wxyz in wxyzs])
num_gaussians = len(vert)
print("Number of gaussians to render: ", num_gaussians)
colors = onp.zeros((num_gaussians, 3))
opacities = onp.zeros((num_gaussians, 1))
positions = onp.zeros((num_gaussians, 3))
wxyzs = onp.zeros((num_gaussians, 4))
scales = onp.zeros((num_gaussians, 3))

for i in tqdm(range(num_gaussians)):
v = plydata["vertex"][i]
positions[i, 0] = v["x"]
positions[i, 1] = v["y"]
positions[i, 2] = v["z"]
scales[i, 0] = v["scale_0"]
scales[i, 1] = v["scale_1"]
scales[i, 2] = v["scale_2"]
wxyzs[i, 0] = v["rot_0"]
wxyzs[i, 1] = v["rot_1"]
wxyzs[i, 2] = v["rot_2"]
wxyzs[i, 3] = v["rot_3"]
colors[i, 0] = v["f_dc_0"]
colors[i, 1] = v["f_dc_1"]
colors[i, 2] = v["f_dc_2"]
opacities[i, 0] = v["opacity"]

SH_C0 = 0.28209479177387814
scales = onp.exp(scales)
colors = 0.5 + SH_C0 * colors
opacities = 1.0 / (1.0 + onp.exp(-opacities))

Rs = tf.SO3(wxyzs).as_matrix()
covariances = onp.einsum(
"nij,njk,nlk->nil", Rs, onp.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs
)
if center:
positions -= onp.mean(positions, axis=0, keepdims=True)

print("PLY file loaded")
return {
"centers": positions,
Expand All @@ -127,7 +122,7 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile:


def main(splat_paths: tuple[Path, ...]) -> None:
server = viser.ViserServer(share=True)
server = viser.ViserServer()
server.gui.configure_theme(dark_mode=True)
gui_reset_up = server.gui.add_button(
"Reset up direction",
Expand Down
2 changes: 1 addition & 1 deletion src/viser/_gui_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class GuiApi:
"""Interface for working with the 2D GUI in viser.

Used by both our global server object, for sharing the same GUI elements
with all clients, and by invidividual client handles."""
with all clients, and by individual client handles."""

_target_container_from_thread_id: dict[int, str] = {}
"""ID of container to put GUI elements into."""
Expand Down
8 changes: 4 additions & 4 deletions src/viser/_scene_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class SceneApi:
"""Interface for adding 3D primitives to the scene.

Used by both our global server object, for sharing the same GUI elements
with all clients, and by invidividual client handles."""
with all clients, and by individual client handles."""

def __init__(
self,
Expand Down Expand Up @@ -946,8 +946,7 @@ def _add_gaussian_splats(
position: Tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0),
visible: bool = True,
) -> GaussianSplatHandle:
"""Add a model to render using Gaussian Splatting. Does not yet support
spherical harmonics.
"""Add a model to render using Gaussian Splatting.

**Work-in-progress.** This feature is experimental and still under
development. It may be changed or removed.
Expand All @@ -971,7 +970,8 @@ def _add_gaussian_splats(
assert opacities.shape == (num_gaussians, 1)
assert covariances.shape == (num_gaussians, 3, 3)

# Get cholesky factor of covariance.
# Get cholesky factor of covariance. This helps retain precision when
# we convert to float16.
cov_cholesky_triu = (
onp.linalg.cholesky(covariances.astype(onp.float64) + onp.ones(3) * 1e-7)
.swapaxes(-1, -2) # tril => triu
Expand Down
1 change: 1 addition & 0 deletions src/viser/client/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"clsx": "^2.1.0",
"colortranslator": "^4.1.0",
"dayjs": "^1.11.10",
"detect-browser": "^5.3.0",
"fflate": "^0.8.2",
"hold-event": "^1.1.0",
"immer": "^10.0.4",
Expand Down
13 changes: 9 additions & 4 deletions src/viser/client/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,14 @@ import { useDisclosure } from "@mantine/hooks";
import { rayToViserCoords } from "./WorldTransformUtils";
import { ndcFromPointerXy, opencvXyFromPointerXy } from "./ClickUtils";
import { theme } from "./AppTheme";
import { GaussianSplatsContext } from "./Splatting/GaussianSplats";
import {
GaussianSplatsContext,
useGaussianSplatStore,
} from "./Splatting/SplatContext";
import { FrameSynchronizedMessageHandler } from "./MessageHandler";
import { PlaybackFromFile } from "./FilePlayback";
import GlobalGaussianSplats from "./Splatting/GaussianSplats";
import { BrowserWarning } from "./BrowserWarning";

export type ViewerContextContents = {
messageSource: "websocket" | "file_playback";
Expand Down Expand Up @@ -228,6 +233,7 @@ function ViewerContents({ children }: { children: React.ReactNode }) {
},
}}
/>
<BrowserWarning />
<ViserModal />
<Box
style={{
Expand Down Expand Up @@ -259,10 +265,9 @@ function ViewerContents({ children }: { children: React.ReactNode }) {
})}
>
<Viewer2DCanvas />
<GaussianSplatsContext.Provider
value={React.useRef({ numSorting: 0, sortUpdateCallbacks: [] })}
>
<GaussianSplatsContext.Provider value={useGaussianSplatStore()}>
<ViewerCanvas>
<GlobalGaussianSplats />
<FrameSynchronizedMessageHandler />
</ViewerCanvas>
</GaussianSplatsContext.Provider>
Expand Down
45 changes: 45 additions & 0 deletions src/viser/client/src/BrowserWarning.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import { notifications } from "@mantine/notifications";
import { detect } from "detect-browser";
import { useEffect } from "react";

export function BrowserWarning() {
useEffect(() => {
const browser = detect();

// Browser version are based loosely on support for SIMD, OffscreenCanvas.
//
// https://caniuse.com/?search=simd
// https://caniuse.com/?search=OffscreenCanvas
if (browser === null || browser.version === null) {
console.log("Failed to detect browser");
notifications.show({
title: "Could not detect browser version",
message:
"Your browser version could not be detected. It may not be supported.",
autoClose: false,
color: "red",
});
} else {
const version = parseFloat(browser.version);
console.log(`Detected ${browser.name} version ${version}`);
if (
(browser.name === "chrome" && version < 91) ||
(browser.name === "edge" && version < 91) ||
(browser.name === "firefox" && version < 89) ||
(browser.name === "opera" && version < 77) ||
(browser.name === "safari" && version < 16.4)
)
notifications.show({
title: "Unsuppported browser",
message: `Your browser (${
browser.name.slice(0, 1).toUpperCase() + browser.name.slice(1)
}/${
browser.version
}) is outdated, which may cause problems. Consider updating.`,
autoClose: false,
color: "red",
});
}
});
return null;
}
32 changes: 16 additions & 16 deletions src/viser/client/src/FilePlayback.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,10 @@ export function PlaybackFromFile({ fileUrl }: { fileUrl: string }) {
shadow="0.1em 0 1em 0 rgba(0,0,0,0.1)"
style={{
position: "fixed",
bottom: "0.75em",
bottom: "1em",
left: "50%",
transform: "translateX(-50%)",
width: "20em",
width: "25em",
maxWidth: "95%",
zIndex: 1,
padding: "0.5em",
Expand All @@ -206,11 +206,19 @@ export function PlaybackFromFile({ fileUrl }: { fileUrl: string }) {
gap: "0.375em",
}}
>
<ActionIcon size="xs" variant="subtle">
<ActionIcon size="md" variant="subtle">
{paused ? (
<IconPlayerPlayFilled onClick={() => setPaused(false)} />
<IconPlayerPlayFilled
onClick={() => setPaused(false)}
height="1.125em"
width="1.125em"
/>
) : (
<IconPlayerPauseFilled onClick={() => setPaused(true)} />
<IconPlayerPauseFilled
onClick={() => setPaused(true)}
height="1.125em"
width="1.125em"
/>
)}
</ActionIcon>
<NumberInput
Expand All @@ -220,15 +228,12 @@ export function PlaybackFromFile({ fileUrl }: { fileUrl: string }) {
step={0.01}
styles={{
wrapper: {
width:
(recording.durationSeconds.toFixed(1).length * 0.8).toString() +
"em",
width: "3.1em",
},
input: {
padding: "0.2em",
fontFamily: theme.fontFamilyMonospace,
padding: "0.5em",
minHeight: "1.25rem",
height: "1.25rem",
textAlign: "center",
},
}}
onChange={(value) =>
Expand Down Expand Up @@ -257,11 +262,6 @@ export function PlaybackFromFile({ fileUrl }: { fileUrl: string }) {
data={["0.5x", "1x", "2x", "4x", "8x"]}
styles={{
wrapper: { width: "3.25em" },
input: {
padding: "0.5em",
minHeight: "1.25rem",
height: "1.25rem",
},
}}
comboboxProps={{ zIndex: 5, width: "5.25em" }}
/>
Expand Down
40 changes: 23 additions & 17 deletions src/viser/client/src/MessageHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import GeneratedGuiContainer from "./ControlPanel/Generated";
import { Paper, Progress } from "@mantine/core";
import { IconCheck } from "@tabler/icons-react";
import { computeT_threeworld_world } from "./WorldTransformUtils";
import GaussianSplats from "./Splatting/GaussianSplats";
import { GaussianSplatsContext } from "./Splatting/SplatContext";

/** Convert raw RGB color buffers to linear color buffers. **/
function threeColorBufferFromUint8Buffer(colors: ArrayBuffer) {
Expand All @@ -51,6 +51,10 @@ function useMessageHandler() {
const viewer = useContext(ViewerContext)!;
const ContextBridge = useContextBridge();

const splatContext = useContext(GaussianSplatsContext)!;
const setGaussianBuffer = splatContext((state) => state.setBuffer);
const removeGaussianBuffer = splatContext((state) => state.removeBuffer);

// We could reduce the redundancy here if we wanted to.
// https://github.com/nerfstudio-project/viser/issues/39
const removeSceneNode = viewer.useSceneTree((state) => state.removeSceneNode);
Expand Down Expand Up @@ -966,23 +970,25 @@ function useMessageHandler() {
return;
}
case "GaussianSplatsMessage": {
setGaussianBuffer(
message.name,
new Uint32Array(
message.buffer.buffer.slice(
message.buffer.byteOffset,
message.buffer.byteOffset + message.buffer.byteLength,
),
),
);
addSceneNodeMakeParents(
new SceneNode<THREE.Group>(message.name, (ref) => {
return (
<group ref={ref}>
<GaussianSplats
buffers={{
buffer: new Uint32Array(
message.buffer.buffer.slice(
message.buffer.byteOffset,
message.buffer.byteOffset + message.buffer.byteLength,
),
),
}}
/>
</group>
);
}),
new SceneNode<THREE.Group>(
message.name,
(ref) => {
return <group ref={ref}></group>;
},
() => {
removeGaussianBuffer(message.name);
},
),
);
return;
}
Expand Down
Loading
Loading