diff --git a/gunpowder/contrib/nodes/add_boundary_distance_gradients.py b/gunpowder/contrib/nodes/add_boundary_distance_gradients.py index b2897272..38306d23 100644 --- a/gunpowder/contrib/nodes/add_boundary_distance_gradients.py +++ b/gunpowder/contrib/nodes/add_boundary_distance_gradients.py @@ -4,7 +4,7 @@ from gunpowder.array import Array from gunpowder.batch_request import BatchRequest from gunpowder.nodes.batch_filter import BatchFilter -from scipy.ndimage.morphology import distance_transform_edt +from scipy.ndimage import distance_transform_edt logger = logging.getLogger(__name__) diff --git a/gunpowder/morphology.py b/gunpowder/morphology.py index cd32de10..80b57004 100644 --- a/gunpowder/morphology.py +++ b/gunpowder/morphology.py @@ -1,5 +1,5 @@ import numpy as np -from scipy.ndimage.morphology import distance_transform_edt +from scipy.ndimage import distance_transform_edt def enlarge_binary_map( diff --git a/gunpowder/nodes/__init__.py b/gunpowder/nodes/__init__.py index 1a152f8e..3c2a410f 100644 --- a/gunpowder/nodes/__init__.py +++ b/gunpowder/nodes/__init__.py @@ -45,3 +45,5 @@ from .upsample import UpSample from .zarr_source import ZarrSource from .zarr_write import ZarrWrite +from .gp_array_source import ArraySource as GPArraySource +from .gp_graph_source import GraphSource as GPGraphSource diff --git a/gunpowder/nodes/defect_augment.py b/gunpowder/nodes/defect_augment.py index c9ae4fe1..8ca6b6f3 100644 --- a/gunpowder/nodes/defect_augment.py +++ b/gunpowder/nodes/defect_augment.py @@ -4,9 +4,7 @@ # imports for deformed slice from skimage.draw import line -from scipy.ndimage.measurements import label -from scipy.ndimage.interpolation import map_coordinates -from scipy.ndimage.morphology import binary_dilation +from scipy.ndimage import label, map_coordinates, binary_dilation from gunpowder.batch_request import BatchRequest from gunpowder.coordinate import Coordinate diff --git a/gunpowder/nodes/deform_augment.py b/gunpowder/nodes/deform_augment.py index f606696d..c78ddc1f 100644 --- a/gunpowder/nodes/deform_augment.py +++ b/gunpowder/nodes/deform_augment.py @@ -119,18 +119,18 @@ def __init__( self.graph_raster_voxel_size = ( Coordinate(graph_raster_voxel_size) if graph_raster_voxel_size is not None - else None + else control_point_spacing * 0 + 1 ) - self.p = p - assert self.control_point_spacing.dims == self.jitter_sigma.dims, ( - self.control_point_spacing, - self.jitter_sigma, + assert ( + self.control_point_spacing.dims + == self.jitter_sigma.dims + == self.graph_raster_voxel_size.dims + ), ( + f"control_point_spacing: {self.control_point_spacing}, " + f"jitter_sigma: {self.jitter_sigma}, " + f"and graph_raster_voxel_size must have the same number of dimensions" ) - if self.graph_raster_voxel_size is not None: - assert self.graph_raster_voxel_size.dims == self.jitter_sigma.dims, ( - self.graph_raster_voxel_size, - self.jitter_sigma, - ) + self.p = p def setup(self): if self.transform_key is not None: @@ -506,6 +506,9 @@ def __create_transformation(self, target_spec: ArraySpec): local_transformation += rot_transformation if self.subsample > 1: + global_transformation = upscale_transformation( + global_transformation, target_shape + ) local_transformation = upscale_transformation( local_transformation, target_shape ) diff --git a/gunpowder/nodes/elastic_augment.py b/gunpowder/nodes/elastic_augment.py index 0beb3c57..c5da5f81 100644 --- a/gunpowder/nodes/elastic_augment.py +++ b/gunpowder/nodes/elastic_augment.py @@ -461,7 +461,7 @@ def __fast_point_projection(self, transformation, nodes, source_roi, target_roi) ) missing_points = [] - projected_locs = ndimage.measurements.center_of_mass(data > 0, data, ids) + projected_locs = ndimage.center_of_mass(data > 0, data, ids) projected_locs = [ np.array(loc[-self.spatial_dims :]) * self.voxel_size + target_roi.begin for loc in projected_locs diff --git a/gunpowder/nodes/exclude_labels.py b/gunpowder/nodes/exclude_labels.py index 2592dc25..cb812def 100644 --- a/gunpowder/nodes/exclude_labels.py +++ b/gunpowder/nodes/exclude_labels.py @@ -1,6 +1,6 @@ import logging import numpy as np -from scipy.ndimage.morphology import distance_transform_edt +from scipy.ndimage import distance_transform_edt from .batch_filter import BatchFilter from gunpowder.array import Array diff --git a/gunpowder/nodes/gp_array_source.py b/gunpowder/nodes/gp_array_source.py new file mode 100644 index 00000000..0277abd3 --- /dev/null +++ b/gunpowder/nodes/gp_array_source.py @@ -0,0 +1,24 @@ +import copy +from typing import TYPE_CHECKING + +from .batch_provider import BatchProvider + +if TYPE_CHECKING: + from gunpowder import Array, ArrayKey, Batch + + +class ArraySource(BatchProvider): + def __init__(self, key: "ArrayKey", array: "Array"): + self.key = key + self.array = array + + def setup(self): + self.provides(self.key, self.array.spec.copy()) + + def provide(self, request): + outputs = Batch() + if self.array.spec.nonspatial: + outputs[self.key] = copy.deepcopy(self.array) + else: + outputs[self.key] = copy.deepcopy(self.array.crop(request[self.key].roi)) + return outputs diff --git a/gunpowder/nodes/gp_graph_source.py b/gunpowder/nodes/gp_graph_source.py new file mode 100644 index 00000000..f7fb4aa1 --- /dev/null +++ b/gunpowder/nodes/gp_graph_source.py @@ -0,0 +1,23 @@ +import copy +from typing import TYPE_CHECKING + +from .batch_provider import BatchProvider + +if TYPE_CHECKING: + from gunpowder import Batch, Graph, GraphKey + + +class GraphSource(BatchProvider): + def __init__(self, key: "GraphKey", graph: "Graph"): + self.key = key + self.graph = graph + + def setup(self): + self.provides(self.key, self.graph.spec) + + def provide(self, request): + outputs = Batch() + outputs[self.key] = copy.deepcopy( + self.graph.crop(request[self.key].roi).trim(request[self.key].roi) + ) + return outputs diff --git a/gunpowder/nodes/random_location.py b/gunpowder/nodes/random_location.py index d5b6c1e2..192d7722 100644 --- a/gunpowder/nodes/random_location.py +++ b/gunpowder/nodes/random_location.py @@ -12,6 +12,7 @@ from gunpowder.array import Array from gunpowder.array_spec import ArraySpec from .batch_filter import BatchFilter +from gunpowder.profiling import Timing logger = logging.getLogger(__name__) @@ -210,6 +211,35 @@ def prepare(self, request): return request + def provide(self, request): + + timing_prepare = Timing(self, "prepare") + timing_prepare.start() + + downstream_request = request.copy() + + self.prepare(request) + + self.remove_provided(request) + + timing_prepare.stop() + + batch = self.get_upstream_provider().request_batch(request) + + timing_process = Timing(self, "process") + timing_process.start() + + downstream_request.remove_placeholders() + + self.process(batch, downstream_request) + + timing_process.stop() + + batch.profiling_stats.add(timing_prepare) + batch.profiling_stats.add(timing_process) + + return batch + def process(self, batch, request): if self.random_shift_key is not None: batch[self.random_shift_key] = Array( @@ -435,7 +465,6 @@ def __select_random_location_with_points( point, points, ) - num_points = len(points) return random_shift diff --git a/gunpowder/nodes/rasterize_graph.py b/gunpowder/nodes/rasterize_graph.py index de6c4436..8301b9ae 100644 --- a/gunpowder/nodes/rasterize_graph.py +++ b/gunpowder/nodes/rasterize_graph.py @@ -1,6 +1,6 @@ import logging import numpy as np -from scipy.ndimage.filters import gaussian_filter +from scipy.ndimage import gaussian_filter from skimage import draw from .batch_filter import BatchFilter diff --git a/gunpowder/nodes/snapshot.py b/gunpowder/nodes/snapshot.py index 78ba8f0c..acc3f624 100644 --- a/gunpowder/nodes/snapshot.py +++ b/gunpowder/nodes/snapshot.py @@ -11,7 +11,7 @@ class Snapshot(BatchFilter): - """Save a passing batch in an HDF file. + """Save a passing batch in an HDF or Zarr file. The default behaviour is to periodically save a snapshot after ``every`` iterations. @@ -37,7 +37,9 @@ class Snapshot(BatchFilter): Template for output filenames. ``{id}`` in the string will be replaced with the ID of the batch. ``{iteration}`` with the training - iteration (if training was performed on this batch). + iteration (if training was performed on this batch). Snapshot will + be saved as zarr file if output_filename ends in ``.zarr`` and as + HDF otherwise. every (``int``): @@ -209,8 +211,8 @@ def process(self, batch, request): if self.store_value_range: dataset.attrs["value_range"] = ( - np.asscalar(array.data.min()), - np.asscalar(array.data.max()), + array.data.min().item(), + array.data.max().item(), ) # if array has attributes, add them to the dataset diff --git a/gunpowder/version_info.py b/gunpowder/version_info.py index e45efbfd..1aa73336 100644 --- a/gunpowder/version_info.py +++ b/gunpowder/version_info.py @@ -1,6 +1,6 @@ __major__ = 1 -__minor__ = 3 -__patch__ = 2 +__minor__ = 4 +__patch__ = 0 __tag__ = "" __version__ = "{}.{}.{}{}".format(__major__, __minor__, __patch__, __tag__).strip(".") diff --git a/pyproject.toml b/pyproject.toml index 01441435..ead9be13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,12 @@ tensorflow = [ 'tensorflow<2.0; python_version<"3.8"', # https://stackoverflow.com/a/72493690 'protobuf==3.20.*; python_version=="3.7"', ] +jax = [ + 'jax', + 'jaxlib', + 'haiku', + 'optax', +] full = [ 'torch', 'tensorflow<2.0; python_version<"3.8"', @@ -74,3 +80,7 @@ target_version = ['py38', 'py39', 'py310'] [tool.setuptools.packages.find] include = ["gunpowder*"] + +[tool.ruff] +# pyflakes, pycodestyle, isort +select = ["F", "E", "W", "I001"] \ No newline at end of file diff --git a/tests/cases/deform_augment.py b/tests/cases/deform_augment.py index 17176479..904a73f9 100644 --- a/tests/cases/deform_augment.py +++ b/tests/cases/deform_augment.py @@ -123,19 +123,20 @@ def provide(self, request): @pytest.mark.parametrize("rotate", [True, False]) @pytest.mark.parametrize("spatial_dims", [2, 3]) @pytest.mark.parametrize("fast_points", [True, False]) -def test_3d_basics(rotate, spatial_dims, fast_points): +@pytest.mark.parametrize("subsampling", [1, 2, 4]) +def test_3d_basics(rotate, spatial_dims, fast_points, subsampling): test_labels = ArrayKey("TEST_LABELS") test_labels2 = ArrayKey("TEST_LABELS2") test_graph = GraphKey("TEST_GRAPH") pipeline = GraphTestSource3D(test_graph, test_labels, test_labels2) + DeformAugment( - [10] * spatial_dims, + [4] * spatial_dims, [1] * spatial_dims, graph_raster_voxel_size=[1] * spatial_dims, rotate=rotate, spatial_dims=spatial_dims, use_fast_points_transform=fast_points, - subsample=2, + subsample=subsampling, ) for _ in range(5): @@ -152,7 +153,9 @@ def test_3d_basics(rotate, spatial_dims, fast_points): labels2 = batch[test_labels2] graph = batch[test_graph] - assert Node(id=1, location=np.array([0, 0, 0])) in list(graph.nodes) + assert Node(id=1, location=np.array([0, 0, 0])) in list(graph.nodes), list( + graph.nodes + ) # graph should have moved together with the voxels for node in graph.nodes: diff --git a/tests/cases/random_location.py b/tests/cases/random_location.py index b505c603..65bef230 100644 --- a/tests/cases/random_location.py +++ b/tests/cases/random_location.py @@ -60,15 +60,16 @@ def test_random_shift(): a = ArrayKey("A") b = ArrayKey("B") random_shift_key = ArrayKey("RANDOM_SHIFT") - source_a = ExampleSourceRandomLocation(a) - source_b = ExampleSourceRandomLocation(b) pipeline = ( - (source_a, source_b) + (ExampleSourceRandomLocation(a), ExampleSourceRandomLocation(b)) + MergeProvider() - + CustomRandomLocation(a, random_store_key=random_shift_key) + + CustomRandomLocation(a, random_shift_key=random_shift_key) ) - pipeline_no_random = (source_a, source_b) + MergeProvider() + pipeline_no_random = ( + ExampleSourceRandomLocation(a), + ExampleSourceRandomLocation(b), + ) + MergeProvider() with build(pipeline), build(pipeline_no_random): sums = set() @@ -96,7 +97,6 @@ def test_random_shift(): b: ArraySpec( roi=Roi(batch[random_shift_key].data, (20, 20, 20)) ), - random_shift_key: ArraySpec(nonspatial=True), } ) ) @@ -106,8 +106,8 @@ def test_random_shift(): sums.add(batch[a].data.sum()) # Request a ROI with the same shape as the entire ROI - full_roi_a = Roi((0, 0, 0), source_a.roi.shape) - full_roi_b = Roi((0, 0, 0), source_b.roi.shape) + full_roi_a = Roi((0, 0, 0), ExampleSourceRandomLocation(a).roi.shape) + full_roi_b = Roi((0, 0, 0), ExampleSourceRandomLocation(b).roi.shape) batch = pipeline.request_batch( BatchRequest( {a: ArraySpec(roi=full_roi_a), b: ArraySpec(roi=full_roi_b)}