diff --git a/src/autoseg/train_job.py b/src/autoseg/train_job.py index 1508f30..a42bdf6 100644 --- a/src/autoseg/train_job.py +++ b/src/autoseg/train_job.py @@ -7,15 +7,20 @@ def train_model( iterations: int = 100000, warmup: int = 100000, raw_file: str = "path/to/.zarr/or/.n5/or/.tiff", + rewrite_file: str = "./rewritten.zarr", + rewrite_ds: str = "rewritten_volume", out_file: str = "./raw_predictions.zarr", voxel_size: int = 33, save_every=2500, ) -> None: - # TODO: call ztools to rewrite .tiff file to zarr format + # TODO: add util funcs for generating masks if raw_file.endswith(".tiff"): - tiff_to_zarr(tiff_file=raw_file) - # raw_file: str = TODO: reassign raw file name + tiff_to_zarr(tiff_file=raw_file, + out_file=rewrite_file, + out_ds=rewrite_ds) + + model_type = model_type.lower() if model_type == "mtlsd": mtlsd_train( diff --git a/src/autoseg/utils.py b/src/autoseg/utils.py index 672cb06..0c8b2b5 100644 --- a/src/autoseg/utils.py +++ b/src/autoseg/utils.py @@ -2,6 +2,8 @@ from funlib.persistence import prepare_ds from funlib.geometry import Coordinate, Roi import tifffile +import numpy as np +import zarr neighborhood: list[list[int]] = [ @@ -51,3 +53,45 @@ def tiff_to_zarr(tiff_file:str="path/to/.tiff", ds[roi] = tiff_stack print("TIFF Image stack saved as Zarr dataset.") + + +def create_masks(raw_file: str, labels_ds: str) -> None: + f = zarr.open(raw_file, "a") + + labels = f[labels_ds] + offset = labels.attrs["offset"] + resolution = labels.attrs["resolution"] + + labels = labels[:] + + labels_mask = np.ones_like(labels).astype(np.uint8) + unlabelled_mask = (labels > 0).astype(np.uint8) + + for ds_name, data in [ + ("volumes/training_labels_cropped_mask", labels_mask), + ("volumes/training_unlabelled_cropped_mask", unlabelled_mask), + ]: + f[ds_name] = data + f[ds_name].attrs["offset"] = offset + f[ds_name].attrs["resolution"] = resolution + + try: + labels = f["volumes/training_gt_rasters"] + offset = labels.attrs["offset"] + resolution = labels.attrs["resolution"] + + labels = labels[:] + + labels_mask = np.ones_like(labels).astype(np.uint8) + unlabelled_mask = (labels > 0).astype(np.uint8) + + for ds_name, data in [ + ("volumes/training_raster_mask", labels_mask), + ("volumes/training_unrastered_mask", unlabelled_mask), + ]: + f[ds_name] = data + f[ds_name].attrs["offset"] = offset + f[ds_name].attrs["resolution"] = resolution + + except KeyError: + pass