diff --git a/rasterize_gaussians.cpp b/rasterize_gaussians.cpp index 8a32330..736d4c0 100644 --- a/rasterize_gaussians.cpp +++ b/rasterize_gaussians.cpp @@ -155,9 +155,6 @@ torch::Tensor RasterizeGaussiansCPU::forward(AutogradContext *ctx, ){ int numPoints = xys.size(0); - ctx->saved_data["imgWidth"] = imgWidth; - ctx->saved_data["imgHeight"] = imgHeight; - torch::Device device = xys.device(); auto t = rasterize_forward_tensor_cpu(imgWidth, imgHeight, xys, @@ -169,12 +166,14 @@ torch::Tensor RasterizeGaussiansCPU::forward(AutogradContext *ctx, camDepths ); // Final image - torch::Tensor outImg = std::get<0>(t).to(device); + torch::Tensor outImg = std::get<0>(t); - torch::Tensor finalTs = std::get<1>(t).to(device); + torch::Tensor finalTs = std::get<1>(t); std::vector *px2gid = std::get<2>(t); ctx->saved_data["px2gid"] = reinterpret_cast(px2gid); + ctx->saved_data["imgWidth"] = imgWidth; + ctx->saved_data["imgHeight"] = imgHeight; ctx->save_for_backward({ xys, conics, colors, opacity, background, cov2d, camDepths, finalTs }); return outImg; @@ -197,7 +196,6 @@ tensor_list RasterizeGaussiansCPU::backward(AutogradContext *ctx, tensor_list gr torch::Tensor finalTs = saved[7]; torch::Tensor v_outAlpha = torch::zeros_like(v_outImg.index({"...", 0})); - torch::Device device = xys.device(); auto t = rasterize_backward_tensor_cpu(imgHeight, imgWidth, xys, @@ -215,10 +213,10 @@ tensor_list RasterizeGaussiansCPU::backward(AutogradContext *ctx, tensor_list gr delete[] px2gid; - torch::Tensor v_xy = std::get<0>(t).to(device); - torch::Tensor v_conic = std::get<1>(t).to(device); - torch::Tensor v_colors = std::get<2>(t).to(device); - torch::Tensor v_opacity = std::get<3>(t).to(device); + torch::Tensor v_xy = std::get<0>(t); + torch::Tensor v_conic = std::get<1>(t); + torch::Tensor v_colors = std::get<2>(t); + torch::Tensor v_opacity = std::get<3>(t); torch::Tensor none; return { v_xy,