diff --git a/src/autoseg/train_job.py b/src/autoseg/train_job.py index a42bdf6..33baafe 100644 --- a/src/autoseg/train_job.py +++ b/src/autoseg/train_job.py @@ -8,20 +8,21 @@ def train_model( warmup: int = 100000, raw_file: str = "path/to/.zarr/or/.n5/or/.tiff", rewrite_file: str = "./rewritten.zarr", - rewrite_ds: str = "rewritten_volume", + rewrite_ds: str = "training_raw", out_file: str = "./raw_predictions.zarr", voxel_size: int = 33, save_every=2500, ) -> None: - # TODO: add util funcs for generating masks - if raw_file.endswith(".tiff"): + # TODO: add util funcs for generating masks, pulling paintings + if raw_file.endswith(".tiff") or raw_file.endswith(".tif"): tiff_to_zarr(tiff_file=raw_file, out_file=rewrite_file, out_ds=rewrite_ds) + raw_file: str = rewrite_file - model_type = model_type.lower() + model_type: str = model_type.lower() if model_type == "mtlsd": mtlsd_train( iterations=iterations,