Skip to content

Commit

Permalink
Add back parallelized retriever (#721)
Browse files Browse the repository at this point in the history
* add back parallel retriever

* improve docs

* reformat

* swap loop order

* reformat

* add back repr

* allow no matches, then fix rank-2 to shape=(0,2) for coordinates

* add no retrieval variant

* remove max_frame_lookahead from loader

* remove lookahead from loader

* remove lookahead from loader

---------
  • Loading branch information
johnwlambert authored Sep 26, 2023
1 parent 82e4bba commit 201b762
Show file tree
Hide file tree
Showing 16 changed files with 136 additions and 125 deletions.
20 changes: 5 additions & 15 deletions gtsfm/frontend/cacher/image_matcher_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,36 +64,26 @@ def _save_result_to_cache(
data = {"keypoints_i1": keypoints_i1, "keypoints_i2": keypoints_i2}
io_utils.write_to_bz2_file(data, cache_path)

def match(
self,
image_i1: Image,
image_i2: Image,
) -> Tuple[Keypoints, Keypoints]:
def match(self, image_i1: Image, image_i2: Image) -> Tuple[Keypoints, Keypoints]:
"""Identify feature matches across two images.
If the results are in the cache, they are fetched and returned. Otherwise, the `match()` of the
underlying object's API is called and the results are cached.
Args:
image_i1: first input image of pair.
image_i2: second input image of pair.
image_i1: First input image of pair.
image_i2: Second input image of pair.
Returns:
Keypoints from image 1 (N keypoints will exist).
Corresponding keypoints from image 2 (there will also be N keypoints). These represent feature matches.
"""
cached_data = self._load_result_from_cache(
image_i1=image_i1,
image_i2=image_i2,
)
cached_data = self._load_result_from_cache(image_i1=image_i1, image_i2=image_i2)

if cached_data is not None:
return cached_data

keypoints_i1, keypoints_i2 = self._matcher.match(
image_i1=image_i1,
image_i2=image_i2,
)
keypoints_i1, keypoints_i2 = self._matcher.match(image_i1=image_i1, image_i2=image_i2)

self._save_result_to_cache(
image_i1=image_i1, image_i2=image_i2, keypoints_i1=keypoints_i1, keypoints_i2=keypoints_i2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def apply_image_matcher(
)

keypoints_list, putative_corr_idxs_dict = self._aggregator.aggregate(keypoints_dict=pairwise_correspondences)

return keypoints_list, putative_corr_idxs_dict

def generate_correspondences_and_estimate_two_view(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,7 @@ def aggregate(
# Have to merge keypoints across different views here (or turn off transitivity check).

for (i1, i2), (keypoints_i1, keypoints_i2) in keypoints_dict.items():
_assert_keypoints_rank(keypoints_i1)
_assert_keypoints_rank(keypoints_i2)

# NOTE: `Keypoints` coordinates with shape (0,2) are allowed here, when no matches are identified.
per_image_kpt_coordinates, i1_indices = self.append_unique_keypoints(
i=i1, keypoints=keypoints_i1, per_image_kpt_coordinates=per_image_kpt_coordinates
)
Expand All @@ -115,7 +113,7 @@ def aggregate(
# Reset global state.
self.duplicates_found = 0

keypoints_list: List[Keypoints] = [Keypoints(coordinates=np.array([]))] * (max_img_idx + 1)
keypoints_list: List[Keypoints] = [Keypoints(coordinates=np.zeros(shape=(0, 2)))] * (max_img_idx + 1)
for i in per_image_kpt_coordinates.keys():
keypoints_list[i] = Keypoints(coordinates=per_image_kpt_coordinates[i])
_assert_keypoints_rank(keypoints_list[i])
Expand Down
32 changes: 15 additions & 17 deletions gtsfm/frontend/global_descriptor/netvlad_global_descriptor.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
"""Wrapper around the NetVLAD global image descriptor.
Based on Arandjelovic16cvpr:
"NetVLAD: CNN architecture for weakly supervised place recognition"
https://arxiv.org/pdf/1511.07247.pdf
NetVLAD, is a new generalized VLAD layer, inspired by the “Vector of Locally Aggregated Descriptors”
image representation commonly used in image retrieval
Whereas bag-of-visual-words aggregation keeps counts of visual words, VLAD stores the sum of residuals
(difference vector between the descriptor and its corresponding cluster centre) for each visual word.
Authors: John Lambert, Travis Driver
"""

import numpy as np
import torch
from torch import nn

from gtsfm.common.image import Image
from gtsfm.frontend.global_descriptor.global_descriptor_base import GlobalDescriptorBase
from thirdparty.hloc.netvlad import NetVLAD
Expand All @@ -24,27 +20,29 @@
class NetVLADGlobalDescriptor(GlobalDescriptorBase):
"""NetVLAD global descriptor"""

def __init__(self, use_cuda: bool = True) -> None:
def __init__(self) -> None:
""" """
self._use_cuda = use_cuda
self._model: nn.Module = NetVLAD().eval()
pass

def describe(self, image: Image) -> np.ndarray:
"""Compute the NetVLAD global descriptor for a single image query.
Args:
image: input image.
image: Input image.
Returns:
img_desc: array of shape (D,) representing global image descriptor.
img_desc: Array of shape (D,) representing global image descriptor.
"""
device = torch.device("cuda" if self._use_cuda and torch.cuda.is_available() else "cpu")
self._model.to(device)

# Load model.
# Note: Initializing in the constructor leads to OOM.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model: nn.Module = NetVLAD().to(device)
model.eval()

img_tensor = (
torch.from_numpy(image.value_array).to(device).permute(2, 0, 1).unsqueeze(0).type(torch.float32) / 255
)
with torch.no_grad():
img_tensor = (
torch.from_numpy(image.value_array).permute(2, 0, 1).unsqueeze(0).type(torch.float32).to(device) / 255
)
img_desc = self._model({"image": img_tensor})
img_desc = model({"image": img_tensor})

return img_desc["global_descriptor"].detach().squeeze().cpu().numpy()
17 changes: 0 additions & 17 deletions gtsfm/loader/colmap_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(
images_dir: str,
use_gt_intrinsics: bool = True,
use_gt_extrinsics: bool = True,
max_frame_lookahead: int = 1,
max_resolution: int = 760,
) -> None:
"""Initializes to load from a specified folder on disk.
Expand All @@ -56,9 +55,6 @@ def __init__(
use_gt_intrinsics: Whether to use ground truth intrinsics. If COLMAP calibration is
not found on disk, then use_gt_intrinsics will be set to false automatically.
use_gt_extrinsics: Whether to use ground truth extrinsics.
max_frame_lookahead: Maximum number of consecutive frames to consider for
matching/co-visibility. Any value of max_frame_lookahead less than the size of
the dataset assumes data is sequentially captured
max_resolution: Integer representing maximum length of image's short side, i.e.
the smaller of the height/width of the image. e.g. for 1080p (1920 x 1080),
max_resolution would be 1080. If the image resolution max(height, width) is
Expand All @@ -67,7 +63,6 @@ def __init__(
super().__init__(max_resolution)
self._use_gt_intrinsics = use_gt_intrinsics
self._use_gt_extrinsics = use_gt_extrinsics
self._max_frame_lookahead = max_frame_lookahead

wTi_list, img_fnames, self._calibrations, _, _, _ = io_utils.read_scene_data_from_colmap_format(
colmap_files_dirpath
Expand Down Expand Up @@ -175,15 +170,3 @@ def get_camera_pose(self, index: int) -> Optional[Pose3]:

wTi = self._wTi_list[index]
return wTi

def is_valid_pair(self, idx1: int, idx2: int) -> bool:
"""Checks if (idx1, idx2) is a valid pair. idx1 < idx2 is required.
Args:
idx1: First index of the pair.
idx2: Second index of the pair.
Returns:
Validation result.
"""
return super().is_valid_pair(idx1, idx2) and abs(idx1 - idx2) <= self._max_frame_lookahead
6 changes: 4 additions & 2 deletions gtsfm/multi_view_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def create_computation_graph(
relative_pose_priors,
gt_wTi_list=gt_wTi_list,
)

init_cameras_graph = dask.delayed(init_cameras)(wTi_graph, all_intrinsics)

ba_input_graph, data_assoc_metrics_graph = self.data_association_module.create_computation_graph(
Expand Down Expand Up @@ -184,8 +185,9 @@ def init_cameras(

camera_class = gtsfm_types.get_camera_class_for_calibration(intrinsics_list[0])
for idx, (wTi) in enumerate(wTi_list):
if wTi is not None:
cameras[idx] = camera_class(wTi, intrinsics_list[idx])
if wTi is None:
continue
cameras[idx] = camera_class(wTi, intrinsics_list[idx])

return cameras

Expand Down
41 changes: 28 additions & 13 deletions gtsfm/retriever/joint_netvlad_sequential_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
Authors: John Lambert
"""

from pathlib import Path
from typing import List, Optional, Tuple

import dask
from dask.delayed import Delayed

import gtsfm.utils.logger as logger_utils
from gtsfm.loader.loader_base import LoaderBase
from gtsfm.retriever.netvlad_retriever import NetVLADRetriever
Expand All @@ -16,16 +18,18 @@


class JointNetVLADSequentialRetriever(RetrieverBase):
"""Note: this class contains no .run() method."""
"""Retriever that includes both sequential and retrieval links."""

def __init__(self, num_matched: int, min_score: float, max_frame_lookahead: int) -> None:
"""
"""Initializes sub-retrievers.
Args:
num_matched: number of K potential matches to provide per query. These are the top "K" matches per query.
num_matched: Number of K potential matches to provide per query. These are the top "K" matches per query.
min_score: Minimum allowed similarity score to accept a match.
max_frame_lookahead: maximum number of consecutive frames to consider for matching/co-visibility.
max_frame_lookahead: Maximum number of consecutive frames to consider for matching/co-visibility.
"""
super().__init__(matching_regime=ImageMatchingRegime.SEQUENTIAL_WITH_RETRIEVAL)
self._num_matched = num_matched
self._similarity_retriever = NetVLADRetriever(num_matched=num_matched, min_score=min_score)
self._seq_retriever = SequentialRetriever(max_frame_lookahead=max_frame_lookahead)

Expand All @@ -36,33 +40,44 @@ def __repr__(self) -> str:
Sequential retriever: {self._seq_retriever}
"""

def get_image_pairs(self, loader: LoaderBase, plots_output_dir: Optional[Path] = None) -> List[Tuple[int, int]]:
def create_computation_graph(self, loader: LoaderBase, plots_output_dir: Optional[Path] = None) -> Delayed:
"""Compute potential image pairs.
Args:
loader: Image loader. The length of this loader will provide the total number of images
for exhaustive global descriptor matching.
Return:
pair_indices: (i1,i2) image pairs.
"""
return self.get_image_pairs(loader=loader, plots_output_dir=plots_output_dir)

def get_image_pairs(self, loader: LoaderBase, plots_output_dir: Optional[Path] = None) -> Delayed:
"""Compute potential image pairs.
Args:
loader: image loader. The length of this loader will provide the total number of images
loader: Image loader. The length of this loader will provide the total number of images
for exhaustive global descriptor matching.
plots_output_dir: Directory to save plots to. If None, plots are not saved.
Return:
pair_indices: (i1,i2) image pairs.
"""
sim_pairs = self._similarity_retriever.get_image_pairs(loader, plots_output_dir=plots_output_dir)
seq_pairs = self._seq_retriever.get_image_pairs(loader)
sim_pairs = self._similarity_retriever.create_computation_graph(loader, plots_output_dir=plots_output_dir)
seq_pairs = self._seq_retriever.create_computation_graph(loader)

return self.aggregate_pairs(sim_pairs=sim_pairs, seq_pairs=seq_pairs)
return dask.delayed(self.aggregate_pairs)(sim_pairs=sim_pairs, seq_pairs=seq_pairs)

def aggregate_pairs(
self, sim_pairs: List[Tuple[int, int]], seq_pairs: List[Tuple[int, int]]
) -> List[Tuple[int, int]]:
"""Aggregate all image pair indices from both similarity-based and sequential retrieval.
Args:
sim_pairs: image pairs (i1,i2) from similarity-based retrieval.
seq_pairs: image pairs (i1,i2) from sequential retrieval.
sim_pairs: Image pairs (i1,i2) from similarity-based retrieval.
seq_pairs: Image pairs (i1,i2) from sequential retrieval.
Returns:
pairs: unique pairs (i1,i2) representing union of the input sets.
Unique pairs (i1,i2) representing union of the input sets.
"""
pairs = list(set(sim_pairs).union(set(seq_pairs)))
logger.info("Found %d pairs from the NetVLAD + Sequential Retriever.", len(pairs))
Expand Down
Loading

0 comments on commit 201b762

Please sign in to comment.