diff --git a/gunpowder/build.py b/gunpowder/build.py index 6dbb528c..fd545c89 100644 --- a/gunpowder/build.py +++ b/gunpowder/build.py @@ -25,36 +25,17 @@ def __exit__(self, type, value, traceback): logger.debug("tear down completed") -import neuroglancer -from .neuroglancer.event import step_next +from .observers import NeuroglancerObserver class build_neuroglancer(object): def __init__(self, pipeline): self.pipeline = pipeline + self.observer = NeuroglancerObserver("neuroglancer", pipeline) def __enter__(self): - neuroglancer.set_server_bind_address("0.0.0.0") - viewer = neuroglancer.Viewer() - - viewer.actions.add("continue", step_next) - - with viewer.config_state.txn() as s: - s.input_event_bindings.data_view["keyt"] = "continue" - with viewer.txn() as s: - s.layout = neuroglancer.row_layout( - [ - neuroglancer.column_layout( - [ - neuroglancer.LayerGroupViewer(layers=[]), - neuroglancer.LayerGroupViewer(layers=[]), - ] - ), - ] - ) - try: - self.pipeline.setup(viewer) + self.pipeline.setup([self.observer]) except: logger.error( "something went wrong during the setup of the pipeline, calling tear down" @@ -63,7 +44,6 @@ def __enter__(self): logger.debug("tear down completed") raise - print(viewer) return self.pipeline def __exit__(self, type, value, traceback): diff --git a/gunpowder/neuroglancer/add_layer.py b/gunpowder/neuroglancer/add_layer.py index 1a77098a..04097a7f 100644 --- a/gunpowder/neuroglancer/add_layer.py +++ b/gunpowder/neuroglancer/add_layer.py @@ -52,10 +52,6 @@ def parse_dims(array): spatial_dims = array.spec.roi.dims channel_dims = dims - spatial_dims - print("dims :", dims) - print("spatial dims:", spatial_dims) - print("channel dims:", channel_dims) - return dims, spatial_dims, channel_dims @@ -73,10 +69,6 @@ def create_coordinate_space(array, spatial_dim_names, channel_dim_names, unit): units = [""] * channel_dims + [unit] * spatial_dims scales = [1] * channel_dims + list(array.spec.voxel_size) - print("Names :", names) - print("Units :", units) - print("Scales :", scales) - return neuroglancer.CoordinateSpace( names=names, units=units, diff --git a/gunpowder/nodes/batch_provider.py b/gunpowder/nodes/batch_provider.py index 3faf694c..055cf139 100644 --- a/gunpowder/nodes/batch_provider.py +++ b/gunpowder/nodes/batch_provider.py @@ -12,8 +12,7 @@ from gunpowder.array_spec import ArraySpec from gunpowder.graph import GraphKey from gunpowder.graph_spec import GraphSpec -from gunpowder.neuroglancer.event import wait_for_step -from gunpowder.neuroglancer.visualize import visualize +from gunpowder.observers import Observer logger = logging.getLogger(__name__) @@ -54,7 +53,13 @@ class BatchProvider(object): instead. """ - viewer = None + _observers = None + + @property + def observers(self): + if self._observers is None: + self._observers = [] + return self._observers def add_upstream_provider(self, provider): self.get_upstream_providers().append(provider) @@ -219,20 +224,23 @@ def request_batch(self, request): return batch - def setup_viewer(self, viewer): - self.viewer = viewer + def register_observer(self, observer: Observer): + self.observers.append(observer) + self.observe_sources(observer) + + def observe_sources(self, observer): + """ + to be implemented in subclasses + """ + pass def observe_request(self, request): - if self.viewer is not None: - print("Waiting for step...") - visualize(self.viewer, request) - wait_for_step() + for observer in self.observers: + observer.update(request, self) def observe_batch(self, batch): - if self.viewer is not None: - print("Waiting for step...") - visualize(self.viewer, batch) - wait_for_step() + for observer in self.observers: + observer.update(batch, self) def set_seeds(self, request): seed = request.random_seed diff --git a/gunpowder/nodes/zarr_source.py b/gunpowder/nodes/zarr_source.py index 519308f9..aa40a54c 100644 --- a/gunpowder/nodes/zarr_source.py +++ b/gunpowder/nodes/zarr_source.py @@ -125,16 +125,16 @@ def setup(self): self.provides(array_key, spec) - def setup_viewer(self, viewer): - self.viewer = viewer - with viewer.txn() as s: - with self._open_file(self.store) as data_file: - for array_key, ds_name in self.datasets.items(): - if ds_name not in data_file: - raise RuntimeError("%s not in %s" % (ds_name, self.store)) - - spec = self.__read_spec(array_key, data_file, ds_name) - add_layer(s, Array(data_file[ds_name], spec), f"{array_key}_SOURCE") + def observe_sources(self, observer): + with self._open_file(self.store) as data_file: + for array_key, ds_name in self.datasets.items(): + if ds_name not in data_file: + raise RuntimeError("%s not in %s" % (ds_name, self.store)) + + spec = self.__read_spec(array_key, data_file, ds_name) + observer.add_source( + Array(data_file[ds_name], spec), f"{array_key}_SOURCE" + ) def provide(self, request): timing = Timing(self) diff --git a/gunpowder/observers.py b/gunpowder/observers.py new file mode 100644 index 00000000..3199a118 --- /dev/null +++ b/gunpowder/observers.py @@ -0,0 +1,109 @@ +from .batch import Batch +from .batch_request import BatchRequest +from .neuroglancer.event import step_next +from .neuroglancer.event import wait_for_step +from .neuroglancer.visualize import visualize +from .neuroglancer.add_layer import add_layer + +# from .nodes import BatchProvider + +import neuroglancer + +from abc import ABC +from typing import Optional + + +class Observer(ABC): + def __init__(self, name, pipeline): + self.name = name + self.pipeline = pipeline + + def update(self, request_or_batch: BatchRequest or Batch): + """ + Take a BatchRequest or Batch and update the observer's state with + their contents + """ + pass + + def add_source(self, *args, **kwargs): + """ + Add a source to the observer. This is a no-op for observers that do not + provide an array source. + """ + pass + + +class NeuroglancerObserver(Observer): + def __init__(self, name, pipeline, host="0.0.0.0", port=0): + super().__init__(name, pipeline) + self.host = host + self.port = port + + neuroglancer.set_server_bind_address(self.host, self.port) + self.viewer = neuroglancer.Viewer() + self.viewer.actions.add("continue", step_next) + + with self.viewer.config_state.txn() as s: + s.input_event_bindings.data_view["keyt"] = "continue" + + with self.viewer.txn() as s: + s.layout = neuroglancer.row_layout( + [ + neuroglancer.column_layout( + [ + neuroglancer.LayerGroupViewer(layers=[]), + neuroglancer.LayerGroupViewer(layers=[]), + ] + ), + ] + ) + + print(self.viewer) + print("Hit T in neuroglancer viewer to step through the pipeline") + + def update(self, request_or_batch: BatchRequest or Batch, node: Optional = None): + visualize(self.viewer, request_or_batch) + string = self.pipeline.to_string(bold=node) + print( + "\r" + + ( + "REQUESTING: " + if isinstance(request_or_batch, BatchRequest) + else "PROVIDING: " + ) + + string + + " " * 2, + end="", + ) + # print(self.pipeline.to_string(bold=node)) + wait_for_step() + + def add_source( + self, + array, + name, + ): + spatial_dim_names = ["t", "z", "y", "x"] + channel_dim_names = ["b^", "c^"] + opacity = None + shader = None + rgb_channels = None + color = None + visible = True + value_scale_factor = 1.0 + units = "nm" + with self.viewer.txn() as s: + add_layer( + s, + array, + name, + spatial_dim_names, + channel_dim_names, + opacity, + shader, + rgb_channels, + color, + visible, + value_scale_factor, + units, + ) diff --git a/gunpowder/pipeline.py b/gunpowder/pipeline.py index 73d5dc2f..f3a70803 100644 --- a/gunpowder/pipeline.py +++ b/gunpowder/pipeline.py @@ -1,8 +1,10 @@ from gunpowder.nodes import BatchProvider from gunpowder.nodes.batch_provider import BatchRequestError +from .observers import Observer import logging import traceback +from typing import Optional logger = logging.getLogger(__name__) @@ -77,10 +79,12 @@ def copy(self): return pipeline - def setup(self, viewer=None): + def setup(self, observers: Optional[list[Observer]] = None): """Connect all batch providers in the pipeline and call setup for each, from source to sink.""" + observers = observers if observers is not None else [] + def connect(node): for child in node.children: node.output.add_upstream_provider(child.output) @@ -94,8 +98,8 @@ def connect(node): def node_setup(node): try: node.output.setup() - if viewer is not None: - node.output.setup_viewer(viewer) + for observer in observers: + node.output.register_observer(observer) except Exception as e: raise PipelineSetupError(node.output) from e @@ -198,6 +202,17 @@ def to_string(node): reprs = self.traverse(to_string, reverse=True) return self.__rec_repr__(reprs) + + def to_string(self, bold=None): + def to_string(node): + if node.output == bold: + return f"\033[1m{node.output.name()}\033[0m" + else: + return node.output.name() + + reprs = self.traverse(to_string, reverse=True) + + return self.__rec_repr__(reprs) def __rec_repr__(self, reprs): if not isinstance(reprs, list): diff --git a/neuroglancer_fun.py b/neuroglancer_fun.py new file mode 100644 index 00000000..92edd4b3 --- /dev/null +++ b/neuroglancer_fun.py @@ -0,0 +1,64 @@ +import matplotlib.pyplot as plt +import numpy as np +import random +import zarr +import torch +from skimage import data +from skimage import filters + +# make sure we all see the same +torch.manual_seed(1961923) +np.random.seed(1961923) +random.seed(1961923) + +# open a sample image (channels first) +raw_data = data.astronaut().transpose(2, 0, 1) + +# create some dummy "ground-truth" to train on +gt_data = filters.gaussian(raw_data[0], sigma=3.0) > 0.75 +gt_data = gt_data[np.newaxis, :].astype(np.float32) + +# store image in zarr container +f = zarr.open("sample_data.zarr", "w") +f["raw"] = raw_data +f["raw"].attrs["resolution"] = (1, 1) +f["ground_truth"] = gt_data +f["ground_truth"].attrs["resolution"] = (1, 1) + +import gunpowder as gp + +# declare arrays to use in the pipeline +raw = gp.ArrayKey("RAW") +gt = gp.ArrayKey("GT") + +# create "pipeline" consisting only of a data source +source = gp.ZarrSource( + "sample_data.zarr", # the zarr container + {raw: "raw", gt: "ground_truth"}, # which dataset to associate to the array key + { + raw: gp.ArraySpec(interpolatable=True), + gt: gp.ArraySpec(interpolatable=False), + }, # meta-information +) +pipeline = source +pipeline += gp.Normalize(raw) +pipeline += gp.RandomLocation() +pipeline += gp.DeformAugment( + gp.Coordinate(5, 5), + gp.Coordinate(2, 2), + graph_raster_voxel_size=gp.Coordinate(1, 1), +) + +# formulate a request for "raw" +request = gp.BatchRequest() +request.add(raw, gp.Coordinate(64, 64), gp.Coordinate(1, 1)) +request.add(gt, gp.Coordinate(32, 32), gp.Coordinate(1, 1)) + +# build the pipeline... +with gp.build_neuroglancer(pipeline): + for _ in range(10): + # ...and request a batch + batch = pipeline.request_batch(request) + +# show the content of the batch +print(f"batch returned: {batch}")