Skip to content

Commit

Permalink
Revert RasterizeGaussiansCPU changes
Browse files Browse the repository at this point in the history
  • Loading branch information
pierotofy committed Apr 15, 2024
1 parent 043c9b2 commit 8c376c8
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions rasterize_gaussians.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,6 @@ torch::Tensor RasterizeGaussiansCPU::forward(AutogradContext *ctx,
){

int numPoints = xys.size(0);
ctx->saved_data["imgWidth"] = imgWidth;
ctx->saved_data["imgHeight"] = imgHeight;
torch::Device device = xys.device();

auto t = rasterize_forward_tensor_cpu(imgWidth, imgHeight,
xys,
Expand All @@ -169,12 +166,14 @@ torch::Tensor RasterizeGaussiansCPU::forward(AutogradContext *ctx,
camDepths
);
// Final image
torch::Tensor outImg = std::get<0>(t).to(device);
torch::Tensor outImg = std::get<0>(t);

torch::Tensor finalTs = std::get<1>(t).to(device);
torch::Tensor finalTs = std::get<1>(t);
std::vector<int32_t> *px2gid = std::get<2>(t);

ctx->saved_data["px2gid"] = reinterpret_cast<int64_t>(px2gid);
ctx->saved_data["imgWidth"] = imgWidth;
ctx->saved_data["imgHeight"] = imgHeight;
ctx->save_for_backward({ xys, conics, colors, opacity, background, cov2d, camDepths, finalTs });

return outImg;
Expand All @@ -197,7 +196,6 @@ tensor_list RasterizeGaussiansCPU::backward(AutogradContext *ctx, tensor_list gr
torch::Tensor finalTs = saved[7];

torch::Tensor v_outAlpha = torch::zeros_like(v_outImg.index({"...", 0}));
torch::Device device = xys.device();

auto t = rasterize_backward_tensor_cpu(imgHeight, imgWidth,
xys,
Expand All @@ -215,10 +213,10 @@ tensor_list RasterizeGaussiansCPU::backward(AutogradContext *ctx, tensor_list gr
delete[] px2gid;


torch::Tensor v_xy = std::get<0>(t).to(device);
torch::Tensor v_conic = std::get<1>(t).to(device);
torch::Tensor v_colors = std::get<2>(t).to(device);
torch::Tensor v_opacity = std::get<3>(t).to(device);
torch::Tensor v_xy = std::get<0>(t);
torch::Tensor v_conic = std::get<1>(t);
torch::Tensor v_colors = std::get<2>(t);
torch::Tensor v_opacity = std::get<3>(t);
torch::Tensor none;

return { v_xy,
Expand Down

0 comments on commit 8c376c8

Please sign in to comment.