Skip to content

Commit

Permalink
Get labels, rasters options
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Nov 30, 2023
1 parent d8dbb24 commit 5033f40
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
27 changes: 25 additions & 2 deletions src/autoseg/train_job.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from more_itertools import raise_
from .train import mtlsd_train, aclsd_train, stelarr_train
from .utils import tiff_to_zarr, create_masks
from .utils import tiff_to_zarr, create_masks, wkw_seg_to_zarr, download_wk_skeleton, rasterize_skeleton


def train_model(
Expand All @@ -11,9 +11,14 @@ def train_model(
rewrite_file: str = "./rewritten.zarr",
rewrite_ds: str = "volumes/training_raw",
out_file: str = "./raw_predictions.zarr",
get_labels: bool = False,
get_rasters: bool = False,
generate_masks: bool = False,
voxel_size: int = 33,
save_every=2500,
save_every: int =2500,
annotation_id: str = None,
wk_token="YqSgxzFJpP2eyjtqymCTPg",

) -> None:

# TODO: add util funcs for generating masks, pulling paintings
Expand All @@ -26,6 +31,24 @@ def train_model(
except:
raise("Could not convert TIFF file to zarr volume")

if get_labels:
try:
wkw_seg_to_zarr(annotation_id=annotation_id,
save_path=".",
zarr_path=raw_file,
wk_token=wk_token,
gt_name="training_labels")
except:
raise("Could not fetch and convert paintings to zarr format")

if get_rasters:
try:
zip_path: str = download_wk_skeleton(annotation_id=annotation_id,
token=wk_token)
rasterize_skeleton(zip_path=zip_path, raw_file=raw_file)
except:
raise("Could not fetch and convert skeletons to zarr format")

if generate_masks:
try:
create_masks(raw_file, "volumes/training_gt_labels")
Expand Down
6 changes: 3 additions & 3 deletions src/autoseg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def download_wk_skeleton(
url="http://catmaid2.hms.harvard.edu:9000",
annotation_id=None,
token = None,
overwrite=None,
overwrite=True,
zip_suffix=None,
):
# print(f"Downloading {wk_url}/annotations/Explorational/{annotation_ID}...")
Expand Down Expand Up @@ -441,14 +441,14 @@ def wkw_seg_to_zarr(
annotation_id,
save_path,
zarr_path,
raw_name="volumes/raw",
raw_name="volumes/training_raw",
wk_url="http://catmaid2.hms.harvard.edu:9000",
wk_token="YqSgxzFJpP2eyjtqymCTPg",
gt_name=None,
gt_name_prefix="volumes/",
overwrite=None,
):
print(f"Downloading {annotation_ID} from {wk_url}...")
print(f"Downloading {annotation_id} from {wk_url}...")
with wk.webknossos_context(token=wk_token, url=wk_url):
annotation = wk.Annotation.download(
annotation_id
Expand Down

0 comments on commit 5033f40

Please sign in to comment.