Skip to content

Commit

Permalink
Merge pull request #8 from aalbaali/generalize-inverse-function
Browse files Browse the repository at this point in the history
generalize inverse function
  • Loading branch information
aalbaali committed Nov 8, 2023
2 parents e5d6671 + 31bdd1b commit 64b6538
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 17 deletions.
7 changes: 7 additions & 0 deletions .clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
Language: Cpp
BasedOnStyle: Google
ColumnLimit: 100
DerivePointerAlignment: false
PointerAlignment: Left
...
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ Refer to [`transforms_graph.h`](include/transforms_graph/transforms_graph.h) for
## Required types
The `TransformsGraph` requires two template parameters:
1. `Transform` type. This is the type that stores the transforms themselves (i.e., poses). This type should have the following defined:
- A default constructor that sets the transform to the identity transform
- A valid Multiplication operator `*` operator (i.e., `T1 = T2 * T2`)
- A valid `Transform inverse() const` method (i.e., `T1.inverse() * T1` should return an identity transform)
- An output stream operator `<<` (i.e., `std::cout << T`)
- A default constructor that sets the transform to the identity transform.
- A valid Multiplication operator `*` operator (i.e., `T1 = T2 * T2`).
- A valid `Transform inverse() const` method (i.e., `T1.inverse() * T1` should return an identity transform). Alternatively, a function `Transform inverse(Transform)` can be injected/passed to the `TransformsGraph` upon construction.
- An output stream operator `<<` (i.e., `std::cout << T`).
2. `Frame` type (default set to `char`). This is the type that keeps track of the frames. The type should have the following defined:
- Greater-than comparison operator `>` (i.e., `frame_i > frame_j`)
- An output stream operator `<<` (i.e., `std::cout << T`)
- Greater-than comparison operator `>` (i.e., `frame_i > frame_j`).
- An output stream operator `<<` (i.e., `std::cout << T`).
3. `Inv` function object (defaults to `Transform::inverse`). The function object is passed in the constructor.

The classes from [Sophus](https://github.com/strasdat/Sophus) (e.g., `Sophus::SE2d`) and [Eigen](https://eigen.tuxfamily.org/dox/group__TutorialGeometry.html) (e.g., `Eigen::Affine2d`) already satisfy the `Transform` requirements, except for the output stream operator `<<` requirement.

Expand Down
2 changes: 1 addition & 1 deletion examples/eigen_pose_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Pose {

Pose inverse() const {
Pose p;
p.pose_ = std::move(pose_.inverse());
p.pose_ = pose_.inverse();
return p;
}
Eigen::Affine2d Affine() const { return pose_; }
Expand Down
4 changes: 2 additions & 2 deletions examples/minimal_graph_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class Displacement {
Displacement() { d_ = 0; }
Displacement(double d) : d_(d) {}
double x() const { return d_; }
Displacement inverse() const { return Displacement(-d_); }

private:
double d_;
Expand All @@ -43,7 +42,8 @@ int main(int argc, char* argv[]) {
using Transform = Displacement;

// Construct a graph that consists of two unconnected subgraphs
tg::TransformsGraph<Transform, Frame> transforms;
auto displacement_inverse = [](const Transform& t) -> Transform { return -t.x(); };
tg::TransformsGraph<Transform, Frame> transforms(100, displacement_inverse);
transforms.AddTransform('a', 'b', 1);
transforms.AddTransform('a', 'c', 2);
transforms.AddTransform('b', 'd', 3);
Expand Down
6 changes: 4 additions & 2 deletions include/transforms_graph/graph_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ namespace tg {
*/
template <typename Node, typename Graph = std::unordered_map<Node, std::unordered_set<Node>>>
std::vector<Node> DFS(const Graph& graph, Node start, Node end) {
std::vector<Node> path;
std::unordered_set<Node> visited;
std::unordered_map<Node, Node> parent;
std::stack<Node> stack;
Expand All @@ -43,6 +42,8 @@ std::vector<Node> DFS(const Graph& graph, Node start, Node end) {
found_solution = true;
break;
}

// Mark as visited
if (visited.count(current)) continue;
visited.insert(current);

Expand All @@ -55,7 +56,8 @@ std::vector<Node> DFS(const Graph& graph, Node start, Node end) {

if (!found_solution) return {};

// Found the end
// Get path from start -> end
std::vector<Node> path;
Node node = end;
while (node != start) {
path.push_back(node);
Expand Down
28 changes: 22 additions & 6 deletions include/transforms_graph/transforms_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ namespace tg {
* @brief Transform graph class
* @tparam Transform Transform/pose type. Should have `*`, `<<`, and `inverse()` defined
* @tparam Frame Frame type (e.g. char, int, etc.). Should have `<<` defined.
* @tparam Inv Function object, where the instance takes an argument of type Transform and returns
* an inverse Transform.
*/
template <typename Transform, typename Frame = char>
template <typename Transform, typename Frame = char,
typename Inv = std::function<Transform(Transform)>>
class TransformsGraph {
public:
/** Id/key used to store transforms in the Transforms map */
Expand All @@ -32,6 +35,9 @@ class TransformsGraph {
/** Raw transforms */
using Transforms = std::unordered_map<TransformId, Transform>;

/** Inverse function object */
using TransformInverse = Inv;

/** Adjacency matrix of adjacent frames */
using AdjacentFrames = std::unordered_map<Frame, std::unordered_set<Frame>>;

Expand All @@ -45,8 +51,13 @@ class TransformsGraph {
* @details The maximum number of frames cannot be changed after construction
*
* @param[in] max_frames Maximum number of frames allowed in the graph
* @param[in] transform_inverse Function object to invert a transform. Should take a
* transform as an argument and return its inverse.
*/
TransformsGraph(int max_frames = 100) : max_frames_(max_frames) {}
TransformsGraph(int max_frames = 100,
TransformInverse transform_inverse = std::bind(&Transform::inverse,
std::placeholders::_1))
: max_frames_(max_frames), transform_inverse_(transform_inverse) {}

/**
* @brief Get maximum number allowed in the transform graph
Expand Down Expand Up @@ -130,7 +141,7 @@ class TransformsGraph {
const auto transform_id = ComputeTransformId(parent, child);
auto transform = raw_transforms_.at(transform_id);

return ShouldInvertFrames(parent, child) ? transform.inverse() : transform;
return ShouldInvertFrames(parent, child) ? transform_inverse_(transform) : transform;
}

/**
Expand Down Expand Up @@ -160,7 +171,7 @@ class TransformsGraph {
const auto transform_id = ComputeTransformId(prev, frame);
auto T_prev_curr = raw_transforms_.at(transform_id);
if (ShouldInvertFrames(prev, frame)) {
T_prev_curr = T_prev_curr.inverse();
T_prev_curr = transform_inverse_(T_prev_curr);
}

T_parent_child = T_parent_child * T_prev_curr;
Expand Down Expand Up @@ -354,7 +365,8 @@ class TransformsGraph {
throw std::runtime_error("Transform does not exist in the graph");
}
const auto transform_id = ComputeTransformId(parent, child);
raw_transforms_[transform_id] = ShouldInvertFrames(parent, child) ? pose.inverse() : pose;
raw_transforms_[transform_id] =
ShouldInvertFrames(parent, child) ? transform_inverse_(pose) : pose;
}

/**
Expand Down Expand Up @@ -457,7 +469,8 @@ class TransformsGraph {
if (!HasFrame(child)) AddFrame(child);

const auto transform_id = ComputeTransformId(parent, child);
raw_transforms_[transform_id] = ShouldInvertFrames(parent, child) ? pose.inverse() : pose;
raw_transforms_[transform_id] =
ShouldInvertFrames(parent, child) ? transform_inverse_(pose) : pose;
}

/**
Expand Down Expand Up @@ -508,6 +521,9 @@ class TransformsGraph {
/** Maximum number of frames expected to be in the graph */
int max_frames_ = 100;

/** Function object to invert transform */
TransformInverse transform_inverse_;

/** Acyclic graph where the vertices are the frames and the edges are transforms between the two
* frames */
AdjacentFrames adjacent_frames_;
Expand Down

0 comments on commit 64b6538

Please sign in to comment.