diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..1bbc356 --- /dev/null +++ b/.clang-format @@ -0,0 +1,7 @@ +--- +Language: Cpp +BasedOnStyle: Google +ColumnLimit: 100 +DerivePointerAlignment: false +PointerAlignment: Left +... diff --git a/README.md b/README.md index b8e9a8f..8b8331b 100644 --- a/README.md +++ b/README.md @@ -34,13 +34,14 @@ Refer to [`transforms_graph.h`](include/transforms_graph/transforms_graph.h) for ## Required types The `TransformsGraph` requires two template parameters: 1. `Transform` type. This is the type that stores the transforms themselves (i.e., poses). This type should have the following defined: - - A default constructor that sets the transform to the identity transform - - A valid Multiplication operator `*` operator (i.e., `T1 = T2 * T2`) - - A valid `Transform inverse() const` method (i.e., `T1.inverse() * T1` should return an identity transform) - - An output stream operator `<<` (i.e., `std::cout << T`) + - A default constructor that sets the transform to the identity transform. + - A valid Multiplication operator `*` operator (i.e., `T1 = T2 * T2`). + - A valid `Transform inverse() const` method (i.e., `T1.inverse() * T1` should return an identity transform). Alternatively, a function `Transform inverse(Transform)` can be injected/passed to the `TransformsGraph` upon construction. + - An output stream operator `<<` (i.e., `std::cout << T`). 2. `Frame` type (default set to `char`). This is the type that keeps track of the frames. The type should have the following defined: - - Greater-than comparison operator `>` (i.e., `frame_i > frame_j`) - - An output stream operator `<<` (i.e., `std::cout << T`) + - Greater-than comparison operator `>` (i.e., `frame_i > frame_j`). + - An output stream operator `<<` (i.e., `std::cout << T`). +3. `Inv` function object (defaults to `Transform::inverse`). The function object is passed in the constructor. The classes from [Sophus](https://github.com/strasdat/Sophus) (e.g., `Sophus::SE2d`) and [Eigen](https://eigen.tuxfamily.org/dox/group__TutorialGeometry.html) (e.g., `Eigen::Affine2d`) already satisfy the `Transform` requirements, except for the output stream operator `<<` requirement. diff --git a/examples/eigen_pose_example.cpp b/examples/eigen_pose_example.cpp index c283213..64c34e8 100644 --- a/examples/eigen_pose_example.cpp +++ b/examples/eigen_pose_example.cpp @@ -29,7 +29,7 @@ class Pose { Pose inverse() const { Pose p; - p.pose_ = std::move(pose_.inverse()); + p.pose_ = pose_.inverse(); return p; } Eigen::Affine2d Affine() const { return pose_; } diff --git a/examples/minimal_graph_example.cpp b/examples/minimal_graph_example.cpp index a950ddd..e7574ec 100644 --- a/examples/minimal_graph_example.cpp +++ b/examples/minimal_graph_example.cpp @@ -20,7 +20,6 @@ class Displacement { Displacement() { d_ = 0; } Displacement(double d) : d_(d) {} double x() const { return d_; } - Displacement inverse() const { return Displacement(-d_); } private: double d_; @@ -43,7 +42,8 @@ int main(int argc, char* argv[]) { using Transform = Displacement; // Construct a graph that consists of two unconnected subgraphs - tg::TransformsGraph transforms; + auto displacement_inverse = [](const Transform& t) -> Transform { return -t.x(); }; + tg::TransformsGraph transforms(100, displacement_inverse); transforms.AddTransform('a', 'b', 1); transforms.AddTransform('a', 'c', 2); transforms.AddTransform('b', 'd', 3); diff --git a/include/transforms_graph/graph_search.h b/include/transforms_graph/graph_search.h index ece14b3..17acdc6 100644 --- a/include/transforms_graph/graph_search.h +++ b/include/transforms_graph/graph_search.h @@ -27,7 +27,6 @@ namespace tg { */ template >> std::vector DFS(const Graph& graph, Node start, Node end) { - std::vector path; std::unordered_set visited; std::unordered_map parent; std::stack stack; @@ -43,6 +42,8 @@ std::vector DFS(const Graph& graph, Node start, Node end) { found_solution = true; break; } + + // Mark as visited if (visited.count(current)) continue; visited.insert(current); @@ -55,7 +56,8 @@ std::vector DFS(const Graph& graph, Node start, Node end) { if (!found_solution) return {}; - // Found the end + // Get path from start -> end + std::vector path; Node node = end; while (node != start) { path.push_back(node); diff --git a/include/transforms_graph/transforms_graph.h b/include/transforms_graph/transforms_graph.h index df17b47..43defc0 100644 --- a/include/transforms_graph/transforms_graph.h +++ b/include/transforms_graph/transforms_graph.h @@ -22,8 +22,11 @@ namespace tg { * @brief Transform graph class * @tparam Transform Transform/pose type. Should have `*`, `<<`, and `inverse()` defined * @tparam Frame Frame type (e.g. char, int, etc.). Should have `<<` defined. + * @tparam Inv Function object, where the instance takes an argument of type Transform and returns + * an inverse Transform. */ -template +template > class TransformsGraph { public: /** Id/key used to store transforms in the Transforms map */ @@ -32,6 +35,9 @@ class TransformsGraph { /** Raw transforms */ using Transforms = std::unordered_map; + /** Inverse function object */ + using TransformInverse = Inv; + /** Adjacency matrix of adjacent frames */ using AdjacentFrames = std::unordered_map>; @@ -45,8 +51,13 @@ class TransformsGraph { * @details The maximum number of frames cannot be changed after construction * * @param[in] max_frames Maximum number of frames allowed in the graph + * @param[in] transform_inverse Function object to invert a transform. Should take a + * transform as an argument and return its inverse. */ - TransformsGraph(int max_frames = 100) : max_frames_(max_frames) {} + TransformsGraph(int max_frames = 100, + TransformInverse transform_inverse = std::bind(&Transform::inverse, + std::placeholders::_1)) + : max_frames_(max_frames), transform_inverse_(transform_inverse) {} /** * @brief Get maximum number allowed in the transform graph @@ -130,7 +141,7 @@ class TransformsGraph { const auto transform_id = ComputeTransformId(parent, child); auto transform = raw_transforms_.at(transform_id); - return ShouldInvertFrames(parent, child) ? transform.inverse() : transform; + return ShouldInvertFrames(parent, child) ? transform_inverse_(transform) : transform; } /** @@ -160,7 +171,7 @@ class TransformsGraph { const auto transform_id = ComputeTransformId(prev, frame); auto T_prev_curr = raw_transforms_.at(transform_id); if (ShouldInvertFrames(prev, frame)) { - T_prev_curr = T_prev_curr.inverse(); + T_prev_curr = transform_inverse_(T_prev_curr); } T_parent_child = T_parent_child * T_prev_curr; @@ -354,7 +365,8 @@ class TransformsGraph { throw std::runtime_error("Transform does not exist in the graph"); } const auto transform_id = ComputeTransformId(parent, child); - raw_transforms_[transform_id] = ShouldInvertFrames(parent, child) ? pose.inverse() : pose; + raw_transforms_[transform_id] = + ShouldInvertFrames(parent, child) ? transform_inverse_(pose) : pose; } /** @@ -457,7 +469,8 @@ class TransformsGraph { if (!HasFrame(child)) AddFrame(child); const auto transform_id = ComputeTransformId(parent, child); - raw_transforms_[transform_id] = ShouldInvertFrames(parent, child) ? pose.inverse() : pose; + raw_transforms_[transform_id] = + ShouldInvertFrames(parent, child) ? transform_inverse_(pose) : pose; } /** @@ -508,6 +521,9 @@ class TransformsGraph { /** Maximum number of frames expected to be in the graph */ int max_frames_ = 100; + /** Function object to invert transform */ + TransformInverse transform_inverse_; + /** Acyclic graph where the vertices are the frames and the edges are transforms between the two * frames */ AdjacentFrames adjacent_frames_;