diff --git a/gtsam/sfm/DsfTrackGenerator.cpp b/gtsam/sfm/DsfTrackGenerator.cpp index cf989f282c..0efda06295 100644 --- a/gtsam/sfm/DsfTrackGenerator.cpp +++ b/gtsam/sfm/DsfTrackGenerator.cpp @@ -130,7 +130,7 @@ std::vector tracksFromPairwiseMatches( } // TODO(johnwlambert): return the Transitivity failure percentage here. - return tracks2d; + return validTracks; } } // namespace gtsfm diff --git a/python/gtsam/tests/test_DsfTrackGenerator.py b/python/gtsam/tests/test_DsfTrackGenerator.py index be6aa0796f..618f8a8cc1 100644 --- a/python/gtsam/tests/test_DsfTrackGenerator.py +++ b/python/gtsam/tests/test_DsfTrackGenerator.py @@ -4,6 +4,7 @@ """ import unittest +from typing import Dict, List, Tuple import numpy as np from gtsam.gtsfm import Keypoints @@ -16,6 +17,29 @@ class TestDsfTrackGenerator(GtsamTestCase): """Tests for DsfTrackGenerator.""" + def test_generate_tracks_from_pairwise_matches_nontransitive( + self, + ) -> None: + """Tests DSF for non-transitive matches. + + Test will result in no tracks since nontransitive tracks are naively discarded by DSF. + """ + keypoints_list = get_dummy_keypoints_list() + nontransitive_matches_dict = get_nontransitive_matches() # contains one non-transitive track + + # For each image pair (i1,i2), we provide a (K,2) matrix + # of corresponding keypoint indices (k1,k2). + matches_dict = {} + for (i1,i2), corr_idxs in nontransitive_matches_dict.items(): + matches_dict[IndexPair(i1, i2)] = corr_idxs + + tracks = gtsam.gtsfm.tracksFromPairwiseMatches( + matches_dict, + keypoints_list, + verbose=True, + ) + self.assertEqual(len(tracks), 0, "Tracks not filtered correctly") + def test_track_generation(self) -> None: """Ensures that DSF generates three tracks from measurements in 3 images (H=200,W=400).""" @@ -29,7 +53,7 @@ def test_track_generation(self) -> None: keypoints_list.append(kps_i2) # For each image pair (i1,i2), we provide a (K,2) matrix - # of corresponding image indices (k1,k2). + # of corresponding keypoint indices (k1,k2). matches_dict = {} matches_dict[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]]) matches_dict[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]]) @@ -93,5 +117,73 @@ def test_sfm_track_2d_constructor(self) -> None: assert track.numberMeasurements() == 1 +def get_dummy_keypoints_list() -> List[Keypoints]: + """ """ + img1_kp_coords = np.array([[1, 1], [2, 2], [3, 3.]]) + img1_kp_scale = np.array([6.0, 9.0, 8.5]) + img2_kp_coords = np.array( + [ + [1, 1.], + [2, 2], + [3, 3], + [4, 4], + [5, 5], + [6, 6], + [7, 7], + [8, 8], + ] + ) + img3_kp_coords = np.array( + [ + [1, 1.], + [2, 2], + [3, 3], + [4, 4], + [5, 5], + [6, 6], + [7, 7], + [8, 8], + [9, 9], + [10, 10], + ] + ) + img4_kp_coords = np.array( + [ + [1, 1.], + [2, 2], + [3, 3], + [4, 4], + [5, 5], + ] + ) + keypoints_list = [ + Keypoints(coordinates=img1_kp_coords), + Keypoints(coordinates=img2_kp_coords), + Keypoints(coordinates=img3_kp_coords), + Keypoints(coordinates=img4_kp_coords), + ] + return keypoints_list + + +def get_nontransitive_matches() -> Dict[Tuple[int, int], np.ndarray]: + """Set up correspondences for each (i1,i2) pair that violates transitivity. + + (i=0, k=0) (i=0, k=1) + | \\ | + | \\ | + (i=1, k=2)--(i=2,k=3)--(i=3, k=4) + + Transitivity is violated due to the match between frames 0 and 3. + """ + nontransitive_matches_dict = { + (0, 1): np.array([[0, 2]]), + (1, 2): np.array([[2, 3]]), + (0, 2): np.array([[0, 3]]), + (0, 3): np.array([[1, 4]]), + (2, 3): np.array([[3, 4]]), + } + return nontransitive_matches_dict + + if __name__ == "__main__": unittest.main()