diff --git a/src/autoseg/postprocess/segment_mws.py b/src/autoseg/postprocess/segment_mws.py index a5f6b9d..6efbc7a 100644 --- a/src/autoseg/postprocess/segment_mws.py +++ b/src/autoseg/postprocess/segment_mws.py @@ -1,15 +1,26 @@ # UTILIZES RUSTY_MWS, HGLOM FOR ALL SEGMENTATION, ORIGINALLY WRITTEN BY BRIAN REICHER (2023) - +import torch import rusty_mws import hglom from ..utils import neighborhood from ..predict.network_predictions import predict_task +from ..networks.FLibUNet import setup_unet +from ..models import * +from ..losses import * + + +def _setup_model(network=setup_unet(), model=MTLSDModel): + model = model(unet=network, num_fmaps=network.out_channels) + return model def get_validation_segmentation( segmentation_style: str = "mws", + model=_setup_model(), + model_type="MTLSD", iteration="latest", + model_path="./", raw_file="../../data/xpress-challenge.zarr", raw_dataset="volumes/validation_raw", out_file="./validation.zarr", @@ -44,7 +55,10 @@ def get_validation_segmentation( if pred_affs: predict_task( # Raw --> Affinities + model=model, + model_type=model_type, iteration=iteration, + model_path=model_path, raw_file=raw_file, raw_dataset=raw_dataset, out_file=out_file, diff --git a/src/autoseg/predict/network_predictions.py b/src/autoseg/predict/network_predictions.py index 1119f63..a915bbc 100644 --- a/src/autoseg/predict/network_predictions.py +++ b/src/autoseg/predict/network_predictions.py @@ -30,7 +30,7 @@ def predict_task( voxel_size: int = 33, ) -> None: """ - Predict affinities using a trained deep learning model. + Predict affinities using a trained learning model. Parameters: model: