Skip to content

Commit

Permalink
Merge pull request #1616 from borglab/cherrypick-transitivity-fix-dsf…
Browse files Browse the repository at this point in the history
…trackgenerator

Cherrypick commits for release/4.2, that include transitivity fix for DsfTrackGenerator
  • Loading branch information
dellaert authored Sep 3, 2023
2 parents 13c7daf + 2f2d654 commit 1a86944
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 14 deletions.
6 changes: 4 additions & 2 deletions gtsam/sfm/DsfTrackGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <algorithm>
#include <iostream>
#include <iomanip>

namespace gtsam {

Expand All @@ -38,7 +39,8 @@ static DSFMapIndexPair generateDSF(const MatchIndicesMap& matches) {
// Image pair is (i1,i2).
size_t i1 = pair_indices.first;
size_t i2 = pair_indices.second;
for (size_t k = 0; k < corr_indices.rows(); k++) {
size_t m = static_cast<size_t>(corr_indices.rows());
for (size_t k = 0; k < m; k++) {
// Measurement indices are found in a single matrix row, as (k1,k2).
size_t k1 = corr_indices(k, 0), k2 = corr_indices(k, 1);
// Unique key for DSF is (i,k), representing keypoint index in an image.
Expand Down Expand Up @@ -128,7 +130,7 @@ std::vector<SfmTrack2d> tracksFromPairwiseMatches(
}

// TODO(johnwlambert): return the Transitivity failure percentage here.
return tracks2d;
return validTracks;
}

} // namespace gtsfm
Expand Down
116 changes: 104 additions & 12 deletions python/gtsam/tests/test_DsfTrackGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,65 @@
"""

import unittest
from typing import Dict, Tuple

import gtsam
import numpy as np
from gtsam import (IndexPair, KeypointsVector, MatchIndicesMap, Point2,
SfmMeasurementVector, SfmTrack2d)
from gtsam.gtsfm import Keypoints
from gtsam.utils.test_case import GtsamTestCase

import gtsam
from gtsam import (IndexPair, KeypointsVector, MatchIndicesMap, Point2,
SfmMeasurementVector, SfmTrack2d)


class TestDsfTrackGenerator(GtsamTestCase):
"""Tests for DsfTrackGenerator."""

def test_generate_tracks_from_pairwise_matches_nontransitive(
self,
) -> None:
"""Tests DSF for non-transitive matches.
Test will result in no tracks since nontransitive tracks are naively
discarded by DSF.
"""
keypoints = get_dummy_keypoints_list()
nontransitive_matches = get_nontransitive_matches()

# For each image pair (i1,i2), we provide a (K,2) matrix
# of corresponding keypoint indices (k1,k2).
matches = MatchIndicesMap()
for (i1, i2), correspondences in nontransitive_matches.items():
matches[IndexPair(i1, i2)] = correspondences

tracks = gtsam.gtsfm.tracksFromPairwiseMatches(
matches,
keypoints,
verbose=True,
)
self.assertEqual(len(tracks), 0, "Tracks not filtered correctly")

def test_track_generation(self) -> None:
"""Ensures that DSF generates three tracks from measurements
in 3 images (H=200,W=400)."""
kps_i0 = Keypoints(np.array([[10.0, 20], [30, 40]]))
kps_i1 = Keypoints(np.array([[50.0, 60], [70, 80], [90, 100]]))
kps_i2 = Keypoints(np.array([[110.0, 120], [130, 140]]))

keypoints_list = KeypointsVector()
keypoints_list.append(kps_i0)
keypoints_list.append(kps_i1)
keypoints_list.append(kps_i2)
keypoints = KeypointsVector()
keypoints.append(kps_i0)
keypoints.append(kps_i1)
keypoints.append(kps_i2)

# For each image pair (i1,i2), we provide a (K,2) matrix
# of corresponding image indices (k1,k2).
matches_dict = MatchIndicesMap()
matches_dict[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]])
matches_dict[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]])
matches = MatchIndicesMap()
matches[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]])
matches[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]])

tracks = gtsam.gtsfm.tracksFromPairwiseMatches(
matches_dict,
keypoints_list,
matches,
keypoints,
verbose=False,
)
assert len(tracks) == 3
Expand Down Expand Up @@ -93,5 +119,71 @@ def test_sfm_track_2d_constructor(self) -> None:
assert track.numberMeasurements() == 1


def get_dummy_keypoints_list() -> KeypointsVector:
"""Generate a list of dummy keypoints for testing."""
img1_kp_coords = np.array([[1, 1], [2, 2], [3, 3.]])
img2_kp_coords = np.array(
[
[1, 1.],
[2, 2],
[3, 3],
[4, 4],
[5, 5],
[6, 6],
[7, 7],
[8, 8],
]
)
img3_kp_coords = np.array(
[
[1, 1.],
[2, 2],
[3, 3],
[4, 4],
[5, 5],
[6, 6],
[7, 7],
[8, 8],
[9, 9],
[10, 10],
]
)
img4_kp_coords = np.array(
[
[1, 1.],
[2, 2],
[3, 3],
[4, 4],
[5, 5],
]
)
keypoints = KeypointsVector()
keypoints.append(Keypoints(coordinates=img1_kp_coords))
keypoints.append(Keypoints(coordinates=img2_kp_coords))
keypoints.append(Keypoints(coordinates=img3_kp_coords))
keypoints.append(Keypoints(coordinates=img4_kp_coords))
return keypoints


def get_nontransitive_matches() -> Dict[Tuple[int, int], np.ndarray]:
"""Set up correspondences for each (i1,i2) pair that violates transitivity.
(i=0, k=0) (i=0, k=1)
| \\ |
| \\ |
(i=1, k=2)--(i=2,k=3)--(i=3, k=4)
Transitivity is violated due to the match between frames 0 and 3.
"""
nontransitive_matches = {
(0, 1): np.array([[0, 2]]),
(1, 2): np.array([[2, 3]]),
(0, 2): np.array([[0, 3]]),
(0, 3): np.array([[1, 4]]),
(2, 3): np.array([[3, 4]]),
}
return nontransitive_matches


if __name__ == "__main__":
unittest.main()

0 comments on commit 1a86944

Please sign in to comment.