Skip to content

Commit

Permalink
workaround eigen bug with matrix inversion
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahdhn committed Sep 24, 2024
1 parent f1fec89 commit db0b94a
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
42 changes: 42 additions & 0 deletions include/rxmesh/util/inverse.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once

#include <cuda.h>
#include "rxmesh/types.h"

#include <Eigen/Dense>

namespace rxmesh {
/**
* @brief since 3x3 matrix inverse in eigen is buggy on device (it results into
* "unspecified launch failure"), we convert eigen matrix into glm, inverse it,
* then convert it back to eigen matrix
* @tparam T the floating point type of the matrix
* @tparam n the size of the matrix, expected/tested sizes are 2,3, and 4.
*/

template <typename T, int n>
__device__ __host__ __inline__ Eigen::Matrix<T, n, n> inverse(
const Eigen::Matrix<T, n, n>& in)
{
glm::mat<n, n, T, glm::defaultp> glm_mat;

for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
glm_mat[i][j] = in(i, j);
}
}

auto glm_inv = glm::inverse(glm_mat);

Eigen::Matrix<T, n, n> eig_inv;

for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
eig_inv(i, j) = glm_inv[i][j];
}
}
return eig_inv;
}


} // namespace rxmesh
6 changes: 5 additions & 1 deletion tests/RXMesh_test/test_scalar.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "rxmesh/diff/scalar.h"

#include "rxmesh/rxmesh_static.h"
#include "rxmesh/util/inverse.h"


#define RX_ASSERT_NEAR(val1, val2, eps, d_err) \
if (abs(val1 - val2) > eps) { \
Expand Down Expand Up @@ -749,7 +751,9 @@ __global__ static void test_min_quadratic(int* d_err, T eps = 1e-9)
2.0 * x[1] + 6.0 * x[2] + 10;

// Solve for minimum
const Eigen::Vector<T, 3> x_min = -f.Hess.inverse() * f.grad;
typename Real3::HessType f_hess_inv = inverse(f.Hess);

const Eigen::Vector<T, 3> x_min = -f_hess_inv * f.grad;

RX_ASSERT_NEAR(x_min.x(), -0.5, eps, d_err);
RX_ASSERT_NEAR(x_min.y(), 0.5, eps, d_err);
Expand Down

0 comments on commit db0b94a

Please sign in to comment.