diff --git a/model.cpp b/model.cpp index e8a7c58..f6f835e 100644 --- a/model.cpp +++ b/model.cpp @@ -5,6 +5,11 @@ #include "rasterize_gaussians.hpp" #include "tensor_math.hpp" #include "vendor/gsplat/config.h" +#ifdef USE_HIP +#include +#else +#include +#endif torch::Tensor randomQuatTensor(long long n){ torch::Tensor u = torch::rand(n); @@ -386,6 +391,11 @@ void Model::afterTrain(int step){ xysGradNorm = torch::Tensor(); visCounts = torch::Tensor(); max2DSize = torch::Tensor(); +#ifdef USE_HIP + c10::hip::HIPCachingAllocator::emptyCache(); +#else + c10::cuda::CUDACachingAllocator::emptyCache(); +#endif } }