Skip to content

Commit

Permalink
Fix 4.3 python style
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert committed Sep 2, 2023
1 parent 12f919d commit 2f2d654
Showing 1 changed file with 35 additions and 35 deletions.
70 changes: 35 additions & 35 deletions python/gtsam/tests/test_DsfTrackGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
"""

import unittest
from typing import Dict, List, Tuple
from typing import Dict, Tuple

import numpy as np
from gtsam.gtsfm import Keypoints
from gtsam.utils.test_case import GtsamTestCase

import gtsam
from gtsam import IndexPair, Point2, SfmTrack2d
from gtsam import (IndexPair, KeypointsVector, MatchIndicesMap, Point2,
SfmMeasurementVector, SfmTrack2d)


class TestDsfTrackGenerator(GtsamTestCase):
Expand All @@ -22,20 +23,21 @@ def test_generate_tracks_from_pairwise_matches_nontransitive(
) -> None:
"""Tests DSF for non-transitive matches.
Test will result in no tracks since nontransitive tracks are naively discarded by DSF.
Test will result in no tracks since nontransitive tracks are naively
discarded by DSF.
"""
keypoints_list = get_dummy_keypoints_list()
nontransitive_matches_dict = get_nontransitive_matches() # contains one non-transitive track
keypoints = get_dummy_keypoints_list()
nontransitive_matches = get_nontransitive_matches()

# For each image pair (i1,i2), we provide a (K,2) matrix
# of corresponding keypoint indices (k1,k2).
matches_dict = {}
for (i1,i2), corr_idxs in nontransitive_matches_dict.items():
matches_dict[IndexPair(i1, i2)] = corr_idxs
matches = MatchIndicesMap()
for (i1, i2), correspondences in nontransitive_matches.items():
matches[IndexPair(i1, i2)] = correspondences

tracks = gtsam.gtsfm.tracksFromPairwiseMatches(
matches_dict,
keypoints_list,
matches,
keypoints,
verbose=True,
)
self.assertEqual(len(tracks), 0, "Tracks not filtered correctly")
Expand All @@ -47,20 +49,20 @@ def test_track_generation(self) -> None:
kps_i1 = Keypoints(np.array([[50.0, 60], [70, 80], [90, 100]]))
kps_i2 = Keypoints(np.array([[110.0, 120], [130, 140]]))

keypoints_list = []
keypoints_list.append(kps_i0)
keypoints_list.append(kps_i1)
keypoints_list.append(kps_i2)
keypoints = KeypointsVector()
keypoints.append(kps_i0)
keypoints.append(kps_i1)
keypoints.append(kps_i2)

# For each image pair (i1,i2), we provide a (K,2) matrix
# of corresponding keypoint indices (k1,k2).
matches_dict = {}
matches_dict[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]])
matches_dict[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]])
# of corresponding image indices (k1,k2).
matches = MatchIndicesMap()
matches[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]])
matches[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]])

tracks = gtsam.gtsfm.tracksFromPairwiseMatches(
matches_dict,
keypoints_list,
matches,
keypoints,
verbose=False,
)
assert len(tracks) == 3
Expand Down Expand Up @@ -110,17 +112,16 @@ class TestSfmTrack2d(GtsamTestCase):

def test_sfm_track_2d_constructor(self) -> None:
"""Test construction of 2D SfM track."""
measurements = []
measurements = SfmMeasurementVector()
measurements.append((0, Point2(10, 20)))
track = SfmTrack2d(measurements=measurements)
track.measurement(0)
assert track.numberMeasurements() == 1


def get_dummy_keypoints_list() -> List[Keypoints]:
""" """
def get_dummy_keypoints_list() -> KeypointsVector:
"""Generate a list of dummy keypoints for testing."""
img1_kp_coords = np.array([[1, 1], [2, 2], [3, 3.]])
img1_kp_scale = np.array([6.0, 9.0, 8.5])
img2_kp_coords = np.array(
[
[1, 1.],
Expand Down Expand Up @@ -156,33 +157,32 @@ def get_dummy_keypoints_list() -> List[Keypoints]:
[5, 5],
]
)
keypoints_list = [
Keypoints(coordinates=img1_kp_coords),
Keypoints(coordinates=img2_kp_coords),
Keypoints(coordinates=img3_kp_coords),
Keypoints(coordinates=img4_kp_coords),
]
return keypoints_list
keypoints = KeypointsVector()
keypoints.append(Keypoints(coordinates=img1_kp_coords))
keypoints.append(Keypoints(coordinates=img2_kp_coords))
keypoints.append(Keypoints(coordinates=img3_kp_coords))
keypoints.append(Keypoints(coordinates=img4_kp_coords))
return keypoints


def get_nontransitive_matches() -> Dict[Tuple[int, int], np.ndarray]:
"""Set up correspondences for each (i1,i2) pair that violates transitivity.
(i=0, k=0) (i=0, k=1)
| \\ |
| \\ |
(i=1, k=2)--(i=2,k=3)--(i=3, k=4)
Transitivity is violated due to the match between frames 0 and 3.
Transitivity is violated due to the match between frames 0 and 3.
"""
nontransitive_matches_dict = {
nontransitive_matches = {
(0, 1): np.array([[0, 2]]),
(1, 2): np.array([[2, 3]]),
(0, 2): np.array([[0, 3]]),
(0, 3): np.array([[1, 4]]),
(2, 3): np.array([[3, 4]]),
}
return nontransitive_matches_dict
return nontransitive_matches


if __name__ == "__main__":
Expand Down

0 comments on commit 2f2d654

Please sign in to comment.