Skip to content

Commit

Permalink
MTLSD custom refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Nov 29, 2023
1 parent bbc2936 commit c3c5d47
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/autoseg/train/MTLSDTrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
torch.backends.cudnn.benchmark = True


def mtlsd_train(iterations: int, raw_file: str, voxel_size: int = 33):
def mtlsd_train(raw_file: str = "../../data/xpress-challenge.zarr",
voxel_size: int = 33,
iterations: int = 100000,
save_every: int = 25000,
) -> None:
raw = gp.ArrayKey("RAW")
labels = gp.ArrayKey("LABELS")
labels_mask = gp.ArrayKey("LABELS_MASK")
Expand Down Expand Up @@ -152,7 +156,7 @@ def mtlsd_train(iterations: int, raw_file: str, voxel_size: int = 33):
5: affs_weights,
},
outputs={0: pred_lsds, 1: pred_affs},
save_every=50000,
save_every=save_every,
log_dir="log",
)

Expand All @@ -168,7 +172,7 @@ def mtlsd_train(iterations: int, raw_file: str, voxel_size: int = 33):
gt_affs: "gt_affs",
pred_affs: "pred_affs",
},
output_filename="batch_{iteration}.zarr",
output_filename="batch_latest.zarr",
every=50000,
)

Expand Down

0 comments on commit c3c5d47

Please sign in to comment.