Skip to content

Commit

Permalink
Merge branch 'main' into dev-v1.4
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Jun 13, 2024
2 parents 832e475 + 185b1f0 commit 68917db
Show file tree
Hide file tree
Showing 16 changed files with 131 additions and 37 deletions.
2 changes: 1 addition & 1 deletion gunpowder/contrib/nodes/add_boundary_distance_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from gunpowder.array import Array
from gunpowder.batch_request import BatchRequest
from gunpowder.nodes.batch_filter import BatchFilter
from scipy.ndimage.morphology import distance_transform_edt
from scipy.ndimage import distance_transform_edt

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion gunpowder/morphology.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from scipy.ndimage.morphology import distance_transform_edt
from scipy.ndimage import distance_transform_edt


def enlarge_binary_map(
Expand Down
2 changes: 2 additions & 0 deletions gunpowder/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,5 @@
from .upsample import UpSample
from .zarr_source import ZarrSource
from .zarr_write import ZarrWrite
from .gp_array_source import ArraySource as GPArraySource
from .gp_graph_source import GraphSource as GPGraphSource
4 changes: 1 addition & 3 deletions gunpowder/nodes/defect_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

# imports for deformed slice
from skimage.draw import line
from scipy.ndimage.measurements import label
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.morphology import binary_dilation
from scipy.ndimage import label, map_coordinates, binary_dilation

from gunpowder.batch_request import BatchRequest
from gunpowder.coordinate import Coordinate
Expand Down
23 changes: 13 additions & 10 deletions gunpowder/nodes/deform_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,18 @@ def __init__(
self.graph_raster_voxel_size = (
Coordinate(graph_raster_voxel_size)
if graph_raster_voxel_size is not None
else None
else control_point_spacing * 0 + 1
)
self.p = p
assert self.control_point_spacing.dims == self.jitter_sigma.dims, (
self.control_point_spacing,
self.jitter_sigma,
assert (
self.control_point_spacing.dims
== self.jitter_sigma.dims
== self.graph_raster_voxel_size.dims
), (
f"control_point_spacing: {self.control_point_spacing}, "
f"jitter_sigma: {self.jitter_sigma}, "
f"and graph_raster_voxel_size must have the same number of dimensions"
)
if self.graph_raster_voxel_size is not None:
assert self.graph_raster_voxel_size.dims == self.jitter_sigma.dims, (
self.graph_raster_voxel_size,
self.jitter_sigma,
)
self.p = p

def setup(self):
if self.transform_key is not None:
Expand Down Expand Up @@ -506,6 +506,9 @@ def __create_transformation(self, target_spec: ArraySpec):
local_transformation += rot_transformation

if self.subsample > 1:
global_transformation = upscale_transformation(
global_transformation, target_shape
)
local_transformation = upscale_transformation(
local_transformation, target_shape
)
Expand Down
2 changes: 1 addition & 1 deletion gunpowder/nodes/elastic_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def __fast_point_projection(self, transformation, nodes, source_roi, target_roi)
)

missing_points = []
projected_locs = ndimage.measurements.center_of_mass(data > 0, data, ids)
projected_locs = ndimage.center_of_mass(data > 0, data, ids)
projected_locs = [
np.array(loc[-self.spatial_dims :]) * self.voxel_size + target_roi.begin
for loc in projected_locs
Expand Down
2 changes: 1 addition & 1 deletion gunpowder/nodes/exclude_labels.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import numpy as np
from scipy.ndimage.morphology import distance_transform_edt
from scipy.ndimage import distance_transform_edt

from .batch_filter import BatchFilter
from gunpowder.array import Array
Expand Down
24 changes: 24 additions & 0 deletions gunpowder/nodes/gp_array_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import copy
from typing import TYPE_CHECKING

from .batch_provider import BatchProvider

if TYPE_CHECKING:
from gunpowder import Array, ArrayKey, Batch


class ArraySource(BatchProvider):
def __init__(self, key: "ArrayKey", array: "Array"):
self.key = key
self.array = array

def setup(self):
self.provides(self.key, self.array.spec.copy())

def provide(self, request):
outputs = Batch()
if self.array.spec.nonspatial:
outputs[self.key] = copy.deepcopy(self.array)
else:
outputs[self.key] = copy.deepcopy(self.array.crop(request[self.key].roi))
return outputs
23 changes: 23 additions & 0 deletions gunpowder/nodes/gp_graph_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import copy
from typing import TYPE_CHECKING

from .batch_provider import BatchProvider

if TYPE_CHECKING:
from gunpowder import Batch, Graph, GraphKey


class GraphSource(BatchProvider):
def __init__(self, key: "GraphKey", graph: "Graph"):
self.key = key
self.graph = graph

def setup(self):
self.provides(self.key, self.graph.spec)

def provide(self, request):
outputs = Batch()
outputs[self.key] = copy.deepcopy(
self.graph.crop(request[self.key].roi).trim(request[self.key].roi)
)
return outputs
31 changes: 30 additions & 1 deletion gunpowder/nodes/random_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from gunpowder.array import Array
from gunpowder.array_spec import ArraySpec
from .batch_filter import BatchFilter
from gunpowder.profiling import Timing

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -210,6 +211,35 @@ def prepare(self, request):

return request

def provide(self, request):

timing_prepare = Timing(self, "prepare")
timing_prepare.start()

downstream_request = request.copy()

self.prepare(request)

self.remove_provided(request)

timing_prepare.stop()

batch = self.get_upstream_provider().request_batch(request)

timing_process = Timing(self, "process")
timing_process.start()

downstream_request.remove_placeholders()

self.process(batch, downstream_request)

timing_process.stop()

batch.profiling_stats.add(timing_prepare)
batch.profiling_stats.add(timing_process)

return batch

def process(self, batch, request):
if self.random_shift_key is not None:
batch[self.random_shift_key] = Array(
Expand Down Expand Up @@ -435,7 +465,6 @@ def __select_random_location_with_points(
point,
points,
)
num_points = len(points)

return random_shift

Expand Down
2 changes: 1 addition & 1 deletion gunpowder/nodes/rasterize_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import numpy as np
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage import gaussian_filter
from skimage import draw

from .batch_filter import BatchFilter
Expand Down
10 changes: 6 additions & 4 deletions gunpowder/nodes/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class Snapshot(BatchFilter):
"""Save a passing batch in an HDF file.
"""Save a passing batch in an HDF or Zarr file.
The default behaviour is to periodically save a snapshot after
``every`` iterations.
Expand All @@ -37,7 +37,9 @@ class Snapshot(BatchFilter):
Template for output filenames. ``{id}`` in the string will be
replaced with the ID of the batch. ``{iteration}`` with the training
iteration (if training was performed on this batch).
iteration (if training was performed on this batch). Snapshot will
be saved as zarr file if output_filename ends in ``.zarr`` and as
HDF otherwise.
every (``int``):
Expand Down Expand Up @@ -209,8 +211,8 @@ def process(self, batch, request):

if self.store_value_range:
dataset.attrs["value_range"] = (
np.asscalar(array.data.min()),
np.asscalar(array.data.max()),
array.data.min().item(),
array.data.max().item(),
)

# if array has attributes, add them to the dataset
Expand Down
4 changes: 2 additions & 2 deletions gunpowder/version_info.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__major__ = 1
__minor__ = 3
__patch__ = 2
__minor__ = 4
__patch__ = 0
__tag__ = ""
__version__ = "{}.{}.{}{}".format(__major__, __minor__, __patch__, __tag__).strip(".")

Expand Down
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ tensorflow = [
'tensorflow<2.0; python_version<"3.8"', # https://stackoverflow.com/a/72493690
'protobuf==3.20.*; python_version=="3.7"',
]
jax = [
'jax',
'jaxlib',
'haiku',
'optax',
]
full = [
'torch',
'tensorflow<2.0; python_version<"3.8"',
Expand All @@ -74,3 +80,7 @@ target_version = ['py38', 'py39', 'py310']

[tool.setuptools.packages.find]
include = ["gunpowder*"]

[tool.ruff]
# pyflakes, pycodestyle, isort
select = ["F", "E", "W", "I001"]
11 changes: 7 additions & 4 deletions tests/cases/deform_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,20 @@ def provide(self, request):
@pytest.mark.parametrize("rotate", [True, False])
@pytest.mark.parametrize("spatial_dims", [2, 3])
@pytest.mark.parametrize("fast_points", [True, False])
def test_3d_basics(rotate, spatial_dims, fast_points):
@pytest.mark.parametrize("subsampling", [1, 2, 4])
def test_3d_basics(rotate, spatial_dims, fast_points, subsampling):
test_labels = ArrayKey("TEST_LABELS")
test_labels2 = ArrayKey("TEST_LABELS2")
test_graph = GraphKey("TEST_GRAPH")

pipeline = GraphTestSource3D(test_graph, test_labels, test_labels2) + DeformAugment(
[10] * spatial_dims,
[4] * spatial_dims,
[1] * spatial_dims,
graph_raster_voxel_size=[1] * spatial_dims,
rotate=rotate,
spatial_dims=spatial_dims,
use_fast_points_transform=fast_points,
subsample=2,
subsample=subsampling,
)

for _ in range(5):
Expand All @@ -152,7 +153,9 @@ def test_3d_basics(rotate, spatial_dims, fast_points):
labels2 = batch[test_labels2]
graph = batch[test_graph]

assert Node(id=1, location=np.array([0, 0, 0])) in list(graph.nodes)
assert Node(id=1, location=np.array([0, 0, 0])) in list(graph.nodes), list(
graph.nodes
)

# graph should have moved together with the voxels
for node in graph.nodes:
Expand Down
16 changes: 8 additions & 8 deletions tests/cases/random_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,16 @@ def test_random_shift():
a = ArrayKey("A")
b = ArrayKey("B")
random_shift_key = ArrayKey("RANDOM_SHIFT")
source_a = ExampleSourceRandomLocation(a)
source_b = ExampleSourceRandomLocation(b)

pipeline = (
(source_a, source_b)
(ExampleSourceRandomLocation(a), ExampleSourceRandomLocation(b))
+ MergeProvider()
+ CustomRandomLocation(a, random_store_key=random_shift_key)
+ CustomRandomLocation(a, random_shift_key=random_shift_key)
)
pipeline_no_random = (source_a, source_b) + MergeProvider()
pipeline_no_random = (
ExampleSourceRandomLocation(a),
ExampleSourceRandomLocation(b),
) + MergeProvider()

with build(pipeline), build(pipeline_no_random):
sums = set()
Expand Down Expand Up @@ -96,7 +97,6 @@ def test_random_shift():
b: ArraySpec(
roi=Roi(batch[random_shift_key].data, (20, 20, 20))
),
random_shift_key: ArraySpec(nonspatial=True),
}
)
)
Expand All @@ -106,8 +106,8 @@ def test_random_shift():
sums.add(batch[a].data.sum())

# Request a ROI with the same shape as the entire ROI
full_roi_a = Roi((0, 0, 0), source_a.roi.shape)
full_roi_b = Roi((0, 0, 0), source_b.roi.shape)
full_roi_a = Roi((0, 0, 0), ExampleSourceRandomLocation(a).roi.shape)
full_roi_b = Roi((0, 0, 0), ExampleSourceRandomLocation(b).roi.shape)
batch = pipeline.request_batch(
BatchRequest(
{a: ArraySpec(roi=full_roi_a), b: ArraySpec(roi=full_roi_b)}
Expand Down

0 comments on commit 68917db

Please sign in to comment.