Skip to content

Commit

Permalink
Flip cam_R_gcam, fix type issues in augmentations
Browse files Browse the repository at this point in the history
  • Loading branch information
AlanSavio25 committed Mar 26, 2024
1 parent 5cd4dd1 commit 9580558
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 19 deletions.
22 changes: 12 additions & 10 deletions maploc/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ def get_view(self, idx, scene, seq, name, seed, bbox_tile):
world_t_cam = self.data["t_c2w"][idx].numpy()

image = read_image(self.image_dirs[scene] / (name + self.image_ext))
image = (
torch.from_numpy(np.ascontiguousarray(image))
.permute(2, 0, 1)
.float()
.div_(255)
)

if self.cfg.force_camera_height is not None:
data["camera_height"] = torch.tensor(self.cfg.force_camera_height)
Expand All @@ -153,19 +159,19 @@ def get_view(self, idx, scene, seq, name, seed, bbox_tile):
world_T_tile = Transform2D.from_Rt(torch.eye(2), canvas.bbox.min_)
tile_T_cam = (world_T_tile.inv() @ world_T_cam2d).float()

image, valid, cam = self.process_image(image, cam, seed, cam_R_gcam)

# Map augmentations
if self.stage == "train":
if self.cfg.augmentation.rot90:
raster, tile_T_cam = random_rot90(raster, tile_T_cam, canvas.ppm)
if self.cfg.augmentation.flip:
image, valid, raster, tile_T_cam = random_flip(
image, valid, raster, tile_T_cam, canvas.ppm
image, raster, tile_T_cam, cam_R_gcam = random_flip(
image, raster, tile_T_cam, cam_R_gcam, canvas.ppm
)
map_T_cam = Transform2D.to_pixels(tile_T_cam, 1 / canvas.ppm)
# map_T_cam will be deprecated, tile_T_cam is sufficient.

image, valid, cam = self.process_image(image, cam, seed, cam_R_gcam)

# Spatial to memory layout
raster = torch.rot90(raster, -1, dims=(-2, -1))

Expand Down Expand Up @@ -208,19 +214,15 @@ def get_view(self, idx, scene, seq, name, seed, bbox_tile):
"camera": cam,
"canvas": canvas,
"map": raster,
"cam_R_gcam": cam_R_gcam,
"tile_T_cam": tile_T_cam,
"map_T_cam": map_T_cam,
"map_t_init": map_t_init,
"pixels_per_meter": torch.tensor(canvas.ppm).float(),
}

def process_image(self, image, cam, seed, cam_R_gcam):
image = (
torch.from_numpy(np.ascontiguousarray(image))
.permute(2, 0, 1)
.float()
.div_(255)
)

assert self.cfg.rectify_pitch
image, valid = rectify_image(image, cam, cam_R_gcam)

Expand Down
24 changes: 15 additions & 9 deletions maploc/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def crop_map(raster, xy, size, seed=None):
xy -= np.array([left, top])
return raster, xy


def random_rot90(
raster: torch.Tensor,
tile_T_cam: Transform2D,
Expand All @@ -28,33 +29,38 @@ def random_rot90(
raster = torch.rot90(raster, rot, dims=(-2, -1))

# Rotate the camera position around tile's center
map_t_center = np.array(raster.shape[-2:]) / 2.0
tile_t_center = Transform2D.from_pixels(map_t_center, 1 / pixels_per_meter)
map_t_center = torch.tensor(raster.shape[-2:]) / 2.0
tile_t_center = Transform2D.from_pixels(map_t_center, 1 / pixels_per_meter).float()
center_t_cam = tile_T_cam.t - tile_t_center
R = Transform2D.from_degrees(torch.Tensor([rot * 90]), torch.zeros(2)).float()
R = Transform2D.from_degrees(torch.tensor([rot * 90]), torch.zeros(2)).float()
center_t_rotcam = R @ center_t_cam.T.float()
tile_t_rotcam = center_t_rotcam.squeeze(0) + tile_t_center
tile_r_rotcam = (tile_T_cam.angle + rot * 90) % 360
tile_T_rotcam = Transform2D.from_degrees(tile_r_rotcam, tile_t_rotcam)
tile_T_rotcam = Transform2D.from_degrees(tile_r_rotcam, tile_t_rotcam).float()
return raster, tile_T_rotcam


def random_flip(
image: torch.Tensor,
valid: torch.Tensor,
raster: torch.Tensor,
tile_T_cam: Transform2D,
cam_R_gcam: torch.Tensor,
pixels_per_meter: float,
seed: int = None,
):
state = np.random.RandomState(seed)
if state.rand() > 0.5: # no flip
return image, valid, raster, tile_T_cam
return image, raster, tile_T_cam, cam_R_gcam

image = torch.flip(image, (-1,))
valid = torch.flip(valid, (-1,))

map_t_center = np.array(raster.shape[-2:]) / 2.0
# Flip cam_R_gcam
gcam_R_cam = cam_R_gcam.T
roll = torch.rad2deg(torch.arctan2(gcam_R_cam[1, 0], gcam_R_cam[0, 0]))
R = Rotation.from_euler("z", -2 * roll, degrees=True).as_matrix()
gcam_R_cam = torch.tensor(R) @ gcam_R_cam

map_t_center = torch.tensor(raster.shape[-2:]) / 2.0
tile_t_center = Transform2D.from_pixels(map_t_center, 1 / pixels_per_meter)
center_t_cam = tile_T_cam.t - tile_t_center
if state.rand() > 0.5: # flip x
Expand All @@ -67,7 +73,7 @@ def random_flip(
center_t_flipcam = center_t_cam * torch.tensor([1, -1])
tile_t_flipcam = center_t_flipcam + tile_t_center
tile_T_flipcam = Transform2D.from_degrees(tile_r_flipcam % 360, tile_t_flipcam)
return image, valid, raster, tile_T_flipcam
return image, raster, tile_T_flipcam.float(), gcam_R_cam.T


def decompose_rotmat(R_c2w):
Expand Down

0 comments on commit 9580558

Please sign in to comment.