From 2f2d6546d1aa587a2da4e6ecc58f7c84d6ad03e1 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 2 Sep 2023 16:55:38 -0700 Subject: [PATCH] Fix 4.3 python style --- python/gtsam/tests/test_DsfTrackGenerator.py | 70 ++++++++++---------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/python/gtsam/tests/test_DsfTrackGenerator.py b/python/gtsam/tests/test_DsfTrackGenerator.py index 618f8a8cc1..a1027bebd7 100644 --- a/python/gtsam/tests/test_DsfTrackGenerator.py +++ b/python/gtsam/tests/test_DsfTrackGenerator.py @@ -4,14 +4,15 @@ """ import unittest -from typing import Dict, List, Tuple +from typing import Dict, Tuple import numpy as np from gtsam.gtsfm import Keypoints from gtsam.utils.test_case import GtsamTestCase import gtsam -from gtsam import IndexPair, Point2, SfmTrack2d +from gtsam import (IndexPair, KeypointsVector, MatchIndicesMap, Point2, + SfmMeasurementVector, SfmTrack2d) class TestDsfTrackGenerator(GtsamTestCase): @@ -22,20 +23,21 @@ def test_generate_tracks_from_pairwise_matches_nontransitive( ) -> None: """Tests DSF for non-transitive matches. - Test will result in no tracks since nontransitive tracks are naively discarded by DSF. + Test will result in no tracks since nontransitive tracks are naively + discarded by DSF. """ - keypoints_list = get_dummy_keypoints_list() - nontransitive_matches_dict = get_nontransitive_matches() # contains one non-transitive track + keypoints = get_dummy_keypoints_list() + nontransitive_matches = get_nontransitive_matches() # For each image pair (i1,i2), we provide a (K,2) matrix # of corresponding keypoint indices (k1,k2). - matches_dict = {} - for (i1,i2), corr_idxs in nontransitive_matches_dict.items(): - matches_dict[IndexPair(i1, i2)] = corr_idxs + matches = MatchIndicesMap() + for (i1, i2), correspondences in nontransitive_matches.items(): + matches[IndexPair(i1, i2)] = correspondences tracks = gtsam.gtsfm.tracksFromPairwiseMatches( - matches_dict, - keypoints_list, + matches, + keypoints, verbose=True, ) self.assertEqual(len(tracks), 0, "Tracks not filtered correctly") @@ -47,20 +49,20 @@ def test_track_generation(self) -> None: kps_i1 = Keypoints(np.array([[50.0, 60], [70, 80], [90, 100]])) kps_i2 = Keypoints(np.array([[110.0, 120], [130, 140]])) - keypoints_list = [] - keypoints_list.append(kps_i0) - keypoints_list.append(kps_i1) - keypoints_list.append(kps_i2) + keypoints = KeypointsVector() + keypoints.append(kps_i0) + keypoints.append(kps_i1) + keypoints.append(kps_i2) # For each image pair (i1,i2), we provide a (K,2) matrix - # of corresponding keypoint indices (k1,k2). - matches_dict = {} - matches_dict[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]]) - matches_dict[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]]) + # of corresponding image indices (k1,k2). + matches = MatchIndicesMap() + matches[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]]) + matches[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]]) tracks = gtsam.gtsfm.tracksFromPairwiseMatches( - matches_dict, - keypoints_list, + matches, + keypoints, verbose=False, ) assert len(tracks) == 3 @@ -110,17 +112,16 @@ class TestSfmTrack2d(GtsamTestCase): def test_sfm_track_2d_constructor(self) -> None: """Test construction of 2D SfM track.""" - measurements = [] + measurements = SfmMeasurementVector() measurements.append((0, Point2(10, 20))) track = SfmTrack2d(measurements=measurements) track.measurement(0) assert track.numberMeasurements() == 1 -def get_dummy_keypoints_list() -> List[Keypoints]: - """ """ +def get_dummy_keypoints_list() -> KeypointsVector: + """Generate a list of dummy keypoints for testing.""" img1_kp_coords = np.array([[1, 1], [2, 2], [3, 3.]]) - img1_kp_scale = np.array([6.0, 9.0, 8.5]) img2_kp_coords = np.array( [ [1, 1.], @@ -156,33 +157,32 @@ def get_dummy_keypoints_list() -> List[Keypoints]: [5, 5], ] ) - keypoints_list = [ - Keypoints(coordinates=img1_kp_coords), - Keypoints(coordinates=img2_kp_coords), - Keypoints(coordinates=img3_kp_coords), - Keypoints(coordinates=img4_kp_coords), - ] - return keypoints_list + keypoints = KeypointsVector() + keypoints.append(Keypoints(coordinates=img1_kp_coords)) + keypoints.append(Keypoints(coordinates=img2_kp_coords)) + keypoints.append(Keypoints(coordinates=img3_kp_coords)) + keypoints.append(Keypoints(coordinates=img4_kp_coords)) + return keypoints def get_nontransitive_matches() -> Dict[Tuple[int, int], np.ndarray]: """Set up correspondences for each (i1,i2) pair that violates transitivity. - + (i=0, k=0) (i=0, k=1) | \\ | | \\ | (i=1, k=2)--(i=2,k=3)--(i=3, k=4) - Transitivity is violated due to the match between frames 0 and 3. + Transitivity is violated due to the match between frames 0 and 3. """ - nontransitive_matches_dict = { + nontransitive_matches = { (0, 1): np.array([[0, 2]]), (1, 2): np.array([[2, 3]]), (0, 2): np.array([[0, 3]]), (0, 3): np.array([[1, 4]]), (2, 3): np.array([[3, 4]]), } - return nontransitive_matches_dict + return nontransitive_matches if __name__ == "__main__":