diff --git a/tiled/catalog/adapter.py b/tiled/catalog/adapter.py index d7e3ce9d8..0fdc17fe3 100644 --- a/tiled/catalog/adapter.py +++ b/tiled/catalog/adapter.py @@ -259,11 +259,12 @@ def __init__( context, node, *, + structure_family=None, + data_sources=None, conditions=None, queries=None, sorting=None, access_policy=None, - data_source_name=None, ): self.context = context self.engine = self.context.engine @@ -277,23 +278,18 @@ def __init__( self.order_by_clauses = order_by_clauses(self.sorting) self.conditions = conditions or [] self.queries = queries or [] - self.structure_family = node.structure_family self.specs = [Spec.parse_obj(spec) for spec in node.specs] self.ancestors = node.ancestors self.key = node.key self.access_policy = access_policy self.startup_tasks = [self.startup] self.shutdown_tasks = [self.shutdown] - if data_source_name is not None: - for data_source in self.data_sources: - if data_source_name == data_source.name: - self.data_source_structure_family = data_source.structure_family - break - else: - raise ValueError(f"No DataSource named {data_source_name} on this node") - self.data_source = data_source - elif len(self.data_sources) == 1: - (self.data_source,) = self.data_sources + self.structure_family = structure_family or node.structure_family + if data_sources is None: + data_sources = [ + DataSource.from_orm(ds) for ds in self.node.data_sources or [] + ] + self.data_sources = data_sources def metadata(self): return self.node.metadata_ @@ -332,10 +328,6 @@ async def __aiter__(self): async with self.context.session() as db: return (await db.execute(statement)).scalar().all() - @property - def data_sources(self): - return [DataSource.from_orm(ds) for ds in self.node.data_sources or []] - async def asset_by_id(self, asset_id): statement = ( select(orm.Asset) @@ -394,7 +386,6 @@ async def async_len(self): async def lookup_adapter( self, segments, - data_source_name=None, ): # TODO: Accept filter for predicate-pushdown. if not segments: return self @@ -404,17 +395,13 @@ async def lookup_adapter( # this node, either via user search queries or via access # control policy queries. Look up first the _direct_ child of this # node, if it exists within the filtered results. - first_level = await self.lookup_adapter( - segments[:1], data_source_name=data_source_name - ) + first_level = await self.lookup_adapter(segments[:1]) if first_level is None: return None # Now proceed to traverse further down the tree, if needed. # Search queries and access controls apply only at the top level. assert not first_level.conditions - return await first_level.lookup_adapter( - segments[1:], data_source_name=data_source_name - ) + return await first_level.lookup_adapter(segments[1:]) statement = ( select(orm.Node) .filter(orm.Node.ancestors == self.segments + ancestors) @@ -437,9 +424,7 @@ async def lookup_adapter( # HDF5 file begins. for i in range(len(segments)): - catalog_adapter = await self.lookup_adapter( - segments[:i], data_source_name=data_source_name - ) + catalog_adapter = await self.lookup_adapter(segments[:i]) if catalog_adapter.data_sources: adapter = await catalog_adapter.get_adapter() for segment in segments[i:]: @@ -451,25 +436,19 @@ async def lookup_adapter( return STRUCTURES[node.structure_family]( self.context, node, - data_source_name=data_source_name, access_policy=self.access_policy, ) async def get_adapter(self): - if (self.structure_family == StructureFamily.union) and not self.data_source: - raise RuntimeError( - "A data_source_name must be specified at construction time." - ) + (data_source,) = self.data_sources try: - adapter_factory = self.context.adapters_by_mimetype[ - self.data_source.mimetype - ] + adapter_factory = self.context.adapters_by_mimetype[data_source.mimetype] except KeyError: raise RuntimeError( - f"Server configuration has no adapter for mimetype {self.data_source.mimetype!r}" + f"Server configuration has no adapter for mimetype {data_source.mimetype!r}" ) parameters = collections.defaultdict(list) - for asset in self.data_source.assets: + for asset in data_source.assets: if asset.parameter is None: continue scheme = urlparse(asset.data_uri).scheme @@ -498,10 +477,10 @@ async def get_adapter(self): else: parameters[asset.parameter].append(asset.data_uri) adapter_kwargs = dict(parameters) - adapter_kwargs.update(self.data_source.parameters) + adapter_kwargs.update(data_source.parameters) adapter_kwargs["specs"] = self.node.specs adapter_kwargs["metadata"] = self.node.metadata_ - adapter_kwargs["structure"] = self.data_source.structure + adapter_kwargs["structure"] = data_source.structure adapter_kwargs["access_policy"] = self.access_policy adapter = await anyio.to_thread.run_sync( partial(adapter_factory, **adapter_kwargs) @@ -1008,29 +987,21 @@ async def write_partition(self, *args, **kwargs): class CatalogUnionAdapter(CatalogNodeAdapter): - # def get(self, key): - # for data_source in data_sources: - # if data_source.name == - - async def read(self, *args, **kwargs): - return await ensure_awaitable((await self.get_adapter()).read, *args, **kwargs) - - async def read_block(self, *args, **kwargs): - return await ensure_awaitable( - (await self.get_adapter()).read_block, *args, **kwargs - ) + def get(self, key): + ... - async def write(self, *args, **kwargs): - return await ensure_awaitable((await self.get_adapter()).write, *args, **kwargs) - - async def read_partition(self, *args, **kwargs): - return await ensure_awaitable( - (await self.get_adapter()).read_partition, *args, **kwargs - ) - - async def write_partition(self, *args, **kwargs): - return await ensure_awaitable( - (await self.get_adapter()).write_partition, *args, **kwargs + def for_data_source(self, data_source_name): + for data_source in self.data_sources: + if data_source_name == data_source.name: + break + else: + raise ValueError(f"No DataSource named {data_source_name} on this node") + return STRUCTURES[data_source.structure_family]( + self.context, + self.node, + access_policy=self.access_policy, + structure_family=data_source.structure_family, + data_sources=[data_source], ) diff --git a/tiled/server/dependencies.py b/tiled/server/dependencies.py index 7e1e964e1..520741c52 100644 --- a/tiled/server/dependencies.py +++ b/tiled/server/dependencies.py @@ -84,9 +84,7 @@ async def inner( # It can jump directly to the node of interest. if hasattr(entry, "lookup_adapter"): - entry = await entry.lookup_adapter( - path_parts[i:], data_source_name=data_source - ) + entry = await entry.lookup_adapter(path_parts[i:]) if entry is None: raise NoEntry(path_parts) break @@ -129,16 +127,21 @@ async def inner( if entry.structure_family == StructureFamily.union: if not data_source: raise HTTPException( - status_code=400, detail="A data_source query parameter is required." + status_code=400, + detail=( + "A data_source query parameter is required on this endpoint " + "when addressing a 'union' structure." + ), ) - if entry.data_source.structure_family in structure_families: - return entry + entry_for_data_source = entry.for_data_source(data_source) + if entry_for_data_source.structure_family in structure_families: + return entry_for_data_source raise HTTPException( status_code=404, detail=( f"The data source named {data_source} backing the node " - f"at {path} has structure family {entry.data_source.structure_family} " - "and this endpoint is compatible with structure families " + f"at {path} has structure family {entry_for_data_source.structure_family} " + "and this endpoint is compatible only with structure families " f"{structure_families}" ), ) @@ -146,7 +149,7 @@ async def inner( status_code=404, detail=( f"The node at {path} has structure family {entry.structure_family} " - "and this endpoint is compatible with structure families " + "and this endpoint is compatible only with structure families " f"{structure_families}" ), ) diff --git a/tiled/server/router.py b/tiled/server/router.py index 67930515a..fc1c3a551 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -364,10 +364,7 @@ async def array_block( """ Fetch a chunk of array-like data. """ - if data_source is None: - shape = entry.structure().shape - else: - shape = entry.data_source.structure.shape + shape = entry.structure().shape # Check that block dimensionality matches array dimensionality. ndim = len(shape) if len(block) != ndim: @@ -1292,7 +1289,9 @@ async def put_array_block( @router.put("/node/full/{path:path}", deprecated=True) async def put_node_full( request: Request, - entry=SecureEntry(scopes=["write:data"]), + entry=SecureEntry( + scopes=["write:data"], structure_families={StructureFamily.table} + ), deserialization_registry=Depends(get_deserialization_registry), ): if not hasattr(entry, "write"):