Skip to content

Commit

Permalink
(Experimental) Gaussian splatting + WebGL (#110)
Browse files Browse the repository at this point in the history
* Add Gaussian splatting implementation
Shaders adapted from: https://github.com/antimatter15/splat/

* Optimize

* Fix uniform updates, performance improvements

* Sync

* Implement sort in C++/WebAssembly, 2~3x faster

* Rename hack pt 1

* Rename pt 2

* Component lifecycle cleanup

* Fix typescript

* Fix view error

* Account for pixel ratio in point cloud shader

* ruff

* Experiment

* Fix runtime error

* Use new websock_interface API

* Fix GUI container edge cases, the code needs a refactor 🙂

* 0.2.0

* Signficant performance improvements

* ruff

* "We have Luma AI at home"

* More improvements

* Fix bug caught by pyright

* Sorting cleanup

* More performance tweaks

* Cholesky, optimizations, formatting

* ruf isues

* pr changes (#237)

* pr changes

* nit

---------

Co-authored-by: Brent Yi <[email protected]>

* Nits, mark API as explicitly experimental

* Address pyright issues

* Docs / notes

* ruff

---------

Co-authored-by: Rebecca Feng <[email protected]>
Co-authored-by: beckyfeng08 <[email protected]>
  • Loading branch information
3 people committed Jul 11, 2024
1 parent 6f18e22 commit 03ce7f7
Show file tree
Hide file tree
Showing 27 changed files with 958 additions and 71 deletions.
6 changes: 6 additions & 0 deletions .clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# C++ formatting rules; used for WebAssembly code.
BasedOnStyle: LLVM
AlignAfterOpenBracket: BlockIndent
BinPackArguments: false
BinPackParameters: false
IndentWidth: 4
4 changes: 2 additions & 2 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
# Check out source.
- uses: actions/checkout@v2
with:
fetch-depth: 0 # This ensures the entire history is fetched so we can switch branches
fetch-depth: 0 # This ensures the entire history is fetched so we can switch branches

- name: Set up Python
uses: actions/setup-python@v1
Expand Down Expand Up @@ -60,7 +60,7 @@ jobs:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: ./docs/build
destination_dir: ${{ env.DOCS_SUBDIR }}
keep_files: false # This will only erase the destination subdirectory.
keep_files: false # This will only erase the destination subdirectory.
cname: viser.studio
if: github.event_name != 'pull_request'

Expand Down
8 changes: 6 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@ repos:
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- repo: https://github.com/charliermarsh/ruff-pre-commit
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: "v0.0.267"
rev: v0.5.1
hooks:
# Run the linter.
- id: ruff
args: [--fix]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/psf/black
rev: "23.3.0"
hooks:
Expand Down
2 changes: 2 additions & 0 deletions .prettierignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.mjs
build/
139 changes: 83 additions & 56 deletions docs/source/_templates/sidebar/brand.html
Original file line number Diff line number Diff line change
@@ -1,68 +1,95 @@
<a class="sidebar-brand{% if logo %} centered{% endif %}" href="{{ pathto(master_doc) }}">
{% block brand_content %} {%- if logo_url %}
<div class="sidebar-logo-container">
<a href="{{ pathto(master_doc) }}"><img class="sidebar-logo" src="{{ logo_url }}" alt="Logo" /></a>
</div>
{%- endif %} {%- if theme_light_logo and theme_dark_logo %}
<div class="sidebar-logo-container" style="margin: .5rem 1em .5rem 0">
<img class="sidebar-logo only-light" src="{{ pathto('_static/' + theme_light_logo, 1) }}" alt="logo" />
<img class="sidebar-logo only-dark" src="{{ pathto('_static/' + theme_dark_logo, 1) }}" alt="logo" />
</div>
{%- endif %}
<!-- <span class="sidebar-brand-text">{{ project }}</span> -->
<a
class="sidebar-brand{% if logo %} centered{% endif %}"
href="{{ pathto(master_doc) }}"
>
{% block brand_content %} {%- if logo_url %}
<div class="sidebar-logo-container">
<a href="{{ pathto(master_doc) }}"
><img class="sidebar-logo" src="{{ logo_url }}" alt="Logo"
/></a>
</div>
{%- endif %} {%- if theme_light_logo and theme_dark_logo %}
<div class="sidebar-logo-container" style="margin: 0.5rem 1em 0.5rem 0">
<img
class="sidebar-logo only-light"
src="{{ pathto('_static/' + theme_light_logo, 1) }}"
alt="logo"
/>
<img
class="sidebar-logo only-dark"
src="{{ pathto('_static/' + theme_dark_logo, 1) }}"
alt="logo"
/>
</div>
{%- endif %}
<!-- <span class="sidebar-brand-text">{{ project }}</span> -->

{% endblock brand_content %}
{% endblock brand_content %}
</a>

<!-- Dropdown for different versions of the viser docs. Slightly hacky. -->
<div style="padding: 0 1em;">
<script>
var viserDocsVersionsPopulated = false;
<div style="padding: 0 1em">
<script>
var viserDocsVersionsPopulated = false;

async function getViserVersionList() {
// This index.txt file is written by the docs.yml GitHub action.
// https://github.com/nerfstudio-project/viser/blob/main/.github/workflows/docs.yml
const response = await fetch(
"https://viser.studio/versions/index.txt",
{ cache: "no-cache" }
);
return await response.text();
async function getViserVersionList() {
// This index.txt file is written by the docs.yml GitHub action.
// https://github.com/nerfstudio-project/viser/blob/main/.github/workflows/docs.yml
const response = await fetch("https://viser.studio/versions/index.txt", {
cache: "no-cache",
});
return await response.text();
}
async function viserDocsPopulateVersionDropDown() {
// Load the version list lazily...
if (viserDocsVersionsPopulated) {
return;
}
async function viserDocsPopulateVersionDropDown () {
// Load the version list lazily...
if (viserDocsVersionsPopulated) {
return;
}
viserDocsVersionsPopulated = true;
viserDocsVersionsPopulated = true;

console.log("Populating docs version list!")
const versions = (await getViserVersionList()).trim().split("\n").reverse();
console.log(versions);
let htmlString = "<ul style='margin: 0.5rem 0 0 0'>";
htmlString += `<li><a href="https://viser.studio/latest">latest</a></li>`;
for (let version of versions) {
htmlString += `<li><a href="https://viser.studio/versions/${version}">${version}</a></li>`;
}

htmlString += "</ul>";
document.getElementById("viser-version-dropdown").innerHTML = htmlString;
console.log("Populating docs version list!");
const versions = (await getViserVersionList())
.trim()
.split("\n")
.reverse();
console.log(versions);
let htmlString = "<ul style='margin: 0.5rem 0 0 0'>";
htmlString += `<li><a href="https://viser.studio/latest">latest</a></li>`;
for (let version of versions) {
htmlString += `<li><a href="https://viser.studio/versions/${version}">${version}</a></li>`;
}
</script>
<details
style="padding: 0.5rem; background: var(--color-background-primary); border-radius: 0.5rem; border: 1px solid var(--color-sidebar-background-border);"
ontoggle="viserDocsPopulateVersionDropDown()"
>
<summary style="cursor: pointer;"><strong>Version:</strong> <em>{{ version }}</em></summary>
<div id="viser-version-dropdown"></div>
</details>
<!-- End dropdown -->

htmlString += "</ul>";
document.getElementById("viser-version-dropdown").innerHTML = htmlString;
}
</script>
<details
style="
padding: 0.5rem;
background: var(--color-background-primary);
border-radius: 0.5rem;
border: 1px solid var(--color-sidebar-background-border);
"
ontoggle="viserDocsPopulateVersionDropDown()"
>
<summary style="cursor: pointer">
<strong>Version:</strong> <em>{{ version }}</em>
</summary>
<div id="viser-version-dropdown"></div>
</details>
<!-- End dropdown -->
</div>

<div style="text-align: left; padding: 1em">
<script async defer src="https://buttons.github.io/buttons.js"></script>
<a class="github-button" href="https://github.com/nerfstudio-project/viser"
data-color-scheme="no-preference: light; light: light; dark: light;" data-size="large" data-show-count="true"
aria-label="Download buttons/github-buttons on GitHub">
Github
</a>
<script async defer src="https://buttons.github.io/buttons.js"></script>
<a
class="github-button"
href="https://github.com/nerfstudio-project/viser"
data-color-scheme="no-preference: light; light: light; dark: light;"
data-size="large"
data-show-count="true"
aria-label="Download buttons/github-buttons on GitHub"
>
Github
</a>
</div>
2 changes: 2 additions & 0 deletions docs/source/scene_handles.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,6 @@ connected clients. When a scene node is added to a client (for example, via

.. autoclass:: viser.TransformControlsHandle

.. autoclass:: viser.GaussianSplatHandle

<!-- prettier-ignore-end -->
167 changes: 167 additions & 0 deletions examples/experimental/gaussian_splats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""WebGL-based Gaussian splat rendering. This is still under developmentt."""

from __future__ import annotations

import time
from pathlib import Path
from typing import TypedDict

import numpy as onp
import numpy.typing as onpt
import tyro
import viser
from plyfile import PlyData
from viser import transforms as tf


class SplatFile(TypedDict):
"""Data loaded from an antimatter15-style splat file."""

centers: onpt.NDArray[onp.floating]
"""(N, 3)."""
rgbs: onpt.NDArray[onp.floating]
"""(N, 3). Range [0, 1]."""
opacities: onpt.NDArray[onp.floating]
"""(N, 1). Range [0, 1]."""
covariances: onpt.NDArray[onp.floating]
"""(N, 3, 3)."""


def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile:
"""Load an antimatter15-style splat file."""
splat_buffer = splat_path.read_bytes()
bytes_per_gaussian = (
# Each Gaussian is serialized as:
# - position (vec3, float32)
3 * 4
# - xyz (vec3, float32)
+ 3 * 4
# - rgba (vec4, uint8)
+ 4
# - ijkl (vec4, uint8), where 0 => -1, 255 => 1.
+ 4
)
assert len(splat_buffer) % bytes_per_gaussian == 0
num_gaussians = len(splat_buffer) // bytes_per_gaussian
print("Number of gaussians to render: ", f"{num_gaussians=}")

# Reinterpret cast to dtypes that we want to extract.
splat_uint8 = onp.frombuffer(splat_buffer, dtype=onp.uint8).reshape(
(num_gaussians, bytes_per_gaussian)
)
scales = splat_uint8[:, 12:24].copy().view(onp.float32)
wxyzs = splat_uint8[:, 28:32] / 255.0 * 2.0 - 1.0
Rs = onp.array([tf.SO3(wxyz).as_matrix() for wxyz in wxyzs])
covariances = onp.einsum(
"nij,njk,nlk->nil", Rs, onp.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs
)
centers = splat_uint8[:, 0:12].copy().view(onp.float32)
if center:
centers -= onp.mean(centers, axis=0, keepdims=True)
print("Splat file loaded")
return {
"centers": centers,
# Colors should have shape (N, 3).
"rgbs": splat_uint8[:, 24:27] / 255.0,
"opacities": splat_uint8[:, 27:28] / 255.0,
# Covariances should have shape (N, 3, 3).
"covariances": covariances,
}


def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile:
plydata = PlyData.read(ply_file_path)
vert = plydata["vertex"]
sorted_indices = onp.argsort(
-onp.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"])
/ (1 + onp.exp(-vert["opacity"]))
)
numgaussians = len(vert)
print("Number of gaussians to render: ", numgaussians)
colors = onp.zeros((numgaussians, 3))
opacities = onp.zeros((numgaussians, 1))
positions = onp.zeros((numgaussians, 3))
wxyzs = onp.zeros((numgaussians, 4))
scales = onp.zeros((numgaussians, 3))
for idx in sorted_indices:
v = plydata["vertex"][idx]
position = onp.array([v["x"], v["y"], v["z"]], dtype=onp.float32)
scale = onp.exp(
onp.array([v["scale_0"], v["scale_1"], v["scale_2"]], dtype=onp.float32)
)

rot = onp.array(
[v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], dtype=onp.float32
)
SH_C0 = 0.28209479177387814
color = onp.array(
[
0.5 + SH_C0 * v["f_dc_0"],
0.5 + SH_C0 * v["f_dc_1"],
0.5 + SH_C0 * v["f_dc_2"],
]
)
opacity = 1 / (1 + onp.exp(-v["opacity"]))
wxyz = rot / onp.linalg.norm(rot) # normalize
scales[idx] = scale
colors[idx] = color
opacities[idx] = onp.array([opacity])
positions[idx] = position
wxyzs[idx] = wxyz

Rs = onp.array([tf.SO3(wxyz).as_matrix() for wxyz in wxyzs])
covariances = onp.einsum(
"nij,njk,nlk->nil", Rs, onp.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs
)
if center:
positions -= onp.mean(positions, axis=0, keepdims=True)
print("PLY file loaded")
return {
"centers": positions,
# Colors should have shape (N, 3).
"rgbs": colors,
"opacities": opacities,
# Covariances should have shape (N, 3, 3).
"covariances": covariances,
}


def main(splat_paths: tuple[Path, ...], test_multisplat: bool = False) -> None:
server = viser.ViserServer(share=True)
server.gui.configure_theme(dark_mode=True)
gui_reset_up = server.gui.add_button(
"Reset up direction",
hint="Set the camera control 'up' direction to the current camera's 'up'.",
)

@gui_reset_up.on_click
def _(event: viser.GuiEvent) -> None:
client = event.client
assert client is not None
client.camera.up_direction = tf.SO3(client.camera.wxyz) @ onp.array(
[0.0, -1.0, 0.0]
)

for i, splat_path in enumerate(splat_paths):
if splat_path.suffix == ".splat":
splat_data = load_splat_file(splat_path, center=True)
elif splat_path.suffix == ".ply":
splat_data = load_ply_file(splat_path, center=True)
else:
raise SystemExit("Please provide a filepath to a .splat or .ply file.")

server.scene.add_transform_controls(f"/{i}")
server.scene._add_gaussian_splats(
f"/{i}/gaussian_splats",
centers=splat_data["centers"],
rgbs=splat_data["rgbs"],
opacities=splat_data["opacities"],
covariances=splat_data["covariances"],
)

while True:
time.sleep(10.0)


if __name__ == "__main__":
tyro.cli(main)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"nodeenv>=1.8.0",
"psutil>=5.9.5",
"yourdfpy>=0.0.53",
"plyfile>=1.0.2"
]

[project.optional-dependencies]
Expand All @@ -48,6 +49,7 @@ examples = [
"plotly>=5.21.0",
"robot_descriptions>=1.10.0",
"gdown>=4.6.6",
"plyfile",
]

[project.urls]
Expand Down
1 change: 1 addition & 0 deletions src/viser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ._scene_handles import BatchedAxesHandle as BatchedAxesHandle
from ._scene_handles import CameraFrustumHandle as CameraFrustumHandle
from ._scene_handles import FrameHandle as FrameHandle
from ._scene_handles import GaussianSplatHandle as GaussianSplatHandle
from ._scene_handles import GlbHandle as GlbHandle
from ._scene_handles import Gui3dContainerHandle as Gui3dContainerHandle
from ._scene_handles import ImageHandle as ImageHandle
Expand Down
Loading

0 comments on commit 03ce7f7

Please sign in to comment.