Skip to content

Commit

Permalink
Writing and reading arrays works.
Browse files Browse the repository at this point in the history
  • Loading branch information
danielballan committed Feb 26, 2024
1 parent 80eef8c commit 927d6cb
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 74 deletions.
91 changes: 31 additions & 60 deletions tiled/catalog/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,12 @@ def __init__(
context,
node,
*,
structure_family=None,
data_sources=None,
conditions=None,
queries=None,
sorting=None,
access_policy=None,
data_source_name=None,
):
self.context = context
self.engine = self.context.engine
Expand All @@ -277,23 +278,18 @@ def __init__(
self.order_by_clauses = order_by_clauses(self.sorting)
self.conditions = conditions or []
self.queries = queries or []
self.structure_family = node.structure_family
self.specs = [Spec.parse_obj(spec) for spec in node.specs]
self.ancestors = node.ancestors
self.key = node.key
self.access_policy = access_policy
self.startup_tasks = [self.startup]
self.shutdown_tasks = [self.shutdown]
if data_source_name is not None:
for data_source in self.data_sources:
if data_source_name == data_source.name:
self.data_source_structure_family = data_source.structure_family
break
else:
raise ValueError(f"No DataSource named {data_source_name} on this node")
self.data_source = data_source
elif len(self.data_sources) == 1:
(self.data_source,) = self.data_sources
self.structure_family = structure_family or node.structure_family
if data_sources is None:
data_sources = [
DataSource.from_orm(ds) for ds in self.node.data_sources or []
]
self.data_sources = data_sources

def metadata(self):
return self.node.metadata_
Expand Down Expand Up @@ -332,10 +328,6 @@ async def __aiter__(self):
async with self.context.session() as db:
return (await db.execute(statement)).scalar().all()

@property
def data_sources(self):
return [DataSource.from_orm(ds) for ds in self.node.data_sources or []]

async def asset_by_id(self, asset_id):
statement = (
select(orm.Asset)
Expand Down Expand Up @@ -394,7 +386,6 @@ async def async_len(self):
async def lookup_adapter(
self,
segments,
data_source_name=None,
): # TODO: Accept filter for predicate-pushdown.
if not segments:
return self
Expand All @@ -404,17 +395,13 @@ async def lookup_adapter(
# this node, either via user search queries or via access
# control policy queries. Look up first the _direct_ child of this
# node, if it exists within the filtered results.
first_level = await self.lookup_adapter(
segments[:1], data_source_name=data_source_name
)
first_level = await self.lookup_adapter(segments[:1])
if first_level is None:
return None
# Now proceed to traverse further down the tree, if needed.
# Search queries and access controls apply only at the top level.
assert not first_level.conditions
return await first_level.lookup_adapter(
segments[1:], data_source_name=data_source_name
)
return await first_level.lookup_adapter(segments[1:])
statement = (
select(orm.Node)
.filter(orm.Node.ancestors == self.segments + ancestors)
Expand All @@ -437,9 +424,7 @@ async def lookup_adapter(
# HDF5 file begins.

for i in range(len(segments)):
catalog_adapter = await self.lookup_adapter(
segments[:i], data_source_name=data_source_name
)
catalog_adapter = await self.lookup_adapter(segments[:i])
if catalog_adapter.data_sources:
adapter = await catalog_adapter.get_adapter()
for segment in segments[i:]:
Expand All @@ -451,25 +436,19 @@ async def lookup_adapter(
return STRUCTURES[node.structure_family](
self.context,
node,
data_source_name=data_source_name,
access_policy=self.access_policy,
)

async def get_adapter(self):
if (self.structure_family == StructureFamily.union) and not self.data_source:
raise RuntimeError(
"A data_source_name must be specified at construction time."
)
(data_source,) = self.data_sources
try:
adapter_factory = self.context.adapters_by_mimetype[
self.data_source.mimetype
]
adapter_factory = self.context.adapters_by_mimetype[data_source.mimetype]
except KeyError:
raise RuntimeError(
f"Server configuration has no adapter for mimetype {self.data_source.mimetype!r}"
f"Server configuration has no adapter for mimetype {data_source.mimetype!r}"
)
parameters = collections.defaultdict(list)
for asset in self.data_source.assets:
for asset in data_source.assets:
if asset.parameter is None:
continue
scheme = urlparse(asset.data_uri).scheme
Expand Down Expand Up @@ -498,10 +477,10 @@ async def get_adapter(self):
else:
parameters[asset.parameter].append(asset.data_uri)
adapter_kwargs = dict(parameters)
adapter_kwargs.update(self.data_source.parameters)
adapter_kwargs.update(data_source.parameters)
adapter_kwargs["specs"] = self.node.specs
adapter_kwargs["metadata"] = self.node.metadata_
adapter_kwargs["structure"] = self.data_source.structure
adapter_kwargs["structure"] = data_source.structure
adapter_kwargs["access_policy"] = self.access_policy
adapter = await anyio.to_thread.run_sync(
partial(adapter_factory, **adapter_kwargs)
Expand Down Expand Up @@ -1008,29 +987,21 @@ async def write_partition(self, *args, **kwargs):


class CatalogUnionAdapter(CatalogNodeAdapter):
# def get(self, key):
# for data_source in data_sources:
# if data_source.name ==

async def read(self, *args, **kwargs):
return await ensure_awaitable((await self.get_adapter()).read, *args, **kwargs)

async def read_block(self, *args, **kwargs):
return await ensure_awaitable(
(await self.get_adapter()).read_block, *args, **kwargs
)
def get(self, key):
...

async def write(self, *args, **kwargs):
return await ensure_awaitable((await self.get_adapter()).write, *args, **kwargs)

async def read_partition(self, *args, **kwargs):
return await ensure_awaitable(
(await self.get_adapter()).read_partition, *args, **kwargs
)

async def write_partition(self, *args, **kwargs):
return await ensure_awaitable(
(await self.get_adapter()).write_partition, *args, **kwargs
def for_data_source(self, data_source_name):
for data_source in self.data_sources:
if data_source_name == data_source.name:
break
else:
raise ValueError(f"No DataSource named {data_source_name} on this node")
return STRUCTURES[data_source.structure_family](
self.context,
self.node,
access_policy=self.access_policy,
structure_family=data_source.structure_family,
data_sources=[data_source],
)


Expand Down
21 changes: 12 additions & 9 deletions tiled/server/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ async def inner(
# It can jump directly to the node of interest.

if hasattr(entry, "lookup_adapter"):
entry = await entry.lookup_adapter(
path_parts[i:], data_source_name=data_source
)
entry = await entry.lookup_adapter(path_parts[i:])
if entry is None:
raise NoEntry(path_parts)
break
Expand Down Expand Up @@ -129,24 +127,29 @@ async def inner(
if entry.structure_family == StructureFamily.union:
if not data_source:
raise HTTPException(
status_code=400, detail="A data_source query parameter is required."
status_code=400,
detail=(
"A data_source query parameter is required on this endpoint "
"when addressing a 'union' structure."
),
)
if entry.data_source.structure_family in structure_families:
return entry
entry_for_data_source = entry.for_data_source(data_source)
if entry_for_data_source.structure_family in structure_families:
return entry_for_data_source
raise HTTPException(
status_code=404,
detail=(
f"The data source named {data_source} backing the node "
f"at {path} has structure family {entry.data_source.structure_family} "
"and this endpoint is compatible with structure families "
f"at {path} has structure family {entry_for_data_source.structure_family} "
"and this endpoint is compatible only with structure families "
f"{structure_families}"
),
)
raise HTTPException(
status_code=404,
detail=(
f"The node at {path} has structure family {entry.structure_family} "
"and this endpoint is compatible with structure families "
"and this endpoint is compatible only with structure families "
f"{structure_families}"
),
)
Expand Down
9 changes: 4 additions & 5 deletions tiled/server/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,7 @@ async def array_block(
"""
Fetch a chunk of array-like data.
"""
if data_source is None:
shape = entry.structure().shape
else:
shape = entry.data_source.structure.shape
shape = entry.structure().shape
# Check that block dimensionality matches array dimensionality.
ndim = len(shape)
if len(block) != ndim:
Expand Down Expand Up @@ -1292,7 +1289,9 @@ async def put_array_block(
@router.put("/node/full/{path:path}", deprecated=True)
async def put_node_full(
request: Request,
entry=SecureEntry(scopes=["write:data"]),
entry=SecureEntry(
scopes=["write:data"], structure_families={StructureFamily.table}
),
deserialization_registry=Depends(get_deserialization_registry),
):
if not hasattr(entry, "write"):
Expand Down

0 comments on commit 927d6cb

Please sign in to comment.