Skip to content

Commit

Permalink
refactor post processing seg
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Dec 15, 2023
1 parent 541ae3e commit 1bb9f5a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
16 changes: 15 additions & 1 deletion src/autoseg/postprocess/segment_mws.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
# UTILIZES RUSTY_MWS, HGLOM FOR ALL SEGMENTATION, ORIGINALLY WRITTEN BY BRIAN REICHER (2023)

import torch
import rusty_mws
import hglom

from ..utils import neighborhood
from ..predict.network_predictions import predict_task
from ..networks.FLibUNet import setup_unet
from ..models import *
from ..losses import *


def _setup_model(network=setup_unet(), model=MTLSDModel):
model = model(unet=network, num_fmaps=network.out_channels)
return model


def get_validation_segmentation(
segmentation_style: str = "mws",
model=_setup_model(),
model_type="MTLSD",
iteration="latest",
model_path="./",
raw_file="../../data/xpress-challenge.zarr",
raw_dataset="volumes/validation_raw",
out_file="./validation.zarr",
Expand Down Expand Up @@ -44,7 +55,10 @@ def get_validation_segmentation(

if pred_affs:
predict_task( # Raw --> Affinities
model=model,
model_type=model_type,
iteration=iteration,
model_path=model_path,
raw_file=raw_file,
raw_dataset=raw_dataset,
out_file=out_file,
Expand Down
2 changes: 1 addition & 1 deletion src/autoseg/predict/network_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def predict_task(
voxel_size: int = 33,
) -> None:
"""
Predict affinities using a trained deep learning model.
Predict affinities using a trained learning model.
Parameters:
model:
Expand Down

0 comments on commit 1bb9f5a

Please sign in to comment.