Skip to content

Commit

Permalink
Apply black and isort
Browse files Browse the repository at this point in the history
  • Loading branch information
sarlinpe committed Feb 19, 2024
1 parent e0aaa57 commit 06e9a66
Show file tree
Hide file tree
Showing 28 changed files with 89 additions and 86 deletions.
10 changes: 7 additions & 3 deletions demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"# The highest accuracy is achieved with num_rotations=360\n",
"# but num_rotations=64~128 is often sufficient.\n",
"# To reduce the memory usage, we can reduce the tile size in the next cell.\n",
"demo = Demo(num_rotations=256, device='cpu')"
"demo = Demo(num_rotations=256, device=\"cpu\")"
]
},
{
Expand Down Expand Up @@ -135,19 +135,22 @@
"\n",
"# Show the query area in an interactive map\n",
"from maploc.osm.viz import GeoPlotter\n",
"\n",
"plot = GeoPlotter(zoom=16)\n",
"plot.points(prior_latlon[:2], \"red\", name=\"location prior\", size=10)\n",
"plot.bbox(proj.unproject(bbox), \"blue\", name=\"map tile\")\n",
"plot.fig.show(\"notebook\")\n",
"\n",
"# Query OpenStreetMap for this area\n",
"from maploc.osm.tiling import TileManager\n",
"\n",
"tiler = TileManager.from_bbox(proj, bbox + 10, demo.config.data.pixel_per_meter)\n",
"canvas = tiler.query(bbox)\n",
"\n",
"# Show the inputs to the model: image and raster map\n",
"from maploc.osm.viz import Colormap, plot_nodes\n",
"from maploc.utils.viz_2d import plot_images\n",
"\n",
"map_viz = Colormap.apply(canvas.raster)\n",
"plot_images([image, map_viz], titles=[\"input image\", \"OpenStreetMap raster\"])\n",
"plot_nodes(1, canvas.raster[2], fontsize=6, size=10)"
Expand Down Expand Up @@ -1186,15 +1189,16 @@
"\n",
"# Run the inference\n",
"uv, yaw, prob, neural_map, image_rectified = demo.localize(\n",
" image, camera, canvas, roll_pitch=gravity)\n",
" image, camera, canvas, roll_pitch=gravity\n",
")\n",
"\n",
"# Visualize the predictions\n",
"overlay = likelihood_overlay(prob.numpy().max(-1), map_viz.mean(-1, keepdims=True))\n",
"(neural_map_rgb,) = features_to_RGB(neural_map.numpy())\n",
"plot_images([overlay, neural_map_rgb], titles=[\"prediction\", \"neural map\"])\n",
"ax = plt.gcf().axes[0]\n",
"ax.scatter(*canvas.to_uv(bbox.center), s=5, c=\"red\")\n",
"plot_dense_rotations(ax, prob, w=0.005, s=1/25)\n",
"plot_dense_rotations(ax, prob, w=0.005, s=1 / 25)\n",
"add_circle_inset(ax, uv)\n",
"plt.show(\"notebook\")\n",
"\n",
Expand Down
3 changes: 1 addition & 2 deletions maploc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

from pathlib import Path
import logging
from pathlib import Path

import pytorch_lightning # noqa: F401


formatter = logging.Formatter(
fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
Expand Down
4 changes: 2 additions & 2 deletions maploc/data/image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

from typing import Callable, Optional, Union, Sequence
import collections
from typing import Callable, Optional, Sequence, Union

import numpy as np
import torch
import torchvision.transforms.functional as tvf
import collections
from scipy.spatial.transform import Rotation

from ..utils.geometry import from_homogeneous, to_homogeneous
Expand Down
4 changes: 2 additions & 2 deletions maploc/data/kitti/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from omegaconf import OmegaConf
from scipy.spatial.transform import Rotation

from ... import logger, DATASETS_PATH
from ... import DATASETS_PATH, logger
from ...osm.tiling import TileManager
from ..dataset import MapLocDataset
from ..sequential import chunk_sequence
from ..torch import collate, worker_init_fn
from .utils import parse_split_file, parse_gps_file, get_camera_calibration
from .utils import get_camera_calibration, parse_gps_file, parse_split_file


class KittiDataModule(pl.LightningDataModule):
Expand Down
6 changes: 3 additions & 3 deletions maploc/data/kitti/prepare.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import argparse
from pathlib import Path
import shutil
import zipfile
from pathlib import Path

import numpy as np
from tqdm.auto import tqdm
Expand All @@ -12,9 +12,9 @@
from ...osm.tiling import TileManager
from ...osm.viz import GeoPlotter
from ...utils.geo import BoundaryBox, Projection
from ...utils.io import download_file, DATA_URL
from .utils import parse_gps_file
from ...utils.io import DATA_URL, download_file
from .dataset import KittiDataModule
from .utils import parse_gps_file

split_files = ["test1_files.txt", "test2_files.txt", "train_files.txt"]

Expand Down
4 changes: 2 additions & 2 deletions maploc/data/mapillary/dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import json
from collections import defaultdict
import os
import shutil
import tarfile
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, Optional

Expand All @@ -14,7 +14,7 @@
import torch.utils.data as torchdata
from omegaconf import DictConfig, OmegaConf

from ... import logger, DATASETS_PATH
from ... import DATASETS_PATH, logger
from ...osm.tiling import TileManager
from ..dataset import MapLocDataset
from ..sequential import chunk_sequence
Expand Down
12 changes: 7 additions & 5 deletions maploc/data/mapillary/download.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import asyncio
import json
from pathlib import Path

import numpy as np
import httpx
import asyncio
from aiolimiter import AsyncLimiter
import numpy as np
import tqdm

from aiolimiter import AsyncLimiter
from opensfm.pygeometry import Camera, Pose
from opensfm.pymap import Shot

from ... import logger
from ...utils.geo import Projection


semaphore = asyncio.Semaphore(100) # number of parallel threads.
image_filename = "{image_id}.jpg"
info_filename = "{image_id}.json"


def retry(times, exceptions):
def decorator(func):
async def wrapper(*args, **kwargs):
Expand All @@ -30,9 +29,12 @@ async def wrapper(*args, **kwargs):
except exceptions:
attempt += 1
return await func(*args, **kwargs)

return wrapper

return decorator


class MapillaryDownloader:
image_fields = (
"id",
Expand Down
33 changes: 16 additions & 17 deletions maploc/data/mapillary/prepare.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import asyncio
import argparse
from collections import defaultdict
import asyncio
import json
import shutil
from collections import defaultdict
from pathlib import Path
from typing import List

import numpy as np
import cv2
from tqdm import tqdm
from tqdm.contrib.concurrent import thread_map
import numpy as np
from omegaconf import DictConfig, OmegaConf
from opensfm.pygeometry import Camera
from opensfm.pymap import Shot
from opensfm.undistort import (
perspective_camera_from_fisheye,
perspective_camera_from_perspective,
)
from tqdm import tqdm
from tqdm.contrib.concurrent import thread_map

from ... import logger
from ...osm.tiling import TileManager
from ...osm.viz import GeoPlotter
from ...utils.geo import BoundaryBox, Projection
from ...utils.io import write_json, download_file, DATA_URL
from ...utils.io import DATA_URL, download_file, write_json
from ..utils import decompose_rotmat
from .dataset import MapillaryDataModule
from .download import (
MapillaryDownloader,
fetch_image_infos,
fetch_images_pixels,
image_filename,
opensfm_shot_from_info,
)
from .utils import (
CameraUndistorter,
PanoramaUndistorter,
keyframe_selection,
perspective_camera_from_pano,
scale_camera,
CameraUndistorter,
PanoramaUndistorter,
undistort_shot,
)
from .download import (
MapillaryDownloader,
opensfm_shot_from_info,
image_filename,
fetch_image_infos,
fetch_images_pixels,
)
from .dataset import MapillaryDataModule


location_to_params = {
"sanfrancisco_soma": {
Expand Down
2 changes: 1 addition & 1 deletion maploc/data/mapillary/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import cv2
import numpy as np
from opensfm import features
from opensfm.pygeometry import Camera, compute_camera_mapping, Pose
from opensfm.pygeometry import Camera, Pose, compute_camera_mapping
from opensfm.pymap import Shot
from scipy.spatial.transform import Rotation

Expand Down
6 changes: 3 additions & 3 deletions maploc/data/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import os

import torch
from lightning_fabric.utilities.apply_func import move_data_to_device
from lightning_fabric.utilities.seed import pl_worker_init_function
from lightning_utilities.core.apply_func import apply_to_collection
from torch.utils.data import get_worker_info
from torch.utils.data._utils.collate import (
default_collate_err_msg_format,
np_str_obj_array_pattern,
)
from lightning_fabric.utilities.seed import pl_worker_init_function
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_fabric.utilities.apply_func import move_data_to_device


def collate(batch):
Expand Down
14 changes: 7 additions & 7 deletions maploc/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@

from typing import Optional, Tuple

import torch
import numpy as np
import torch

from . import logger
from .evaluation.run import resolve_checkpoint_path, pretrained_models
from .data.image import pad_image, rectify_image, resize_image
from .evaluation.run import pretrained_models, resolve_checkpoint_path
from .models.orienternet import OrienterNet
from .models.voting import fuse_gps, argmax_xyr
from .data.image import resize_image, pad_image, rectify_image
from .models.voting import argmax_xyr, fuse_gps
from .osm.raster import Canvas
from .utils.wrappers import Camera
from .utils.io import read_image
from .utils.geo import BoundaryBox, Projection
from .utils.exif import EXIF
from .utils.geo import BoundaryBox, Projection
from .utils.io import read_image
from .utils.wrappers import Camera

try:
from geopy.geocoders import Nominatim
Expand Down
3 changes: 1 addition & 2 deletions maploc/evaluation/kitti.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from pathlib import Path
from typing import Optional, Tuple

from omegaconf import OmegaConf, DictConfig
from omegaconf import DictConfig, OmegaConf

from .. import logger
from ..data import KittiDataModule
from .run import evaluate


default_cfg_single = OmegaConf.create({})
# For the sequential evaluation, we need to center the map around the GT location,
# since random offsets would accumulate and leave only the GT location with a valid mask.
Expand Down
3 changes: 1 addition & 2 deletions maploc/evaluation/mapillary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
from pathlib import Path
from typing import Optional, Tuple

from omegaconf import OmegaConf, DictConfig
from omegaconf import DictConfig, OmegaConf

from .. import logger
from ..conf import data as conf_data_dir
from ..data import MapillaryDataModule
from .run import evaluate


split_overrides = {
"val": {
"scenes": [
Expand Down
13 changes: 6 additions & 7 deletions maploc/evaluation/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,25 @@

import functools
from itertools import islice
from typing import Callable, Dict, Optional, Tuple
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple

import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf
from torchmetrics import MetricCollection
from pytorch_lightning import seed_everything
from torchmetrics import MetricCollection
from tqdm import tqdm

from .. import logger, EXPERIMENTS_PATH
from .. import EXPERIMENTS_PATH, logger
from ..data.torch import collate, unbatch_to_device
from ..models.voting import argmax_xyr, fuse_gps
from ..models.metrics import AngleError, LateralLongitudinalError, Location2DError
from ..models.sequential import GPSAligner, RigidAligner
from ..models.voting import argmax_xyr, fuse_gps
from ..module import GenericModule
from ..utils.io import download_file, DATA_URL
from .viz import plot_example_single, plot_example_sequential
from ..utils.io import DATA_URL, download_file
from .utils import write_dump

from .viz import plot_example_sequential, plot_example_single

pretrained_models = dict(
OrienterNet_MGL=("orienternet_mgl.ckpt", dict(num_rotations=256)),
Expand Down
10 changes: 5 additions & 5 deletions maploc/evaluation/viz.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import matplotlib.pyplot as plt
import numpy as np
import torch
import matplotlib.pyplot as plt

from ..osm.viz import Colormap, plot_nodes
from ..utils.io import write_torch_image
from ..utils.viz_2d import plot_images, features_to_RGB, save_plot
from ..utils.viz_2d import features_to_RGB, plot_images, save_plot
from ..utils.viz_localization import (
add_circle_inset,
likelihood_overlay,
plot_pose,
plot_dense_rotations,
add_circle_inset,
plot_pose,
)
from ..osm.viz import Colormap, plot_nodes


def plot_example_single(
Expand Down
Loading

0 comments on commit 06e9a66

Please sign in to comment.