Skip to content

Commit

Permalink
Much faster dataloading for Gaussian example
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jul 26, 2024
1 parent 94d157a commit d4f0b3a
Showing 1 changed file with 18 additions and 35 deletions.
53 changes: 18 additions & 35 deletions examples/experimental/gaussian_splats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import tyro
import viser
from plyfile import PlyData
from tqdm import tqdm
from viser import transforms as tf


Expand All @@ -30,6 +29,7 @@ class SplatFile(TypedDict):

def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile:
"""Load an antimatter15-style splat file."""
start_time = time.time()
splat_buffer = splat_path.read_bytes()
bytes_per_gaussian = (
# Each Gaussian is serialized as:
Expand All @@ -44,7 +44,6 @@ def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile:
)
assert len(splat_buffer) % bytes_per_gaussian == 0
num_gaussians = len(splat_buffer) // bytes_per_gaussian
print("Number of gaussians to render: ", f"{num_gaussians=}")

# Reinterpret cast to dtypes that we want to extract.
splat_uint8 = onp.frombuffer(splat_buffer, dtype=onp.uint8).reshape(
Expand All @@ -59,7 +58,9 @@ def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile:
centers = splat_uint8[:, 0:12].copy().view(onp.float32)
if center:
centers -= onp.mean(centers, axis=0, keepdims=True)
print("Splat file loaded")
print(
f"Splat file with {num_gaussians=} loaded in {time.time() - start_time} seconds"
)
return {
"centers": centers,
# Colors should have shape (N, 3).
Expand All @@ -71,37 +72,18 @@ def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile:


def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile:
plydata = PlyData.read(ply_file_path)
vert = plydata["vertex"]
num_gaussians = len(vert)
print("Number of gaussians to render: ", num_gaussians)
colors = onp.zeros((num_gaussians, 3))
opacities = onp.zeros((num_gaussians, 1))
positions = onp.zeros((num_gaussians, 3))
wxyzs = onp.zeros((num_gaussians, 4))
scales = onp.zeros((num_gaussians, 3))

for i in tqdm(range(num_gaussians)):
v = plydata["vertex"][i]
positions[i, 0] = v["x"]
positions[i, 1] = v["y"]
positions[i, 2] = v["z"]
scales[i, 0] = v["scale_0"]
scales[i, 1] = v["scale_1"]
scales[i, 2] = v["scale_2"]
wxyzs[i, 0] = v["rot_0"]
wxyzs[i, 1] = v["rot_1"]
wxyzs[i, 2] = v["rot_2"]
wxyzs[i, 3] = v["rot_3"]
colors[i, 0] = v["f_dc_0"]
colors[i, 1] = v["f_dc_1"]
colors[i, 2] = v["f_dc_2"]
opacities[i, 0] = v["opacity"]
"""Load Gaussians stored in a PLY file."""
start_time = time.time()

SH_C0 = 0.28209479177387814
scales = onp.exp(scales)
colors = 0.5 + SH_C0 * colors
opacities = 1.0 / (1.0 + onp.exp(-opacities))

plydata = PlyData.read(ply_file_path)
v = plydata["vertex"]
positions = onp.stack([v["x"], v["y"], v["z"]], axis=-1)
scales = onp.exp(onp.stack([v["scale_0"], v["scale_1"], v["scale_2"]], axis=-1))
wxyzs = onp.stack([v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], axis=1)
colors = 0.5 + SH_C0 * onp.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1)
opacities = 1.0 / (1.0 + onp.exp(-v["opacity"][:, None]))

Rs = tf.SO3(wxyzs).as_matrix()
covariances = onp.einsum(
Expand All @@ -110,13 +92,14 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile:
if center:
positions -= onp.mean(positions, axis=0, keepdims=True)

print("PLY file loaded")
num_gaussians = len(v)
print(
f"PLY file with {num_gaussians=} loaded in {time.time() - start_time} seconds"
)
return {
"centers": positions,
# Colors should have shape (N, 3).
"rgbs": colors,
"opacities": opacities,
# Covariances should have shape (N, 3, 3).
"covariances": covariances,
}

Expand Down

0 comments on commit d4f0b3a

Please sign in to comment.