Skip to content

Commit

Permalink
Refactoring to allow adding multiple Assets (dynamic tiff sequences) (#…
Browse files Browse the repository at this point in the history
…718)

* ENH: refactor put_data_source

* ENH: support sequences of tiffs with more than 2 dimensions

* ENH: support sequences of tiffs with more than 2 dimensions

* add tests for multidimensional tiff sequences

* fmt

* ENH: force the results of TiffSequenceAdapter to be arrays

* ENH: force the results of TiffSequenceAdapter to be arrays

* MNT: fix typo
  • Loading branch information
genematx authored Apr 18, 2024
1 parent 84c185d commit 0753fe5
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 83 deletions.
2 changes: 1 addition & 1 deletion docs/source/explanations/caching.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ from cachetools import Cache
from tiled.adapters.resource_cache import set_resource_cache

cache = Cache(maxsize=1)
set_resouurce_cache(cache)
set_resource_cache(cache)
```

Any object satisfying the `cachetools.Cache` interface is acceptable.
18 changes: 11 additions & 7 deletions tiled/_tests/test_tiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def client(tmpdir_module):
sequence_directory.mkdir()
filepaths = []
for i in range(3):
data = numpy.random.random((5, 7))
data = numpy.random.random((5, 7, 4))
filepath = sequence_directory / f"temp{i:05}.tif"
tf.imwrite(filepath, data)
filepaths.append(filepath)
Expand All @@ -46,19 +46,23 @@ def client(tmpdir_module):
@pytest.mark.parametrize(
"slice_input, correct_shape",
[
(None, (3, 5, 7)),
(0, (5, 7)),
(slice(0, 3, 2), (2, 5, 7)),
((1, slice(0, 3), slice(0, 3)), (3, 3)),
((slice(0, 3), slice(0, 3), slice(0, 3)), (3, 3, 3)),
(None, (3, 5, 7, 4)),
(0, (5, 7, 4)),
(slice(0, 3, 2), (2, 5, 7, 4)),
((1, slice(0, 3), slice(0, 3)), (3, 3, 4)),
((slice(0, 3), slice(0, 3), slice(0, 3)), (3, 3, 3, 4)),
((..., 0, 0, 0), (3,)),
((0, slice(0, 1), slice(0, 2), ...), (1, 2, 4)),
((0, ..., slice(0, 2)), (5, 7, 2)),
((..., slice(0, 1)), (3, 5, 7, 1)),
],
)
def test_tiff_sequence(client, slice_input, correct_shape):
arr = client["sequence"].read(slice=slice_input)
assert arr.shape == correct_shape


@pytest.mark.parametrize("block_input, correct_shape", [((0, 0, 0), (1, 5, 7))])
@pytest.mark.parametrize("block_input, correct_shape", [((0, 0, 0, 0), (1, 5, 7, 4))])
def test_tiff_sequence_block(client, block_input, correct_shape):
arr = client["sequence"].read_block(block_input)
assert arr.shape == correct_shape
Expand Down
66 changes: 23 additions & 43 deletions tiled/adapters/tiff.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import builtins

import numpy as np
import tifffile

from ..structures.array import ArrayStructure, BuiltinDtype
Expand Down Expand Up @@ -123,7 +124,7 @@ def __init__(
structure = ArrayStructure(
shape=shape,
# one chunks per underlying TIFF file
chunks=((1,) * shape[0], (shape[1],), (shape[2],)),
chunks=((1,) * shape[0], *[(i,) for i in shape[1:]]),
# Assume all files have the same data type
data_type=BuiltinDtype.from_numpy_dtype(self.read(slice=0).dtype),
)
Expand All @@ -133,66 +134,45 @@ def metadata(self):
# TODO How to deal with the many headers?
return self._provided_metadata

def read(self, slice=None):
def read(self, slice=Ellipsis):
"""Return a numpy array
Receives a sequence of values to select from a collection of tiff files that were saved in a folder
The input order is defined as files --> X slice --> Y slice
The input order is defined as: files --> vertical slice --> horizontal slice --> color slice --> ...
read() can receive one value or one slice to select all the data from one file or a sequence of files;
or it can receive a tuple of up to three values (int or slice) to select a more specific sequence of pixels
of a group of images
or it can receive a tuple (int or slice) to select a more specific sequence of pixels of a group of images.
"""

if slice is None:
if slice is Ellipsis:
return self._seq.asarray()
if isinstance(slice, int):
# e.g. read(slice=0)
# e.g. read(slice=0) -- return an entire image
return tifffile.TiffFile(self._seq.files[slice]).asarray()
# e.g. read(slice=(...))
if isinstance(slice, builtins.slice):
# e.g. read(slice=(...)) -- return a slice along the image axis
return tifffile.TiffSequence(self._seq.files[slice]).asarray()
if isinstance(slice, tuple):
if len(slice) == 0:
return self._seq.asarray()
if len(slice) == 1:
return self.read(slice=slice[0])
image_axis, *the_rest = slice
# Could be int or slice
# (0, slice(...)) or (0,....) are converted to a list
# Could be int or slice (0, slice(...)) or (0,....); the_rest is converted to a list
if isinstance(image_axis, int):
# e.g. read(slice=(0, ....))
return tifffile.TiffFile(self._seq.files[image_axis]).asarray()
if isinstance(image_axis, builtins.slice):
if image_axis.start is None:
slice_start = 0
else:
slice_start = image_axis.start
if image_axis.step is None:
slice_step = 1
else:
slice_step = image_axis.step

arr = tifffile.TiffSequence(
self._seq.files[
slice_start : image_axis.stop : slice_step # noqa: E203
]
).asarray()
arr = arr[tuple(the_rest)]
return arr
if isinstance(slice, builtins.slice):
# Check for start and step which can be optional
if slice.start is None:
slice_start = 0
else:
slice_start = slice.start
if slice.step is None:
slice_step = 1
else:
slice_step = slice.step

arr = tifffile.TiffSequence(
self._seq.files[slice_start : slice.stop : slice_step] # noqa: E203
).asarray()
arr = tifffile.TiffFile(self._seq.files[image_axis]).asarray()
elif image_axis is Ellipsis:
# Return all images
arr = tifffile.TiffSequence(self._seq.files).asarray()
the_rest.insert(0, Ellipsis) # Include any leading dimensions
elif isinstance(image_axis, builtins.slice):
arr = self.read(slice=image_axis)
arr = np.atleast_1d(arr[tuple(the_rest)])
return arr

def read_block(self, block, slice=None):
if block[1:] != (0, 0):
if any(block[1:]):
# e.g. block[1:] != [0,0, ..., 0]
raise IndexError(block)
arr = self.read(builtins.slice(block[0], block[0] + 1))
if slice is not None:
Expand Down
81 changes: 49 additions & 32 deletions tiled/catalog/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,14 +561,8 @@ async def get_distinct(self, metadata, structure_families, specs, counts):

return data

async def create_node(
self,
structure_family,
metadata,
key=None,
specs=None,
data_sources=None,
):
@property
def insert(self):
# The only way to do "insert if does not exist" i.e. ON CONFLICT
# is to invoke dialect-specific insert.
if self.context.engine.dialect.name == "sqlite":
Expand All @@ -578,6 +572,16 @@ async def create_node(
else:
assert False # future-proofing

return insert

async def create_node(
self,
structure_family,
metadata,
key=None,
specs=None,
data_sources=None,
):
key = key or self.context.key_maker()
data_sources = data_sources or []

Expand Down Expand Up @@ -636,6 +640,7 @@ async def create_node(
"is not one that the Tiled server knows how to read."
),
)

if data_source.structure is None:
structure_id = None
else:
Expand All @@ -646,7 +651,7 @@ async def create_node(
)
structure_id = compute_structure_id(structure)
statement = (
insert(orm.Structure).values(
self.insert(orm.Structure).values(
id=structure_id,
structure=structure,
)
Expand All @@ -663,20 +668,7 @@ async def create_node(
node.data_sources.append(data_source_orm)
await db.flush() # Get data_source_orm.id.
for asset in data_source.assets:
# Find an asset_id if it exists, otherwise create a new one
statement = select(orm.Asset.id).where(
orm.Asset.data_uri == asset.data_uri
)
result = await db.execute(statement)
if row := result.fetchone():
(asset_id,) = row
else:
statement = insert(orm.Asset).values(
data_uri=asset.data_uri,
is_directory=asset.is_directory,
)
result = await db.execute(statement)
(asset_id,) = result.inserted_primary_key
asset_id = await self._put_asset(db, asset)
assoc_orm = orm.DataSourceAssetAssociation(
asset_id=asset_id,
data_source_id=data_source_orm.id,
Expand All @@ -701,23 +693,31 @@ async def create_node(
self.context, refreshed_node, access_policy=self.access_policy
)

async def _put_asset(self, db, asset):
# Find an asset_id if it exists, otherwise create a new one
statement = select(orm.Asset.id).where(orm.Asset.data_uri == asset.data_uri)
result = await db.execute(statement)
if row := result.fetchone():
(asset_id,) = row
else:
statement = self.insert(orm.Asset).values(
data_uri=asset.data_uri,
is_directory=asset.is_directory,
)
result = await db.execute(statement)
(asset_id,) = result.inserted_primary_key

return asset_id

async def put_data_source(self, data_source):
# Obtain and hash the canonical (RFC 8785) representation of
# the JSON structure.
structure = _prepare_structure(
data_source.structure_family, data_source.structure
)
structure_id = compute_structure_id(structure)
# The only way to do "insert if does not exist" i.e. ON CONFLICT
# is to invoke dialect-specific insert.
if self.context.engine.dialect.name == "sqlite":
from sqlalchemy.dialects.sqlite import insert
elif self.context.engine.dialect.name == "postgresql":
from sqlalchemy.dialects.postgresql import insert
else:
assert False # future-proofing
statement = (
insert(orm.Structure).values(
self.insert(orm.Structure).values(
id=structure_id,
structure=structure,
)
Expand All @@ -741,6 +741,23 @@ async def put_data_source(self, data_source):
status_code=404,
detail=f"No data_source {data_source.id} on this node.",
)
# Add assets and associate them with the data_source
for asset in data_source.assets:
asset_id = await self._put_asset(db, asset)
statement = select(orm.DataSourceAssetAssociation).where(
(orm.DataSourceAssetAssociation.data_source_id == data_source.id)
& (orm.DataSourceAssetAssociation.asset_id == asset_id)
)
result = await db.execute(statement)
if not result.fetchone():
assoc_orm = orm.DataSourceAssetAssociation(
asset_id=asset_id,
data_source_id=data_source.id,
parameter=asset.parameter,
num=asset.num,
)
db.add(assoc_orm)

await db.commit()

# async def patch_node(datasources=None):
Expand Down

0 comments on commit 0753fe5

Please sign in to comment.