Skip to content

Commit

Permalink
Create masks func
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Nov 30, 2023
1 parent 6991c10 commit 68de005
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/autoseg/train_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@ def train_model(
iterations: int = 100000,
warmup: int = 100000,
raw_file: str = "path/to/.zarr/or/.n5/or/.tiff",
rewrite_file: str = "./rewritten.zarr",
rewrite_ds: str = "rewritten_volume",
out_file: str = "./raw_predictions.zarr",
voxel_size: int = 33,
save_every=2500,
) -> None:

# TODO: call ztools to rewrite .tiff file to zarr format
# TODO: add util funcs for generating masks
if raw_file.endswith(".tiff"):
tiff_to_zarr(tiff_file=raw_file)
# raw_file: str = TODO: reassign raw file name
tiff_to_zarr(tiff_file=raw_file,
out_file=rewrite_file,
out_ds=rewrite_ds)


model_type = model_type.lower()
if model_type == "mtlsd":
mtlsd_train(
Expand Down
44 changes: 44 additions & 0 deletions src/autoseg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from funlib.persistence import prepare_ds
from funlib.geometry import Coordinate, Roi
import tifffile
import numpy as np
import zarr


neighborhood: list[list[int]] = [
Expand Down Expand Up @@ -51,3 +53,45 @@ def tiff_to_zarr(tiff_file:str="path/to/.tiff",
ds[roi] = tiff_stack

print("TIFF Image stack saved as Zarr dataset.")


def create_masks(raw_file: str, labels_ds: str) -> None:
f = zarr.open(raw_file, "a")

labels = f[labels_ds]
offset = labels.attrs["offset"]
resolution = labels.attrs["resolution"]

labels = labels[:]

labels_mask = np.ones_like(labels).astype(np.uint8)
unlabelled_mask = (labels > 0).astype(np.uint8)

for ds_name, data in [
("volumes/training_labels_cropped_mask", labels_mask),
("volumes/training_unlabelled_cropped_mask", unlabelled_mask),
]:
f[ds_name] = data
f[ds_name].attrs["offset"] = offset
f[ds_name].attrs["resolution"] = resolution

try:
labels = f["volumes/training_gt_rasters"]
offset = labels.attrs["offset"]
resolution = labels.attrs["resolution"]

labels = labels[:]

labels_mask = np.ones_like(labels).astype(np.uint8)
unlabelled_mask = (labels > 0).astype(np.uint8)

for ds_name, data in [
("volumes/training_raster_mask", labels_mask),
("volumes/training_unrastered_mask", unlabelled_mask),
]:
f[ds_name] = data
f[ds_name].attrs["offset"] = offset
f[ds_name].attrs["resolution"] = resolution

except KeyError:
pass

0 comments on commit 68de005

Please sign in to comment.