diff --git a/tiled/catalog/adapter.py b/tiled/catalog/adapter.py index edfa0d764..684e8811e 100644 --- a/tiled/catalog/adapter.py +++ b/tiled/catalog/adapter.py @@ -283,9 +283,18 @@ def __init__( self.ancestors = node.ancestors self.key = node.key self.access_policy = access_policy - self.data_source_name = data_source_name self.startup_tasks = [self.startup] self.shutdown_tasks = [self.shutdown] + if data_source_name is not None: + for data_source in self.data_sources: + if data_source_name == data_source.name: + self.data_source_structure_family = data_source.structure_family + break + else: + raise ValueError(f"No DataSource named {data_source_name} on this node") + self.data_source = data_source + elif len(self.data_sources) == 1: + (self.data_source,) = self.data_sources def metadata(self): return self.node.metadata_ @@ -443,30 +452,20 @@ async def lookup_adapter( ) async def get_adapter(self): - num_data_sources = len(self.data_sources) - if self.data_source_name is not None: - for data_source in self.data_sources: - if self.data_source_name == data_source.name: - break - else: - raise ValueError( - f"No DataSource named {self.data_source_name} on this node" - ) - elif num_data_sources > 1: - raise ValueError( - "A data_source_name is required because this node " - f"has {num_data_sources} data sources" + if (self.structure_family == StructureFamily.union) and not self.data_source: + raise RuntimeError( + "A data_source_name must be specified at construction time." ) - else: - (data_source,) = self.data_sources try: - adapter_factory = self.context.adapters_by_mimetype[data_source.mimetype] + adapter_factory = self.context.adapters_by_mimetype[ + self.data_source.mimetype + ] except KeyError: raise RuntimeError( - f"Server configuration has no adapter for mimetype {data_source.mimetype!r}" + f"Server configuration has no adapter for mimetype {self.data_source.mimetype!r}" ) parameters = collections.defaultdict(list) - for asset in data_source.assets: + for asset in self.data_source.assets: if asset.parameter is None: continue scheme = urlparse(asset.data_uri).scheme @@ -495,10 +494,10 @@ async def get_adapter(self): else: parameters[asset.parameter].append(asset.data_uri) adapter_kwargs = dict(parameters) - adapter_kwargs.update(data_source.parameters) + adapter_kwargs.update(self.data_source.parameters) adapter_kwargs["specs"] = self.node.specs adapter_kwargs["metadata"] = self.node.metadata_ - adapter_kwargs["structure"] = data_source.structure + adapter_kwargs["structure"] = self.data_source.structure adapter_kwargs["access_policy"] = self.access_policy adapter = await anyio.to_thread.run_sync( partial(adapter_factory, **adapter_kwargs) @@ -1007,7 +1006,9 @@ async def write_partition(self, *args, **kwargs): class CatalogUnionAdapter(CatalogNodeAdapter): - # This does not support direct reading or writing. + # def get(self, key): + # for data_source in data_sources: + # if data_source.name == async def read(self, *args, **kwargs): return await ensure_awaitable((await self.get_adapter()).read, *args, **kwargs) diff --git a/tiled/server/dependencies.py b/tiled/server/dependencies.py index a98f67938..7e1e964e1 100644 --- a/tiled/server/dependencies.py +++ b/tiled/server/dependencies.py @@ -11,6 +11,7 @@ serialization_registry as default_serialization_registry, ) from ..query_registration import query_registry as default_query_registry +from ..structures.core import StructureFamily from ..validation_registration import validation_registry as default_validation_registry from .authentication import get_current_principal, get_session_state from .core import NoEntry @@ -48,7 +49,7 @@ def get_root_tree(): ) -def SecureEntry(scopes): +def SecureEntry(scopes, structure_families=None): async def inner( path: str, request: Request, @@ -119,7 +120,36 @@ async def inner( ) except NoEntry: raise HTTPException(status_code=404, detail=f"No such entry: {path_parts}") - return entry + # Fast path for the common successful case + if (structure_families is None) or ( + entry.structure_family in structure_families + ): + return entry + # Handle union structure_family + if entry.structure_family == StructureFamily.union: + if not data_source: + raise HTTPException( + status_code=400, detail="A data_source query parameter is required." + ) + if entry.data_source.structure_family in structure_families: + return entry + raise HTTPException( + status_code=404, + detail=( + f"The data source named {data_source} backing the node " + f"at {path} has structure family {entry.data_source.structure_family} " + "and this endpoint is compatible with structure families " + f"{structure_families}" + ), + ) + raise HTTPException( + status_code=404, + detail=( + f"The node at {path} has structure family {entry.structure_family} " + "and this endpoint is compatible with structure families " + f"{structure_families}" + ), + ) return Security(inner, scopes=scopes) diff --git a/tiled/server/router.py b/tiled/server/router.py index 1f2be0865..c5ed0cd65 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -347,7 +347,10 @@ async def metadata( ) async def array_block( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), block=Depends(block), slice=Depends(slice_), expected_shape=Depends(expected_shape), @@ -359,15 +362,7 @@ async def array_block( """ Fetch a chunk of array-like data. """ - if entry.structure_family == "array": - shape = entry.structure().shape - elif entry.structure_family == "sparse": - shape = entry.structure().shape - else: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /array/block route.", - ) + shape = entry.structure().shape # Check that block dimensionality matches array dimensionality. ndim = len(shape) if len(block) != ndim: @@ -406,10 +401,14 @@ async def array_block( "Use slicing ('?slice=...') to request smaller chunks." ), ) + if entry.structure_family == StructureFamily.union: + structure_family = entry.data_source.structure_family + else: + structure_family = entry.structure_family try: with record_timing(request.state.metrics, "pack"): return await construct_data_response( - entry.structure_family, + structure_family, serialization_registry, array, entry.metadata(), @@ -429,7 +428,10 @@ async def array_block( ) async def array_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), slice=Depends(slice_), expected_shape=Depends(expected_shape), format: Optional[str] = None, @@ -440,12 +442,10 @@ async def array_full( """ Fetch a slice of array-like data. """ - structure_family = entry.structure_family - if structure_family not in {"array", "sparse"}: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /array/full route.", - ) + if entry.structure_family == StructureFamily.union: + structure_family = entry.data_source.structure_family + else: + structure_family = entry.structure_family # Deferred import because this is not a required dependency of the server # for some use cases. import numpy @@ -453,7 +453,7 @@ async def array_full( try: with record_timing(request.state.metrics, "read"): array = await ensure_awaitable(entry.read, slice) - if structure_family == "array": + if structure_family == StructureFamily.array: array = numpy.asarray(array) # Force dask or PIMS or ... to do I/O. except IndexError: raise HTTPException(status_code=400, detail="Block index out of range") @@ -495,7 +495,7 @@ async def array_full( async def get_table_partition( request: Request, partition: int, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), column: Optional[List[str]] = Query(None, min_length=1), field: Optional[List[str]] = Query(None, min_length=1, deprecated=True), format: Optional[str] = None, @@ -543,7 +543,7 @@ async def get_table_partition( async def post_table_partition( request: Request, partition: int, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), column: Optional[List[str]] = Body(None, min_length=1), format: Optional[str] = None, filename: Optional[str] = None, @@ -578,11 +578,6 @@ async def table_partition( """ Fetch a partition (continuous block of rows) from a DataFrame. """ - if entry.structure_family != StructureFamily.table: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /table/partition route.", - ) try: # The singular/plural mismatch here of "fields" and "field" is # due to the ?field=A&field=B&field=C... encodes in a URL. @@ -626,7 +621,7 @@ async def table_partition( ) async def get_table_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), column: Optional[List[str]] = Query(None, min_length=1), format: Optional[str] = None, filename: Optional[str] = None, @@ -654,7 +649,7 @@ async def get_table_full( ) async def post_table_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), column: Optional[List[str]] = Body(None, min_length=1), format: Optional[str] = None, filename: Optional[str] = None, @@ -687,11 +682,6 @@ async def table_full( """ Fetch the data for the given table. """ - # if entry.structure_family != StructureFamily.table: - # raise HTTPException( - # status_code=404, - # detail=f"Cannot read {entry.structure_family} structure with /table/full route.", - # ) try: with record_timing(request.state.metrics, "read"): data = await ensure_awaitable(entry.read, column) @@ -707,10 +697,14 @@ async def table_full( "request a smaller chunks." ), ) + if entry.structure_family == StructureFamily.union: + structure_family = entry.data_source.structure_family + else: + structure_family = entry.structure_family try: with record_timing(request.state.metrics, "pack"): return await construct_data_response( - entry.structure_family, + structure_family, serialization_registry, data, entry.metadata(), @@ -732,7 +726,9 @@ async def table_full( ) async def get_container_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.container} + ), principal: str = Depends(get_current_principal), field: Optional[List[str]] = Query(None, min_length=1), format: Optional[str] = None, @@ -760,7 +756,9 @@ async def get_container_full( ) async def post_container_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.container} + ), principal: str = Depends(get_current_principal), field: Optional[List[str]] = Body(None, min_length=1), format: Optional[str] = None, @@ -793,11 +791,6 @@ async def container_full( """ Fetch the data for the given container. """ - if entry.structure_family != StructureFamily.container: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /container/full route.", - ) try: with record_timing(request.state.metrics, "read"): data = await ensure_awaitable(entry.read, fields=field) @@ -837,7 +830,10 @@ async def container_full( ) async def node_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], + structure_families={StructureFamily.table, StructureFamily.container}, + ), principal: str = Depends(get_current_principal), field: Optional[List[str]] = Query(None, min_length=1), format: Optional[str] = None, @@ -900,7 +896,9 @@ async def node_full( ) async def get_awkward_buffers( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.awkward} + ), form_key: Optional[List[str]] = Query(None, min_length=1), format: Optional[str] = None, filename: Optional[str] = None, @@ -936,7 +934,9 @@ async def get_awkward_buffers( async def post_awkward_buffers( request: Request, body: List[str], - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.awkward} + ), format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), @@ -974,11 +974,6 @@ async def _awkward_buffers( ): structure_family = entry.structure_family structure = entry.structure() - if structure_family != StructureFamily.awkward: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /awkward/buffers route.", - ) with record_timing(request.state.metrics, "read"): # The plural vs. singular mismatch is due to the way query parameters # are given as ?form_key=A&form_key=B&form_key=C. @@ -1019,7 +1014,9 @@ async def _awkward_buffers( ) async def awkward_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.awkward} + ), # slice=Depends(slice_), format: Optional[str] = None, filename: Optional[str] = None, @@ -1030,11 +1027,6 @@ async def awkward_full( Fetch a slice of AwkwardArray data. """ structure_family = entry.structure_family - if structure_family != StructureFamily.awkward: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /awkward/full route.", - ) # Deferred import because this is not a required dependency of the server # for some use cases. import awkward @@ -1254,7 +1246,10 @@ async def bulk_delete( @router.put("/array/full/{path:path}") async def put_array_full( request: Request, - entry=SecureEntry(scopes=["write:data"]), + entry=SecureEntry( + scopes=["write:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), deserialization_registry=Depends(get_deserialization_registry), ): body = await request.body() @@ -1280,7 +1275,10 @@ async def put_array_full( @router.put("/array/block/{path:path}") async def put_array_block( request: Request, - entry=SecureEntry(scopes=["write:data"]), + entry=SecureEntry( + scopes=["write:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), deserialization_registry=Depends(get_deserialization_registry), block=Depends(block), ):