Skip to content

Commit

Permalink
Refactor structure family check into dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
danielballan committed Mar 8, 2024
1 parent 9b0f051 commit 0e65ba8
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 64 deletions.
16 changes: 14 additions & 2 deletions tiled/server/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_root_tree():
)


def SecureEntry(scopes):
def SecureEntry(scopes, structure_families=None):
async def inner(
path: str,
request: Request,
Expand Down Expand Up @@ -116,7 +116,19 @@ async def inner(
)
except NoEntry:
raise HTTPException(status_code=404, detail=f"No such entry: {path_parts}")
return entry
# Fast path for the common successful case
if (structure_families is None) or (
entry.structure_family in structure_families
):
return entry
raise HTTPException(
status_code=404,
detail=(
f"The node at {path} has structure family {entry.structure_family} "
"and this endpoint is compatible with structure families "
f"{structure_families}"
),
)

return Security(inner, scopes=scopes)

Expand Down
120 changes: 58 additions & 62 deletions tiled/server/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,10 @@ async def metadata(
)
async def array_block(
request: Request,
entry=SecureEntry(scopes=["read:data"]),
entry=SecureEntry(
scopes=["read:data"],
structure_families={StructureFamily.array, StructureFamily.sparse},
),
block=Depends(block),
slice=Depends(slice_),
expected_shape=Depends(expected_shape),
Expand All @@ -359,15 +362,7 @@ async def array_block(
"""
Fetch a chunk of array-like data.
"""
if entry.structure_family == "array":
shape = entry.structure().shape
elif entry.structure_family == "sparse":
shape = entry.structure().shape
else:
raise HTTPException(
status_code=404,
detail=f"Cannot read {entry.structure_family} structure with /array/block route.",
)
shape = entry.structure().shape
# Check that block dimensionality matches array dimensionality.
ndim = len(shape)
if len(block) != ndim:
Expand Down Expand Up @@ -406,10 +401,14 @@ async def array_block(
"Use slicing ('?slice=...') to request smaller chunks."
),
)
if entry.structure_family == StructureFamily.union:
structure_family = entry.data_source.structure_family
else:
structure_family = entry.structure_family
try:
with record_timing(request.state.metrics, "pack"):
return await construct_data_response(
entry.structure_family,
structure_family,
serialization_registry,
array,
entry.metadata(),
Expand All @@ -429,7 +428,10 @@ async def array_block(
)
async def array_full(
request: Request,
entry=SecureEntry(scopes=["read:data"]),
entry=SecureEntry(
scopes=["read:data"],
structure_families={StructureFamily.array, StructureFamily.sparse},
),
slice=Depends(slice_),
expected_shape=Depends(expected_shape),
format: Optional[str] = None,
Expand All @@ -440,20 +442,18 @@ async def array_full(
"""
Fetch a slice of array-like data.
"""
structure_family = entry.structure_family
if structure_family not in {"array", "sparse"}:
raise HTTPException(
status_code=404,
detail=f"Cannot read {entry.structure_family} structure with /array/full route.",
)
if entry.structure_family == StructureFamily.union:
structure_family = entry.data_source.structure_family
else:
structure_family = entry.structure_family
# Deferred import because this is not a required dependency of the server
# for some use cases.
import numpy

try:
with record_timing(request.state.metrics, "read"):
array = await ensure_awaitable(entry.read, slice)
if structure_family == "array":
if structure_family == StructureFamily.array:
array = numpy.asarray(array) # Force dask or PIMS or ... to do I/O.
except IndexError:
raise HTTPException(status_code=400, detail="Block index out of range")
Expand Down Expand Up @@ -495,7 +495,7 @@ async def array_full(
async def get_table_partition(
request: Request,
partition: int,
entry=SecureEntry(scopes=["read:data"]),
entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}),
column: Optional[List[str]] = Query(None, min_length=1),
field: Optional[List[str]] = Query(None, min_length=1, deprecated=True),
format: Optional[str] = None,
Expand Down Expand Up @@ -543,7 +543,7 @@ async def get_table_partition(
async def post_table_partition(
request: Request,
partition: int,
entry=SecureEntry(scopes=["read:data"]),
entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}),
column: Optional[List[str]] = Body(None, min_length=1),
format: Optional[str] = None,
filename: Optional[str] = None,
Expand Down Expand Up @@ -578,11 +578,6 @@ async def table_partition(
"""
Fetch a partition (continuous block of rows) from a DataFrame.
"""
if entry.structure_family != StructureFamily.table:
raise HTTPException(
status_code=404,
detail=f"Cannot read {entry.structure_family} structure with /table/partition route.",
)
try:
# The singular/plural mismatch here of "fields" and "field" is
# due to the ?field=A&field=B&field=C... encodes in a URL.
Expand Down Expand Up @@ -626,7 +621,7 @@ async def table_partition(
)
async def get_table_full(
request: Request,
entry=SecureEntry(scopes=["read:data"]),
entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}),
column: Optional[List[str]] = Query(None, min_length=1),
format: Optional[str] = None,
filename: Optional[str] = None,
Expand Down Expand Up @@ -654,7 +649,7 @@ async def get_table_full(
)
async def post_table_full(
request: Request,
entry=SecureEntry(scopes=["read:data"]),
entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}),
column: Optional[List[str]] = Body(None, min_length=1),
format: Optional[str] = None,
filename: Optional[str] = None,
Expand Down Expand Up @@ -687,11 +682,6 @@ async def table_full(
"""
Fetch the data for the given table.
"""
if entry.structure_family != StructureFamily.table:
raise HTTPException(
status_code=404,
detail=f"Cannot read {entry.structure_family} structure with /table/full route.",
)
try:
with record_timing(request.state.metrics, "read"):
data = await ensure_awaitable(entry.read, column)
Expand All @@ -707,10 +697,14 @@ async def table_full(
"request a smaller chunks."
),
)
if entry.structure_family == StructureFamily.union:
structure_family = entry.data_source.structure_family
else:
structure_family = entry.structure_family
try:
with record_timing(request.state.metrics, "pack"):
return await construct_data_response(
entry.structure_family,
structure_family,
serialization_registry,
data,
entry.metadata(),
Expand All @@ -732,7 +726,9 @@ async def table_full(
)
async def get_container_full(
request: Request,
entry=SecureEntry(scopes=["read:data"]),
entry=SecureEntry(
scopes=["read:data"], structure_families={StructureFamily.container}
),
principal: str = Depends(get_current_principal),
field: Optional[List[str]] = Query(None, min_length=1),
format: Optional[str] = None,
Expand Down Expand Up @@ -760,7 +756,9 @@ async def get_container_full(
)
async def post_container_full(
request: Request,
entry=SecureEntry(scopes=["read:data"]),
entry=SecureEntry(
scopes=["read:data"], structure_families={StructureFamily.container}
),
principal: str = Depends(get_current_principal),
field: Optional[List[str]] = Body(None, min_length=1),
format: Optional[str] = None,
Expand Down Expand Up @@ -793,11 +791,6 @@ async def container_full(
"""
Fetch the data for the given container.
"""
if entry.structure_family != StructureFamily.container:
raise HTTPException(
status_code=404,
detail=f"Cannot read {entry.structure_family} structure with /container/full route.",
)
try:
with record_timing(request.state.metrics, "read"):
data = await ensure_awaitable(entry.read, fields=field)
Expand Down Expand Up @@ -837,7 +830,10 @@ async def container_full(
)
async def node_full(
request: Request,
entry=SecureEntry(scopes=["read:data"]),
entry=SecureEntry(
scopes=["read:data"],
structure_families={StructureFamily.table, StructureFamily.container},
),
principal: str = Depends(get_current_principal),
field: Optional[List[str]] = Query(None, min_length=1),
format: Optional[str] = None,
Expand Down Expand Up @@ -900,7 +896,9 @@ async def node_full(
)
async def get_awkward_buffers(
request: Request,
entry=SecureEntry(scopes=["read:data"]),
entry=SecureEntry(
scopes=["read:data"], structure_families={StructureFamily.awkward}
),
form_key: Optional[List[str]] = Query(None, min_length=1),
format: Optional[str] = None,
filename: Optional[str] = None,
Expand Down Expand Up @@ -936,7 +934,9 @@ async def get_awkward_buffers(
async def post_awkward_buffers(
request: Request,
body: List[str],
entry=SecureEntry(scopes=["read:data"]),
entry=SecureEntry(
scopes=["read:data"], structure_families={StructureFamily.awkward}
),
format: Optional[str] = None,
filename: Optional[str] = None,
serialization_registry=Depends(get_serialization_registry),
Expand Down Expand Up @@ -974,11 +974,6 @@ async def _awkward_buffers(
):
structure_family = entry.structure_family
structure = entry.structure()
if structure_family != StructureFamily.awkward:
raise HTTPException(
status_code=404,
detail=f"Cannot read {entry.structure_family} structure with /awkward/buffers route.",
)
with record_timing(request.state.metrics, "read"):
# The plural vs. singular mismatch is due to the way query parameters
# are given as ?form_key=A&form_key=B&form_key=C.
Expand Down Expand Up @@ -1019,7 +1014,9 @@ async def _awkward_buffers(
)
async def awkward_full(
request: Request,
entry=SecureEntry(scopes=["read:data"]),
entry=SecureEntry(
scopes=["read:data"], structure_families={StructureFamily.awkward}
),
# slice=Depends(slice_),
format: Optional[str] = None,
filename: Optional[str] = None,
Expand All @@ -1030,11 +1027,6 @@ async def awkward_full(
Fetch a slice of AwkwardArray data.
"""
structure_family = entry.structure_family
if structure_family != StructureFamily.awkward:
raise HTTPException(
status_code=404,
detail=f"Cannot read {entry.structure_family} structure with /awkward/full route.",
)
# Deferred import because this is not a required dependency of the server
# for some use cases.
import awkward
Expand Down Expand Up @@ -1217,7 +1209,10 @@ async def bulk_delete(
@router.put("/array/full/{path:path}")
async def put_array_full(
request: Request,
entry=SecureEntry(scopes=["write:data"]),
entry=SecureEntry(
scopes=["write:data"],
structure_families={StructureFamily.array, StructureFamily.sparse},
),
deserialization_registry=Depends(get_deserialization_registry),
):
body = await request.body()
Expand All @@ -1243,7 +1238,10 @@ async def put_array_full(
@router.put("/array/block/{path:path}")
async def put_array_block(
request: Request,
entry=SecureEntry(scopes=["write:data"]),
entry=SecureEntry(
scopes=["write:data"],
structure_families={StructureFamily.array, StructureFamily.sparse},
),
deserialization_registry=Depends(get_deserialization_registry),
block=Depends(block),
):
Expand Down Expand Up @@ -1312,14 +1310,12 @@ async def put_table_partition(
@router.put("/awkward/full/{path:path}")
async def put_awkward_full(
request: Request,
entry=SecureEntry(scopes=["write:data"]),
entry=SecureEntry(
scopes=["write:data"], structure_families={StructureFamily.awkward}
),
deserialization_registry=Depends(get_deserialization_registry),
):
body = await request.body()
if entry.structure_family != StructureFamily.awkward:
raise HTTPException(
status_code=404, detail="This route is not applicable to this node."
)
if not hasattr(entry, "write"):
raise HTTPException(status_code=405, detail="This node cannot be written to.")
media_type = request.headers["content-type"]
Expand Down

0 comments on commit 0e65ba8

Please sign in to comment.