diff --git a/tests/test_assets_reader.py b/tests/test_assets_reader.py new file mode 100644 index 0000000..06efd31 --- /dev/null +++ b/tests/test_assets_reader.py @@ -0,0 +1,62 @@ +"""Test titiler.stacapi.stac_reader functions.""" + +import json +import os +from unittest.mock import patch + +import pytest +from rio_tiler.io import Reader +from rio_tiler.models import ImageData + +from titiler.stacapi.assets_reader import AssetsReader +from titiler.stacapi.models import AssetInfo + +from .conftest import mock_rasterio_open + +item_file = os.path.join( + os.path.dirname(__file__), "fixtures", "20200307aC0853900w361030.json" +) +item_json = json.loads(open(item_file).read()) + + +def test_get_asset_info(): + """Test get_asset_info function""" + assets_reader = AssetsReader(item_json) + expected_asset_info = AssetInfo( + url=item_json["assets"]["cog"]["href"], + type=item_json["assets"]["cog"]["type"], + env={}, + ) + assert assets_reader._get_asset_info("cog") == expected_asset_info + + +def test_get_reader_any(): + """Test reader is rio_tiler.io.Reader""" + asset_info = AssetInfo(url="https://file.tif") + empty_stac_reader = AssetsReader({"bbox": [], "assets": []}) + assert empty_stac_reader._get_reader(asset_info) == Reader + + +@pytest.mark.xfail(reason="To be implemented.") +def test_get_reader_netcdf(): + """Test reader attribute is titiler.stacapi.XarrayReader""" + asset_info = AssetInfo(url="https://file.nc", type="application/netcdf") + empty_stac_reader = AssetsReader({"bbox": [], "assets": []}) + empty_stac_reader._get_reader(asset_info) + + +@pytest.mark.skip(reason="Too slow.") +@patch("rio_tiler.io.rasterio.rasterio") +def test_tile_cog(rio): + """Test tile function with COG asset.""" + rio.open = mock_rasterio_open + + with AssetsReader(item_json) as reader: + img = reader.tile(0, 0, 0, assets=["cog"]) + assert isinstance(img, ImageData) + + +@pytest.mark.skip(reason="To be implemented.") +def test_tile_netcdf(): + """Test tile function with netcdf asset.""" + pass diff --git a/tests/test_items.py b/tests/test_items.py deleted file mode 100644 index da96630..0000000 --- a/tests/test_items.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Test titiler.stacapi Item endpoints.""" - -import json -import os -from unittest.mock import patch - -import pystac -import pytest - -from .conftest import mock_rasterio_open - -item_json = os.path.join( - os.path.dirname(__file__), "fixtures", "20200307aC0853900w361030.json" -) - - -@patch("rio_tiler.io.rasterio.rasterio") -@patch("titiler.stacapi.dependencies.get_stac_item") -def test_stac_items(get_stac_item, rio, app): - """test STAC items endpoints.""" - rio.open = mock_rasterio_open - - with open(item_json, "r") as f: - get_stac_item.return_value = pystac.Item.from_dict(json.loads(f.read())) - - response = app.get( - "/collections/noaa-emergency-response/items/20200307aC0853900w361030/assets", - ) - assert response.status_code == 200 - assert response.json() == ["cog"] - - with pytest.warns(UserWarning): - response = app.get( - "/collections/noaa-emergency-response/items/20200307aC0853900w361030/info", - ) - assert response.status_code == 200 - assert response.json()["cog"] - - response = app.get( - "/collections/noaa-emergency-response/items/20200307aC0853900w361030/info", - params={"assets": "cog"}, - ) - assert response.status_code == 200 - assert response.json()["cog"] diff --git a/titiler/stacapi/assets_reader.py b/titiler/stacapi/assets_reader.py new file mode 100644 index 0000000..b41037b --- /dev/null +++ b/titiler/stacapi/assets_reader.py @@ -0,0 +1,221 @@ +"""titiler-stacapi Asset Reader.""" + +import warnings +from typing import Any, Dict, Optional, Sequence, Set, Type, Union + +import attr +import rasterio +from morecantile import TileMatrixSet +from rio_tiler.constants import WEB_MERCATOR_TMS, WGS84_CRS +from rio_tiler.errors import ( + AssetAsBandError, + ExpressionMixingWarning, + InvalidAssetName, + MissingAssets, + TileOutsideBounds, +) +from rio_tiler.io import Reader +from rio_tiler.io.base import BaseReader, MultiBaseReader +from rio_tiler.models import ImageData +from rio_tiler.tasks import multi_arrays +from rio_tiler.types import Indexes + +from titiler.stacapi.models import AssetInfo +from titiler.stacapi.settings import STACSettings + +stac_config = STACSettings() + +valid_types = { + "image/tiff; application=geotiff", + "image/tiff; application=geotiff; profile=cloud-optimized", + "image/tiff; profile=cloud-optimized; application=geotiff", + "image/vnd.stac.geotiff; cloud-optimized=true", + "image/tiff", + "image/x.geotiff", + "image/jp2", + "application/x-hdf5", + "application/x-hdf", + "application/vnd+zarr", + "application/x-netcdf", +} + + +@attr.s +class AssetsReader(MultiBaseReader): + """ + Asset reader for STAC items. + """ + + # bounds and assets are required + input: Any = attr.ib() + tms: TileMatrixSet = attr.ib(default=WEB_MERCATOR_TMS) + minzoom: int = attr.ib() + maxzoom: int = attr.ib() + + reader: Type[BaseReader] = attr.ib(default=Reader) + reader_options: Dict = attr.ib(factory=dict) + + ctx: Any = attr.ib(default=rasterio.Env) + + include_asset_types: Set[str] = attr.ib(default=valid_types) + + @minzoom.default + def _minzoom(self): + return self.tms.minzoom + + @maxzoom.default + def _maxzoom(self): + return self.tms.maxzoom + + def __attrs_post_init__(self): + """ + Post Init. + """ + # MultibaseReader includes the spatial mixin so these attributes are required to assert that the tile exists inside the bounds of the item + self.crs = WGS84_CRS # Per specification STAC items are in WGS84 + self.bounds = self.input["bbox"] + self.assets = list(self.input["assets"]) + + def _get_reader(self, asset_info: AssetInfo) -> Type[BaseReader]: + """Get Asset Reader.""" + asset_type = asset_info.get("type", None) + + if asset_type and asset_type in [ + "application/x-hdf5", + "application/x-hdf", + "application/vnd.zarr", + "application/x-netcdf", + "application/netcdf", + ]: + raise NotImplementedError("XarrayReader not yet implemented") + + return Reader + + def _get_asset_info(self, asset: str) -> AssetInfo: + """ + Validate asset names and return asset's info. + + Args: + asset (str): asset name. + + Returns: + AssetInfo: Asset info + + """ + if asset not in self.assets: + raise InvalidAssetName( + f"{asset} is not valid. Should be one of {self.assets}" + ) + + asset_info = self.input["assets"][asset] + + url = asset_info["href"] + if alternate := stac_config.alternate_url: + url = asset_info["alternate"][alternate]["href"] + + info = AssetInfo(url=url, env={}) + + if asset_info.get("type"): + info["type"] = asset_info["type"] + + # there is a file STAC extension for which `header_size` is the size of the header in the file + # if this value is present, we want to use the GDAL_INGESTED_BYTES_AT_OPEN env variable to read that many bytes at file open. + if header_size := asset_info.get("file:header_size"): + info["env"]["GDAL_INGESTED_BYTES_AT_OPEN"] = header_size # type: ignore + + if bands := asset_info.get("raster:bands"): + stats = [ + (b["statistics"]["minimum"], b["statistics"]["maximum"]) + for b in bands + if {"minimum", "maximum"}.issubset(b.get("statistics", {})) + ] + if len(stats) == len(bands): + info["dataset_statistics"] = stats + + return info + + def tile( # noqa: C901 + self, + tile_x: int, + tile_y: int, + tile_z: int, + assets: Union[Sequence[str], str] = (), + expression: Optional[str] = None, + asset_indexes: Optional[Dict[str, Indexes]] = None, # Indexes for each asset + asset_as_band: bool = False, + **kwargs: Any, + ) -> ImageData: + """Read and merge Wep Map tiles from multiple assets. + + Args: + tile_x (int): Tile's horizontal index. + tile_y (int): Tile's vertical index. + tile_z (int): Tile's zoom level index. + assets (sequence of str or str, optional): assets to fetch info from. + expression (str, optional): rio-tiler expression for the asset list (e.g. asset1/asset2+asset3). + asset_indexes (dict, optional): Band indexes for each asset (e.g {"asset1": 1, "asset2": (1, 2,)}). + kwargs (optional): Options to forward to the `self.reader.tile` method. + + Returns: + rio_tiler.models.ImageData: ImageData instance with data, mask and tile spatial info. + + """ + if not self.tile_exists(tile_x, tile_y, tile_z): + raise TileOutsideBounds( + f"Tile {tile_z}/{tile_x}/{tile_y} is outside image bounds" + ) + + if isinstance(assets, str): + assets = (assets,) + + if assets and expression: + warnings.warn( + "Both expression and assets passed; expression will overwrite assets parameter.", + ExpressionMixingWarning, + stacklevel=2, + ) + + if expression: + assets = self.parse_expression(expression, asset_as_band=asset_as_band) + + if not assets: + raise MissingAssets( + "assets must be passed either via `expression` or `assets` options." + ) + + # indexes comes from the bidx query-parameter. + # but for asset based backend we usually use asset_bidx option. + asset_indexes = asset_indexes or {} + + # We fall back to `indexes` if provided + indexes = kwargs.pop("indexes", None) + + def _reader(asset: str, *args: Any, **kwargs: Any) -> ImageData: + idx = asset_indexes.get(asset) or indexes # type: ignore + asset_info = self._get_asset_info(asset) + reader = self._get_reader(asset_info) + + with self.ctx(**asset_info.get("env", {})): + with reader( + asset_info["url"], tms=self.tms, **self.reader_options + ) as src: + if idx is not None: + kwargs.update({"indexes": idx}) + data = src.tile(*args, **kwargs) + + if asset_as_band: + if len(data.band_names) > 1: + raise AssetAsBandError( + "Can't use `asset_as_band` for multibands asset" + ) + data.band_names = [asset] + else: + data.band_names = [f"{asset}_{n}" for n in data.band_names] + + return data + + img = multi_arrays(assets, _reader, tile_x, tile_y, tile_z, **kwargs) + if expression: + return img.apply_expression(expression) + + return img diff --git a/titiler/stacapi/backend.py b/titiler/stacapi/backend.py index 35993c8..c7ad438 100644 --- a/titiler/stacapi/backend.py +++ b/titiler/stacapi/backend.py @@ -5,7 +5,6 @@ import attr import planetary_computer as pc -import rasterio from cachetools import TTLCache, cached from cachetools.keys import hashkey from cogeo_mosaic.backends import BaseBackend @@ -19,14 +18,12 @@ from rasterio.crs import CRS from rasterio.warp import transform, transform_bounds from rio_tiler.constants import WEB_MERCATOR_TMS, WGS84_CRS -from rio_tiler.errors import InvalidAssetName -from rio_tiler.io import Reader -from rio_tiler.io.base import BaseReader, MultiBaseReader from rio_tiler.models import ImageData from rio_tiler.mosaic import mosaic_reader -from rio_tiler.types import AssetInfo, BBox +from rio_tiler.types import BBox from urllib3 import Retry +from titiler.stacapi.assets_reader import AssetsReader from titiler.stacapi.settings import CacheSettings, RetrySettings, STACSettings from titiler.stacapi.utils import Timer @@ -35,86 +32,6 @@ stac_config = STACSettings() -@attr.s -class CustomSTACReader(MultiBaseReader): - """Simplified STAC Reader. - - Inputs should be in form of: - { - "id": "IAMASTACITEM", - "collection": "mycollection", - "bbox": (0, 0, 10, 10), - "assets": { - "COG": { - "href": "https://somewhereovertherainbow.io/cog.tif" - } - } - } - - """ - - input: Dict[str, Any] = attr.ib() - tms: TileMatrixSet = attr.ib(default=WEB_MERCATOR_TMS) - minzoom: int = attr.ib() - maxzoom: int = attr.ib() - - reader: Type[BaseReader] = attr.ib(default=Reader) - reader_options: Dict = attr.ib(factory=dict) - - ctx: Any = attr.ib(default=rasterio.Env) - - def __attrs_post_init__(self) -> None: - """Set reader spatial infos and list of valid assets.""" - self.bounds = self.input["bbox"] - self.crs = WGS84_CRS # Per specification STAC items are in WGS84 - self.assets = list(self.input["assets"]) - - @minzoom.default - def _minzoom(self): - return self.tms.minzoom - - @maxzoom.default - def _maxzoom(self): - return self.tms.maxzoom - - def _get_asset_info(self, asset: str) -> AssetInfo: - """Validate asset names and return asset's url. - - Args: - asset (str): STAC asset name. - - Returns: - str: STAC asset href. - - """ - if asset not in self.assets: - raise InvalidAssetName( - f"{asset} is not valid. Should be one of {self.assets}" - ) - - asset_info = self.input["assets"][asset] - - url = asset_info["href"] - if alternate := stac_config.alternate_url: - url = asset_info["alternate"][alternate]["href"] - - info = AssetInfo(url=url, env={}) - - if header_size := asset_info.get("file:header_size"): - info["env"]["GDAL_INGESTED_BYTES_AT_OPEN"] = header_size - - if bands := asset_info.get("raster:bands"): - stats = [ - (b["statistics"]["minimum"], b["statistics"]["maximum"]) - for b in bands - if {"minimum", "maximum"}.issubset(b.get("statistics", {})) - ] - if len(stats) == len(bands): - info["dataset_statistics"] = stats - - return info - - @attr.s class STACAPIBackend(BaseBackend): """STACAPI Mosaic Backend.""" @@ -128,8 +45,8 @@ class STACAPIBackend(BaseBackend): minzoom: int = attr.ib() maxzoom: int = attr.ib() - # Use Custom STAC reader (outside init) - reader: Type[CustomSTACReader] = attr.ib(init=False, default=CustomSTACReader) + # Use custom asset reader (outside init) + reader: Type[AssetsReader] = attr.ib(init=False, default=AssetsReader) reader_options: Dict = attr.ib(factory=dict) # default values for bounds diff --git a/titiler/stacapi/dependencies.py b/titiler/stacapi/dependencies.py index 9d81592..6d858a1 100644 --- a/titiler/stacapi/dependencies.py +++ b/titiler/stacapi/dependencies.py @@ -1,18 +1,10 @@ """titiler-stacapi dependencies.""" -import json from typing import Dict, List, Literal, Optional, TypedDict, get_args -import planetary_computer as pc -import pystac -from cachetools import TTLCache, cached -from cachetools.keys import hashkey -from fastapi import Depends, HTTPException, Path, Query -from pystac_client import ItemSearch -from pystac_client.stac_api_io import StacApiIO +from fastapi import Path, Query from starlette.requests import Request from typing_extensions import Annotated -from urllib3 import Retry from titiler.stacapi.enums import MediaType from titiler.stacapi.settings import CacheSettings, RetrySettings @@ -103,60 +95,6 @@ def STACApiParams( ) -@cached( # type: ignore - TTLCache(maxsize=cache_config.maxsize, ttl=cache_config.ttl), - key=lambda url, collection_id, item_id, headers, **kwargs: hashkey( - url, collection_id, item_id, json.dumps(headers) - ), -) -def get_stac_item( - url: str, - collection_id: str, - item_id: str, - headers: Optional[Dict] = None, -) -> pystac.Item: - """Get STAC Item from STAC API.""" - stac_api_io = StacApiIO( - max_retries=Retry( - total=retry_config.retry, - backoff_factor=retry_config.retry_factor, - ), - headers=headers, - ) - results = ItemSearch( - f"{url}/search", - stac_io=stac_api_io, - collections=[collection_id], - ids=[item_id], - modifier=pc.sign_inplace, - ) - items = list(results.items()) - if not items: - raise HTTPException( - 404, - f"Could not find Item {item_id} in {collection_id} collection.", - ) - - return items[0] - - -def ItemIdParams( - collection_id: Annotated[ - str, - Path(description="STAC Collection Identifier"), - ], - item_id: Annotated[str, Path(description="STAC Item Identifier")], - api_params=Depends(STACApiParams), -) -> pystac.Item: - """STAC Item dependency for the MultiBaseTilerFactory.""" - return get_stac_item( - api_params["api_url"], - collection_id, - item_id, - headers=api_params.get("headers", {}), - ) - - def STACSearchParams( request: Request, collection_id: Annotated[ diff --git a/titiler/stacapi/main.py b/titiler/stacapi/main.py index aa5fc30..6ccf0b9 100644 --- a/titiler/stacapi/main.py +++ b/titiler/stacapi/main.py @@ -12,17 +12,16 @@ from typing_extensions import Annotated from titiler.core.errors import DEFAULT_STATUS_CODES, add_exception_handlers -from titiler.core.factory import AlgorithmFactory, MultiBaseTilerFactory, TMSFactory +from titiler.core.factory import AlgorithmFactory, TMSFactory from titiler.core.middleware import CacheControlMiddleware, LoggerMiddleware from titiler.core.resources.enums import OptionalHeader from titiler.mosaic.errors import MOSAIC_STATUS_CODES from titiler.stacapi import __version__ as titiler_stacapi_version from titiler.stacapi import models -from titiler.stacapi.dependencies import ItemIdParams, OutputType, STACApiParams +from titiler.stacapi.dependencies import OutputType, STACApiParams from titiler.stacapi.enums import MediaType from titiler.stacapi.factory import MosaicTilerFactory from titiler.stacapi.settings import ApiSettings, STACAPISettings -from titiler.stacapi.stac_reader import STACReader from titiler.stacapi.utils import create_html_response settings = ApiSettings() @@ -99,25 +98,6 @@ collection.router, tags=["STAC Collection"], prefix="/collections/{collection_id}" ) -############################################################################### -# STAC Item Endpoints -# Notes: The `MultiBaseTilerFactory` from titiler.core.factory expect a `URL` as query parameter -# but in this project we use a custom `path_dependency=ItemIdParams`, which define `{collection_id}` and `{item_id}` as -# `Path` dependencies. Then the `ItemIdParams` dependency will fetch the STAC API endpoint to get the STAC Item. The Item -# will then be used in our custom `STACReader`. -stac = MultiBaseTilerFactory( - reader=STACReader, - path_dependency=ItemIdParams, - optional_headers=optional_headers, - router_prefix="/collections/{collection_id}/items/{item_id}", - add_viewer=True, -) -app.include_router( - stac.router, - tags=["STAC Item"], - prefix="/collections/{collection_id}/items/{item_id}", -) - ############################################################################### # Tiling Schemes Endpoints tms = TMSFactory() diff --git a/titiler/stacapi/models.py b/titiler/stacapi/models.py index dfcd3a4..7033206 100644 --- a/titiler/stacapi/models.py +++ b/titiler/stacapi/models.py @@ -6,7 +6,7 @@ """ -from typing import List, Optional +from typing import Dict, List, Optional, Sequence, Tuple, TypedDict from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -82,3 +82,13 @@ class Landing(BaseModel): title: Optional[str] = None description: Optional[str] = None links: List[Link] + + +class AssetInfo(TypedDict, total=False): + """Asset Reader Options.""" + + url: str + env: Optional[Dict] + type: str + metadata: Optional[Dict] + dataset_statistics: Optional[Sequence[Tuple[float, float]]] diff --git a/titiler/stacapi/stac_reader.py b/titiler/stacapi/stac_reader.py deleted file mode 100644 index 10ef9d7..0000000 --- a/titiler/stacapi/stac_reader.py +++ /dev/null @@ -1,103 +0,0 @@ -"""Custom STAC reader.""" - -from typing import Any, Dict, Optional, Set, Type - -import attr -import pystac -import rasterio -from morecantile import TileMatrixSet -from rasterio.crs import CRS -from rio_tiler.constants import WEB_MERCATOR_TMS, WGS84_CRS -from rio_tiler.errors import InvalidAssetName -from rio_tiler.io import BaseReader, Reader, stac -from rio_tiler.types import AssetInfo - -from titiler.stacapi.settings import STACSettings - -stac_config = STACSettings() - - -@attr.s -class STACReader(stac.STACReader): - """Custom STAC Reader. - - Only accept `pystac.Item` as input (while rio_tiler.io.STACReader accepts url or pystac.Item) - - """ - - input: pystac.Item = attr.ib() - - tms: TileMatrixSet = attr.ib(default=WEB_MERCATOR_TMS) - minzoom: int = attr.ib() - maxzoom: int = attr.ib() - - geographic_crs: CRS = attr.ib(default=WGS84_CRS) - - include_assets: Optional[Set[str]] = attr.ib(default=None) - exclude_assets: Optional[Set[str]] = attr.ib(default=None) - - include_asset_types: Set[str] = attr.ib(default=stac.DEFAULT_VALID_TYPE) - exclude_asset_types: Optional[Set[str]] = attr.ib(default=None) - - reader: Type[BaseReader] = attr.ib(default=Reader) - reader_options: Dict = attr.ib(factory=dict) - - fetch_options: Dict = attr.ib(factory=dict) - - ctx: Any = attr.ib(default=rasterio.Env) - - item: pystac.Item = attr.ib(init=False) - - def __attrs_post_init__(self): - """Fetch STAC Item and get list of valid assets.""" - self.item = self.input - super().__attrs_post_init__() - - @minzoom.default - def _minzoom(self): - return self.tms.minzoom - - @maxzoom.default - def _maxzoom(self): - return self.tms.maxzoom - - def _get_asset_info(self, asset: str) -> AssetInfo: - """Validate asset names and return asset's url. - - Args: - asset (str): STAC asset name. - - Returns: - str: STAC asset href. - - """ - if asset not in self.assets: - raise InvalidAssetName( - f"'{asset}' is not valid, should be one of {self.assets}" - ) - - asset_info = self.item.assets[asset] - extras = asset_info.extra_fields - - url = asset_info.get_absolute_href() or asset_info.href - if alternate := stac_config.alternate_url: - url = asset_info.to_dict()["alternate"][alternate]["href"] - - info = AssetInfo( - url=url, - metadata=extras, - ) - - if head := extras.get("file:header_size"): - info["env"] = {"GDAL_INGESTED_BYTES_AT_OPEN": head} - - if bands := extras.get("raster:bands"): - stats = [ - (b["statistics"]["minimum"], b["statistics"]["maximum"]) - for b in bands - if {"minimum", "maximum"}.issubset(b.get("statistics", {})) - ] - if len(stats) == len(bands): - info["dataset_statistics"] = stats - - return info