Skip to content

Commit

Permalink
Merge pull request #56 from pierotofy/orient
Browse files Browse the repository at this point in the history
Preserve scale/orientation of scene input
  • Loading branch information
pierotofy committed Mar 22, 2024
2 parents 1b1db43 + c584323 commit 521121d
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 43 deletions.
13 changes: 5 additions & 8 deletions colmap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,11 @@ InputData inputDataFromColmap(const std::string &projectRoot){
}

imgf.close();
auto r = autoOrientAndCenterPoses(unorientedPoses);

auto r = autoScaleAndCenterPoses(unorientedPoses);
torch::Tensor poses = std::get<0>(r);
ret.transformMatrix = std::get<1>(r);
ret.scaleFactor = 1.0f / torch::max(torch::abs(poses.index({Slice(), Slice(None, 3), 3}))).item<float>();
poses.index({Slice(), Slice(None, 3), 3}) *= ret.scaleFactor;
ret.translation = std::get<1>(r);
ret.scale = std::get<2>(r);

for (size_t i = 0; i < ret.cameras.size(); i++){
ret.cameras[i].camToWorld = poses[i];
Expand All @@ -141,9 +140,7 @@ InputData inputDataFromColmap(const std::string &projectRoot){
PointSet *pSet = readPointSet(pointsPath.string());
torch::Tensor points = pSet->pointsTensor().clone();

ret.points.xyz = torch::matmul(torch::cat({points, torch::ones_like(points.index({"...", Slice(None, 1)}))}, -1),
ret.transformMatrix.transpose(0, 1));
ret.points.xyz *= ret.scaleFactor;
ret.points.xyz = (points - ret.translation) * ret.scale;
ret.points.rgb = pSet->colorsTensor().clone();

RELEASE_POINTSET(pSet);
Expand Down
4 changes: 2 additions & 2 deletions input_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ struct Points{
};
struct InputData{
std::vector<Camera> cameras;
float scaleFactor;
torch::Tensor transformMatrix;
float scale;
torch::Tensor translation;
Points points;

std::tuple<std::vector<Camera>, Camera *> getCameras(bool validate, const std::string &valImage = "random");
Expand Down
33 changes: 31 additions & 2 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,10 +498,10 @@ void Model::savePlySplat(const std::string &filename){

float zeros[] = { 0.0f, 0.0f, 0.0f };

torch::Tensor meansCpu = means.cpu();
torch::Tensor meansCpu = (means.cpu() / scale) + translation;
torch::Tensor featuresDcCpu = featuresDc.cpu();
torch::Tensor opacitiesCpu = opacities.cpu();
torch::Tensor scalesCpu = scales.cpu();
torch::Tensor scalesCpu = torch::log((torch::exp(scales.cpu()) / scale));
torch::Tensor quatsCpu = quats.cpu();

for (size_t i = 0; i < numPoints; i++) {
Expand All @@ -518,6 +518,35 @@ void Model::savePlySplat(const std::string &filename){
std::cout << "Wrote " << filename << std::endl;
}

void Model::saveDebugPly(const std::string &filename){
// A standard PLY
std::ofstream o(filename, std::ios::binary);
int numPoints = means.size(0);

o << "ply" << std::endl;
o << "format binary_little_endian 1.0" << std::endl;
o << "comment Generated by opensplat" << std::endl;
o << "element vertex " << numPoints << std::endl;
o << "property float x" << std::endl;
o << "property float y" << std::endl;
o << "property float z" << std::endl;
o << "property uchar red" << std::endl;
o << "property uchar green" << std::endl;
o << "property uchar blue" << std::endl;
o << "end_header" << std::endl;

torch::Tensor meansCpu = (means.cpu() / scale) + translation;
torch::Tensor rgbsCpu = (sh2rgb(featuresDc.cpu()) * 255.0f).toType(torch::kUInt8);

for (size_t i = 0; i < numPoints; i++) {
o.write(reinterpret_cast<const char *>(meansCpu[i].data_ptr()), sizeof(float) * 3);
o.write(reinterpret_cast<const char *>(rgbsCpu[i].data_ptr()), sizeof(uint8_t) * 3);
}

o.close();
std::cout << "Wrote " << filename << std::endl;
}

torch::Tensor Model::mainLoss(torch::Tensor &rgb, torch::Tensor &gt, float ssimWeight){
torch::Tensor ssimLoss = 1.0f - ssim.eval(rgb, gt);
torch::Tensor l1Loss = l1(rgb, gt);
Expand Down
18 changes: 13 additions & 5 deletions model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ torch::Tensor psnr(const torch::Tensor& rendered, const torch::Tensor& gt);
torch::Tensor l1(const torch::Tensor& rendered, const torch::Tensor& gt);

struct Model{
Model(const Points &points, int numCameras,
Model(const InputData &inputData, int numCameras,
int numDownscales, int resolutionSchedule, int shDegree, int shDegreeInterval,
int refineEvery, int warmupLength, int resetAlphaEvery, int stopSplitAt, float densifyGradThresh, float densifySizeThresh, int stopScreenSizeAt, float splitScreenSize,
int maxSteps,
Expand All @@ -30,17 +30,21 @@ struct Model{
refineEvery(refineEvery), warmupLength(warmupLength), resetAlphaEvery(resetAlphaEvery), stopSplitAt(stopSplitAt), densifyGradThresh(densifyGradThresh), densifySizeThresh(densifySizeThresh), stopScreenSizeAt(stopScreenSizeAt), splitScreenSize(splitScreenSize),
maxSteps(maxSteps),
device(device), ssim(11, 3){
long long numPoints = points.xyz.size(0);

long long numPoints = inputData.points.xyz.size(0);
scale = inputData.scale;
translation = inputData.translation;

torch::manual_seed(42);

means = points.xyz.to(device).requires_grad_();
scales = PointsTensor(points.xyz).scales().repeat({1, 3}).log().to(device).requires_grad_();
means = inputData.points.xyz.to(device).requires_grad_();
scales = PointsTensor(inputData.points.xyz).scales().repeat({1, 3}).log().to(device).requires_grad_();
quats = randomQuatTensor(numPoints).to(device).requires_grad_();

int dimSh = numShBases(shDegree);
torch::Tensor shs = torch::zeros({numPoints, dimSh, 3}, torch::TensorOptions().dtype(torch::kFloat32).device(device));

shs.index({Slice(), 0, Slice(None, 3)}) = rgb2sh(points.rgb.toType(torch::kFloat64) / 255.0).toType(torch::kFloat32);
shs.index({Slice(), 0, Slice(None, 3)}) = rgb2sh(inputData.points.rgb.toType(torch::kFloat64) / 255.0).toType(torch::kFloat32);
shs.index({Slice(), Slice(1, None), Slice(3, None)}) = 0.0f;

featuresDc = shs.index({Slice(), 0, Slice()}).to(device).requires_grad_();
Expand Down Expand Up @@ -78,6 +82,7 @@ struct Model{
int getDownscaleFactor(int step);
void afterTrain(int step);
void savePlySplat(const std::string &filename);
void saveDebugPly(const std::string &filename);
torch::Tensor mainLoss(torch::Tensor &rgb, torch::Tensor &gt, float ssimWeight);

void addToOptimizer(torch::optim::Adam *optimizer, const torch::Tensor &newParam, const torch::Tensor &idcs, int nSamples);
Expand Down Expand Up @@ -126,6 +131,9 @@ struct Model{
int stopScreenSizeAt;
float splitScreenSize;
int maxSteps;

float scale;
torch::Tensor translation;
};


Expand Down
14 changes: 5 additions & 9 deletions nerfstudio.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,10 @@ InputData inputDataFromNerfStudio(const std::string &projectRoot){

torch::Tensor unorientedPoses = posesFromTransforms(t);

auto r = autoOrientAndCenterPoses(unorientedPoses);
auto r = autoScaleAndCenterPoses(unorientedPoses);
torch::Tensor poses = std::get<0>(r);
ret.transformMatrix = std::get<1>(r);

ret.scaleFactor = 1.0f / torch::max(torch::abs(poses.index({Slice(), Slice(None, 3), 3}))).item<float>();
poses.index({Slice(), Slice(None, 3), 3}) *= ret.scaleFactor;
ret.translation = std::get<1>(r);
ret.scale = std::get<2>(r);

// aabbScale = [[-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]]

Expand All @@ -158,10 +156,8 @@ InputData inputDataFromNerfStudio(const std::string &projectRoot){
}

torch::Tensor points = pSet->pointsTensor().clone();

ret.points.xyz = torch::matmul(torch::cat({points, torch::ones_like(points.index({"...", Slice(None, 1)}))}, -1),
ret.transformMatrix.transpose(0, 1));
ret.points.xyz *= ret.scaleFactor;

ret.points.xyz = (points - ret.translation) * ret.scale;
ret.points.rgb = pSet->colorsTensor().clone();

RELEASE_POINTSET(pSet);
Expand Down
13 changes: 7 additions & 6 deletions opensplat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ int main(int argc, char *argv[]){
std::vector<Camera> cams = std::get<0>(t);
Camera *valCam = std::get<1>(t);

Model model(inputData.points,
cams.size(),
numDownscales, resolutionSchedule, shDegree, shDegreeInterval,
refineEvery, warmupLength, resetAlphaEvery, stopSplitAt, densifyGradThresh, densifySizeThresh, stopScreenSizeAt, splitScreenSize,
numIters,
device);
Model model(inputData,
cams.size(),
numDownscales, resolutionSchedule, shDegree, shDegreeInterval,
refineEvery, warmupLength, resetAlphaEvery, stopSplitAt, densifyGradThresh, densifySizeThresh, stopScreenSizeAt, splitScreenSize,
numIters,
device);

std::vector< size_t > camIndices( cams.size() );
std::iota( camIndices.begin(), camIndices.end(), 0 );
Expand Down Expand Up @@ -145,6 +145,7 @@ int main(int argc, char *argv[]){
}

model.savePlySplat(outputScene);
// model.saveDebugPly("debug.ply");

// Validate
if (valCam != nullptr){
Expand Down
8 changes: 7 additions & 1 deletion spherical_harmonics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@ int degFromSh(int numBases){
}
}

const double C0 = 0.28209479177387814;

torch::Tensor rgb2sh(const torch::Tensor &rgb){
// Converts from RGB values [0,1] to the 0th spherical harmonic coefficient
const double C0 = 0.28209479177387814;
return (rgb - 0.5) / C0;
}

torch::Tensor sh2rgb(const torch::Tensor &sh){
// Converts from 0th spherical harmonic coefficients to RGB values [0,1]
return (sh * C0) + 0.5;
}

#if defined(USE_HIP) || defined(USE_CUDA)

torch::Tensor SphericalHarmonics::forward(AutogradContext *ctx,
Expand Down
1 change: 1 addition & 0 deletions spherical_harmonics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using namespace torch::autograd;

int degFromSh(int numBases);
torch::Tensor rgb2sh(const torch::Tensor &rgb);
torch::Tensor sh2rgb(const torch::Tensor &sh);

#if defined(USE_HIP) || defined(USE_CUDA)

Expand Down
21 changes: 12 additions & 9 deletions tensor_math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,20 @@ torch::Tensor quatToRotMat(const torch::Tensor &quat){

}

std::tuple<torch::Tensor, torch::Tensor> autoOrientAndCenterPoses(const torch::Tensor &poses){
// Center at mean and orient up
std::tuple<torch::Tensor, torch::Tensor, float> autoScaleAndCenterPoses(const torch::Tensor &poses){
// Center at mean
torch::Tensor origins = poses.index({"...", Slice(None, 3), 3});
torch::Tensor translation = torch::mean(origins, 0);
torch::Tensor up = torch::mean(poses.index({Slice(), Slice(None, 3), 1}), 0);
up = up / up.norm();
torch::Tensor center = torch::mean(origins, 0);
origins -= center;

// Scale
float f = 1.0f / torch::max(torch::abs(origins)).item<float>();
origins *= f;

torch::Tensor rotation = rotationMatrix(up, torch::tensor({0, 0, 1}, torch::kFloat32));
torch::Tensor transform = torch::cat({rotation, torch::matmul(rotation, -translation.index({"...", None}))}, -1);
torch::Tensor orientedPoses = torch::matmul(transform, poses);
return std::make_tuple(orientedPoses, transform);
torch::Tensor transformedPoses = poses.clone();
transformedPoses.index_put_({"...", Slice(None, 3), 3}, origins);

return std::make_tuple(transformedPoses, center, f);
}


Expand Down
2 changes: 1 addition & 1 deletion tensor_math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <tuple>

torch::Tensor quatToRotMat(const torch::Tensor &quat);
std::tuple<torch::Tensor, torch::Tensor> autoOrientAndCenterPoses(const torch::Tensor &poses);
std::tuple<torch::Tensor, torch::Tensor, float> autoScaleAndCenterPoses(const torch::Tensor &poses);
torch::Tensor rotationMatrix(const torch::Tensor &a, const torch::Tensor &b);


Expand Down

0 comments on commit 521121d

Please sign in to comment.