From d4f0b3ab65642e7126567b84a2493cfc9e03d762 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sat, 27 Jul 2024 01:17:57 +0900 Subject: [PATCH] Much faster dataloading for Gaussian example --- examples/experimental/gaussian_splats.py | 53 ++++++++---------------- 1 file changed, 18 insertions(+), 35 deletions(-) diff --git a/examples/experimental/gaussian_splats.py b/examples/experimental/gaussian_splats.py index c195435b..ddc41233 100644 --- a/examples/experimental/gaussian_splats.py +++ b/examples/experimental/gaussian_splats.py @@ -11,7 +11,6 @@ import tyro import viser from plyfile import PlyData -from tqdm import tqdm from viser import transforms as tf @@ -30,6 +29,7 @@ class SplatFile(TypedDict): def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile: """Load an antimatter15-style splat file.""" + start_time = time.time() splat_buffer = splat_path.read_bytes() bytes_per_gaussian = ( # Each Gaussian is serialized as: @@ -44,7 +44,6 @@ def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile: ) assert len(splat_buffer) % bytes_per_gaussian == 0 num_gaussians = len(splat_buffer) // bytes_per_gaussian - print("Number of gaussians to render: ", f"{num_gaussians=}") # Reinterpret cast to dtypes that we want to extract. splat_uint8 = onp.frombuffer(splat_buffer, dtype=onp.uint8).reshape( @@ -59,7 +58,9 @@ def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile: centers = splat_uint8[:, 0:12].copy().view(onp.float32) if center: centers -= onp.mean(centers, axis=0, keepdims=True) - print("Splat file loaded") + print( + f"Splat file with {num_gaussians=} loaded in {time.time() - start_time} seconds" + ) return { "centers": centers, # Colors should have shape (N, 3). @@ -71,37 +72,18 @@ def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile: def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: - plydata = PlyData.read(ply_file_path) - vert = plydata["vertex"] - num_gaussians = len(vert) - print("Number of gaussians to render: ", num_gaussians) - colors = onp.zeros((num_gaussians, 3)) - opacities = onp.zeros((num_gaussians, 1)) - positions = onp.zeros((num_gaussians, 3)) - wxyzs = onp.zeros((num_gaussians, 4)) - scales = onp.zeros((num_gaussians, 3)) - - for i in tqdm(range(num_gaussians)): - v = plydata["vertex"][i] - positions[i, 0] = v["x"] - positions[i, 1] = v["y"] - positions[i, 2] = v["z"] - scales[i, 0] = v["scale_0"] - scales[i, 1] = v["scale_1"] - scales[i, 2] = v["scale_2"] - wxyzs[i, 0] = v["rot_0"] - wxyzs[i, 1] = v["rot_1"] - wxyzs[i, 2] = v["rot_2"] - wxyzs[i, 3] = v["rot_3"] - colors[i, 0] = v["f_dc_0"] - colors[i, 1] = v["f_dc_1"] - colors[i, 2] = v["f_dc_2"] - opacities[i, 0] = v["opacity"] + """Load Gaussians stored in a PLY file.""" + start_time = time.time() SH_C0 = 0.28209479177387814 - scales = onp.exp(scales) - colors = 0.5 + SH_C0 * colors - opacities = 1.0 / (1.0 + onp.exp(-opacities)) + + plydata = PlyData.read(ply_file_path) + v = plydata["vertex"] + positions = onp.stack([v["x"], v["y"], v["z"]], axis=-1) + scales = onp.exp(onp.stack([v["scale_0"], v["scale_1"], v["scale_2"]], axis=-1)) + wxyzs = onp.stack([v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], axis=1) + colors = 0.5 + SH_C0 * onp.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1) + opacities = 1.0 / (1.0 + onp.exp(-v["opacity"][:, None])) Rs = tf.SO3(wxyzs).as_matrix() covariances = onp.einsum( @@ -110,13 +92,14 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: if center: positions -= onp.mean(positions, axis=0, keepdims=True) - print("PLY file loaded") + num_gaussians = len(v) + print( + f"PLY file with {num_gaussians=} loaded in {time.time() - start_time} seconds" + ) return { "centers": positions, - # Colors should have shape (N, 3). "rgbs": colors, "opacities": opacities, - # Covariances should have shape (N, 3, 3). "covariances": covariances, }