Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify depth dimensions to match our input #2

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions scripts/splatam.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def initialize_first_timestep(dataset, num_frames, scene_radius_depth_ratio,

# Process RGB-D Data
color = color.permute(2, 0, 1) / 255 # (H, W, C) -> (C, H, W)
# Flatten to match expected dimensions
depth = torch.flatten(depth, start_dim=2)
depth = depth.permute(2, 0, 1) # (H, W, C) -> (C, H, W)

# Process Camera Parameters
Expand All @@ -186,14 +188,17 @@ def initialize_first_timestep(dataset, num_frames, scene_radius_depth_ratio,
# Get Densification RGB-D Data & Camera Parameters
color, depth, densify_intrinsics, _ = densify_dataset[0]
color = color.permute(2, 0, 1) / 255 # (H, W, C) -> (C, H, W)
# Flatten to match expected dimensions
depth = torch.flatten(depth, start_dim=2)
depth = depth.permute(2, 0, 1) # (H, W, C) -> (C, H, W)
densify_intrinsics = densify_intrinsics[:3, :3]
densify_cam = setup_camera(color.shape[2], color.shape[1], densify_intrinsics.cpu().numpy(), w2c.detach().cpu().numpy())
else:
densify_intrinsics = intrinsics

# Get Initial Point Cloud (PyTorch CUDA Tensor)
mask = (depth > 0) # Mask out invalid depth values
depth_z = depth[0] # Take only the 1st channel
mask = (depth_z > 0) # Mask out invalid depth values
mask = mask.reshape(-1)
init_pt_cld, mean3_sq_dist = get_pointcloud(color, depth, densify_intrinsics, w2c,
mask=mask, compute_mean_sq_dist=True,
Expand Down Expand Up @@ -281,7 +286,7 @@ def get_loss(params, curr_data, variables, iter_time_idx, loss_weights, use_sil_

# RGB Loss
if tracking and (use_sil_for_loss or ignore_outlier_depth_loss):
color_mask = torch.tile(mask, (3, 1, 1))
color_mask = torch.tile(mask, (1, 1, 1))
color_mask = color_mask.detach()
losses['im'] = torch.abs(curr_data['im'] - im)[color_mask].sum()
elif tracking:
Expand Down Expand Up @@ -632,6 +637,8 @@ def rgbd_slam(config: dict):
curr_w2c[:3, 3] = curr_cam_tran
# Initialize Keyframe Info
color = color.permute(2, 0, 1) / 255
# Flatten to match expected dimensions
depth = torch.flatten(depth, start_dim=2)
depth = depth.permute(2, 0, 1)
curr_keyframe = {'id': time_idx, 'est_w2c': curr_w2c, 'color': color, 'depth': depth}
# Add to keyframe list
Expand All @@ -647,6 +654,8 @@ def rgbd_slam(config: dict):
gt_w2c = torch.linalg.inv(gt_pose)
# Process RGB-D Data
color = color.permute(2, 0, 1) / 255
# Flatten to match expected dimensions
depth = torch.flatten(depth, start_dim=2)
depth = depth.permute(2, 0, 1)
gt_w2c_all_frames.append(gt_w2c)
curr_gt_w2c = gt_w2c_all_frames
Expand Down Expand Up @@ -782,6 +791,7 @@ def rgbd_slam(config: dict):
# Load RGBD frames incrementally instead of all frames
densify_color, densify_depth, _, _ = densify_dataset[time_idx]
densify_color = densify_color.permute(2, 0, 1) / 255
densify_depth = torch.flatten(densify_depth, start_dim=2)
densify_depth = densify_depth.permute(2, 0, 1)
densify_curr_data = {'cam': densify_cam, 'im': densify_color, 'depth': densify_depth, 'id': time_idx,
'intrinsics': densify_intrinsics, 'w2c': first_frame_w2c, 'iter_gt_w2c_list': curr_gt_w2c}
Expand Down Expand Up @@ -1011,4 +1021,4 @@ def rgbd_slam(config: dict):
os.makedirs(results_dir, exist_ok=True)
shutil.copy(args.experiment, os.path.join(results_dir, "config.py"))

rgbd_slam(experiment.config)
rgbd_slam(experiment.config)
5 changes: 3 additions & 2 deletions utils/eval_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def plot_rgbd_silhouette(color, depth, rastered_color, rastered_depth, presence_
axs[0, 2].imshow(presence_sil_mask, cmap='gray')
axs[0, 2].set_title("Rasterized Silhouette")
diff_depth_l1 = diff_depth_l1.cpu().squeeze(0)
axs[1, 2].imshow(diff_depth_l1, cmap='jet', vmin=0, vmax=6)
axs[1, 2].imshow(diff_depth_l1.permute(1, 2, 0), cmap='jet', vmin=0, vmax=6)
axs[1, 2].set_title("Diff Depth L1")
for ax in axs.flatten():
ax.axis('off')
Expand Down Expand Up @@ -435,6 +435,7 @@ def eval(dataset, final_params, num_frames, eval_dir, sil_thres,

# Process RGB-D Data
color = color.permute(2, 0, 1) / 255 # (H, W, C) -> (C, H, W)
depth = depth.flatten(start_dim=2)
depth = depth.permute(2, 0, 1) # (H, W, C) -> (C, H, W)

if time_idx == 0:
Expand Down Expand Up @@ -838,4 +839,4 @@ def eval_nvs(dataset, final_params, num_frames, eval_dir, sil_thres,
plt.savefig(os.path.join(eval_dir, "metrics.png"), bbox_inches='tight')
if wandb_run is not None:
wandb_run.log({"Eval/Metrics": fig})
plt.close()
plt.close()
4 changes: 3 additions & 1 deletion venv_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ torchmetrics
cyclonedds
pytorch-msssim
plyfile==0.8.1
git+https://github.com/JonathonLuiten/diff-gaussian-rasterization-w-depth.git@cb65e4b86bc3bd8ed42174b72a62e8d3a3a71110
opencv-python
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Santoi a missing dependency? If opencv is not used for visualization, consider using pulling opencv-python-headless instead. opencv-python and matplotlib don't play ball in certain cases (due to PyQt compatibility issues).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks!

open3d
git+https://github.com/JonathonLuiten/diff-gaussian-rasterization-w-depth.git@cb65e4b86bc3bd8ed42174b72a62e8d3a3a71110
5 changes: 5 additions & 0 deletions viz_scripts/final_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ def visualize(scene_path, cfg):
pcd = o3d.geometry.PointCloud()
pcd.points = init_pts
pcd.colors = init_cols
path = cfg['output']
o3d.io.write_point_cloud(path, pcd);
print("PCD written at: ", path);
vis.add_geometry(pcd)

w = cfg['viz_w']
Expand Down Expand Up @@ -279,6 +282,7 @@ def visualize(scene_path, cfg):
parser = argparse.ArgumentParser()

parser.add_argument("experiment", type=str, help="Path to experiment file")
parser.add_argument("pointcloud", type=str, help="Path to write output pointcloud file")

args = parser.parse_args()

Expand All @@ -296,6 +300,7 @@ def visualize(scene_path, cfg):
else:
scene_path = experiment.config["scene_path"]
viz_cfg = experiment.config["viz"]
viz_cfg["output"] = args.pointcloud

# Visualize Final Reconstruction
visualize(scene_path, viz_cfg)