Skip to content

Commit

Permalink
More verbose, update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
pierotofy committed Feb 17, 2024
1 parent 22fa345 commit 0327c00
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

A free and open source implementation of 3D gaussian splatting, written in C++. It's based on [splatfacto](https://docs.nerf.studio/nerfology/methods/splat.html) and focuses on being portable, lean and fast.

![OpenSplat](https://github.com/pierotofy/OpenSplat/assets/1951843/3461e0e4-e134-4d6a-8a56-d89d00258e41)


OpenSplat takes camera poses + sparse points and computes a scene file (.ply) that can be later imported for viewing, editing and rendering in other [software](https://github.com/MrNeRF/awesome-3D-gaussian-splatting?tab=readme-ov-file#open-source-implementations).

## Build
Expand Down
8 changes: 5 additions & 3 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ void Model::afterTrain(int step){
const float cullAlphaThresh = 0.1f;

if (doDensification){
int numPointsBefore = means.size(0);
torch::Tensor avgGradNorm = (xysGradNorm / visCounts) * 0.5f * static_cast<float>( (std::max)(lastWidth, lastHeight) );
torch::Tensor highGrads = (avgGradNorm > densifyGradThresh).squeeze();

Expand Down Expand Up @@ -288,7 +289,6 @@ void Model::afterTrain(int step){
// Duplicate gaussians that are too small
torch::Tensor dups = (std::get<0>(scales.exp().max(-1)) <= densifySizeThresh).squeeze();
dups &= highGrads;
int nDups = dups.sum().item<int>();
torch::Tensor dupMeans = means.index({dups});
torch::Tensor dupFeaturesDc = featuresDc.index({dups});
torch::Tensor dupFeaturesRest = featuresRest.index({dups});
Expand Down Expand Up @@ -330,14 +330,15 @@ void Model::afterTrain(int step){
splits,
torch::full({nSplitSamples * splits.sum().item<int>() + dups.sum().item<int>()}, false, torch::TensorOptions().dtype(torch::kBool).device(device))
}, 0);

std::cout << "Added " << (means.size(0) - numPointsBefore) << " gaussians, new count " << means.size(0) << std::endl;
}

if (doDensification || step >= stopSplitAt){
// Cull
int numPointsBefore = means.size(0);

torch::Tensor culls = (torch::sigmoid(opacities) < cullAlphaThresh).squeeze();
int hugeCount = 0;
if (splitsMask.numel()){
culls |= splitsMask;
}
Expand All @@ -350,7 +351,6 @@ void Model::afterTrain(int step){
huge |= max2DSize > cullScreenSize;
}
culls |= huge;
hugeCount = torch::sum(huge).item<int>();
}

int cullCount = torch::sum(culls).item<int>();
Expand All @@ -369,6 +369,7 @@ void Model::afterTrain(int step){
removeFromOptimizer(featuresRestOpt, featuresRest, culls);
removeFromOptimizer(opacitiesOpt, opacities, culls);

std::cout << "Culled " << (numPointsBefore - means.size(0)) << " gaussians, remaining " << means.size(0) << std::endl;
}
}

Expand All @@ -382,6 +383,7 @@ void Model::afterTrain(int step){
auto paramState = std::make_unique<torch::optim::AdamParamState>(static_cast<torch::optim::AdamParamState&>(*opacitiesOpt->state()[pId]));
paramState->exp_avg(torch::zeros_like(paramState->exp_avg()));
paramState->exp_avg_sq(torch::zeros_like(paramState->exp_avg_sq()));
std::cout << "Alpha reset" << std::endl;
}

// Clear
Expand Down
10 changes: 8 additions & 2 deletions opensplat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ int main(int argc, char *argv[]){

InfiniteRandomIterator<ns::Camera> cams(inputData.cameras);

int imageSize = -1;
for (size_t step = 1; step <= numIters; step++){
ns::Camera cam = cams.next();

Expand All @@ -103,17 +104,22 @@ int main(int argc, char *argv[]){
torch::Tensor gt = cam.getImage(model.getDownscaleFactor(step));
gt = gt.to(device);

if (gt.size(0) != imageSize){
imageSize = gt.size(0) + 1;
std::cout << "Image size " << imageSize << "px" << std::endl;
}

torch::Tensor ssimLoss = 1.0f - model.ssim.eval(rgb, gt);
torch::Tensor l1Loss = ns::l1(rgb, gt);
torch::Tensor mainLoss = (1.0f - ssimWeight) * l1Loss + ssimWeight * ssimLoss;
mainLoss.backward();

if (step % 10 == 0) std::cout << "Step " << step << ": " << mainLoss.item<float>() << std::endl;

model.optimizersStep();
//model.optimizersScheduleStep(); // TODO
model.afterTrain(step);

if (step % 10 == 0) std::cout << "Step " << step << ": " << mainLoss.item<float>() << std::endl;

if (saveEvery > 0 && step % saveEvery == 0){
fs::path p(outputScene);
model.savePlySplat(p.replace_filename(fs::path(p.stem().string() + "_" + std::to_string(step) + p.extension().string()).string()));
Expand Down

0 comments on commit 0327c00

Please sign in to comment.