From a9ddcb9e1bd7527f8ff5dbb344ef628f5791f0e3 Mon Sep 17 00:00:00 2001 From: brianreicher Date: Thu, 30 Nov 2023 14:58:33 -0500 Subject: [PATCH] Black linting --- src/autoseg/train_job.py | 45 +++++++++++++--------- src/autoseg/utils.py | 83 +++++++++++++++++++--------------------- 2 files changed, 65 insertions(+), 63 deletions(-) diff --git a/src/autoseg/train_job.py b/src/autoseg/train_job.py index 84c3a20..b934133 100644 --- a/src/autoseg/train_job.py +++ b/src/autoseg/train_job.py @@ -1,6 +1,12 @@ from more_itertools import raise_ from .train import mtlsd_train, aclsd_train, stelarr_train -from .utils import tiff_to_zarr, create_masks, wkw_seg_to_zarr, download_wk_skeleton, rasterize_skeleton +from .utils import ( + tiff_to_zarr, + create_masks, + wkw_seg_to_zarr, + download_wk_skeleton, + rasterize_skeleton, +) def train_model( @@ -15,45 +21,46 @@ def train_model( get_rasters: bool = False, generate_masks: bool = False, voxel_size: int = 33, - save_every: int =2500, + save_every: int = 2500, annotation_id: str = None, wk_token="YqSgxzFJpP2eyjtqymCTPg", - ) -> None: - # TODO: add util funcs for generating masks, pulling paintings if raw_file.endswith(".tiff") or raw_file.endswith(".tif"): try: - tiff_to_zarr(tiff_file=raw_file, - out_file=rewrite_file, - out_ds=rewrite_ds) + tiff_to_zarr(tiff_file=raw_file, out_file=rewrite_file, out_ds=rewrite_ds) raw_file: str = rewrite_file except: - raise("Could not convert TIFF file to zarr volume") - + raise ("Could not convert TIFF file to zarr volume") + if get_labels: try: - wkw_seg_to_zarr(annotation_id=annotation_id, - save_path=".", - zarr_path=raw_file, - wk_token=wk_token, - gt_name="training_labels") + wkw_seg_to_zarr( + annotation_id=annotation_id, + save_path=".", + zarr_path=raw_file, + wk_token=wk_token, + gt_name="training_labels", + ) except: - raise("Could not fetch and convert paintings to zarr format") + raise ("Could not fetch and convert paintings to zarr format") if get_rasters: try: - zip_path: str = download_wk_skeleton(annotation_id=annotation_id, - token=wk_token) + zip_path: str = download_wk_skeleton( + annotation_id=annotation_id, token=wk_token + ) rasterize_skeleton(zip_path=zip_path, raw_file=raw_file) except: - raise("Could not fetch and convert skeletons to zarr format") + raise ("Could not fetch and convert skeletons to zarr format") if generate_masks: try: create_masks(raw_file, "volumes/training_gt_labels") except: - raise("Could not generate masks - check to make sure a painting labels volume exists") + raise ( + "Could not generate masks - check to make sure a painting labels volume exists" + ) model_type: str = model_type.lower() if model_type == "mtlsd": diff --git a/src/autoseg/utils.py b/src/autoseg/utils.py index a48ff72..7201b51 100644 --- a/src/autoseg/utils.py +++ b/src/autoseg/utils.py @@ -41,19 +41,22 @@ [0, 0, 10], ] -def tiff_to_zarr(tiff_file:str="path/to/.tiff", - out_file:str="tiffAsZarr.zarr", - out_ds:str="volumes/raw", - voxel_size: int = 33, - offset: int = 0, - dtype=np.uint8, - transpose:bool=False) -> None: + +def tiff_to_zarr( + tiff_file: str = "path/to/.tiff", + out_file: str = "tiffAsZarr.zarr", + out_ds: str = "volumes/raw", + voxel_size: int = 33, + offset: int = 0, + dtype=np.uint8, + transpose: bool = False, +) -> None: tiff_stack: np.ndarray = tifffile.imread(tiff_file) if transpose: tiff_stack = np.transpose(tiff_stack, (2, 1, 0)) - voxel_size: Coordinate = Coordinate((voxel_size)*3) - roi: Roi = Roi(offset=(offset)*3, shape=tiff_stack.shape * np.array(voxel_size)) + voxel_size: Coordinate = Coordinate((voxel_size) * 3) + roi: Roi = Roi(offset=(offset) * 3, shape=tiff_stack.shape * np.array(voxel_size)) print("Roi: ", roi) voxel_size: Coordinate = Coordinate(100, 100, 100) @@ -113,6 +116,7 @@ def create_masks(raw_file: str, labels_ds: str) -> None: except KeyError: pass + def generate_graph(test_array, skeleton_path): print("Loading from file . . .") gt_graph = np.load(skeleton_path, allow_pickle=True) @@ -154,8 +158,9 @@ def create_array(test_array, gt_graph): return gt_ndarray -def rasterized_skeletons(raw_file: str, raw_ds: str, out_file: str, skeleton_path: str) -> None: - +def rasterized_skeletons( + raw_file: str, raw_ds: str, out_file: str, skeleton_path: str +) -> None: array = daisy.open_ds(raw_file, raw_ds) gt_graph = generate_graph(array, skeleton_path) gt_array = create_array(array, gt_graph) @@ -164,27 +169,23 @@ def rasterized_skeletons(raw_file: str, raw_ds: str, out_file: str, skeleton_pat unabelled_mask = (gt_array > 0).astype(np.uint8) out["volumes/validation_gt_rasters"] = gt_array - out["volumes/validation_gt_rasters"].attrs[ - "resolution" - ] = array.voxel_size + out["volumes/validation_gt_rasters"].attrs["resolution"] = array.voxel_size out["volumes/validation_gt_rasters"].attrs["offset"] = array.roi.offset logger = logging.getLogger(__name__) + def download_wk_skeleton( save_path=".", url="http://catmaid2.hms.harvard.edu:9000", annotation_id=None, - token = None, + token=None, overwrite=True, zip_suffix=None, ): # print(f"Downloading {wk_url}/annotations/Explorational/{annotation_ID}...") - with wk.webknossos_context( - token=token, - url=url - ): + with wk.webknossos_context(token=token, url=url): annotation = wk.Annotation.download( annotation_id, annotation_type="Explorational", @@ -217,8 +218,8 @@ def download_wk_skeleton( annotation.save(zip_path) return zip_path -def parse_skeleton(zip_path): +def parse_skeleton(zip_path): fin = zip_path if not fin.endswith(".zip"): try: @@ -240,9 +241,8 @@ def parse_skeleton(zip_path): return skel_coor -def get_updated_skeleton(zip_path): - +def get_updated_skeleton(zip_path): if not os.path.exists(zip_path): path = os.path.dirname(os.path.realpath(zip_path)) search_path = os.path.join(path, "skeletons/*") @@ -255,10 +255,12 @@ def get_updated_skeleton(zip_path): return skel_file -def rasterize_skeleton(zip_path="/n/groups/htem/users/br128/xray-challenge-entry/monkeyv1axonseg001_KevinOhgami_20231010.zip", - raw_file="./data/monkey_xnh.zarr", - raw_ds="volumes/training_raw"): +def rasterize_skeleton( + zip_path="/n/groups/htem/users/br128/xray-challenge-entry/monkeyv1axonseg001_KevinOhgami_20231010.zip", + raw_file="./data/monkey_xnh.zarr", + raw_ds="volumes/training_raw", +): logger.info(f"Rasterizing skeleton...") skel_coor = parse_skeleton(zip_path) @@ -269,11 +271,11 @@ def rasterize_skeleton(zip_path="/n/groups/htem/users/br128/xray-challenge-entry dataset_shape = raw.data.shape print(dataset_shape) voxel_size = raw.voxel_size - offset = raw.roi.begin # unhardcode for nonzero offset + offset = raw.roi.begin # unhardcode for nonzero offset image = np.zeros(dataset_shape, dtype=np.uint8) def adjust(coor): - ds_under = [x-1 for x in dataset_shape] + ds_under = [x - 1 for x in dataset_shape] return np.min([coor - offset, ds_under], 0) print("adjusting . . .") @@ -283,7 +285,6 @@ def adjust(coor): line = line_nd(adjust(start), adjust(end)) image[line] = id - # Save GT rasterization #TODO: implement daisy blockwise option total_roi = Roi( Coordinate(offset) * Coordinate(voxel_size), @@ -303,6 +304,7 @@ def adjust(coor): return image + def get_wk_mask( annotation_ID, save_path, @@ -450,9 +452,7 @@ def wkw_seg_to_zarr( ): print(f"Downloading {annotation_id} from {wk_url}...") with wk.webknossos_context(token=wk_token, url=wk_url): - annotation = wk.Annotation.download( - annotation_id - ) + annotation = wk.Annotation.download(annotation_id) time_str = strftime("%Y%m%d", gmtime()) annotation_name = ( @@ -484,8 +484,8 @@ def wkw_seg_to_zarr( # ds = daisy.open_ds(zarr_path, raw_name) # offset = ds.roi.get_offset() #/ ds.voxel_size # offset = Coordinate(0, 0, 0) - roi = Roi((12600, 14100, 51100), (20000,20000,20000)) - shape = Coordinate(200,200,200) + roi = Roi((12600, 14100, 51100), (20000, 20000, 20000)) + shape = Coordinate(200, 200, 200) # shape = ds.roi.get_shape() / ds.voxel_size # shape = Roi() @@ -502,10 +502,7 @@ def wkw_seg_to_zarr( # Open the WKW dataset (as the `1` folder) print(f"Opening {zf_data_tmpdir + '/1'}...") dataset = wkw.wkw.Dataset.open(zf_data_tmpdir + "/1") - data = dataset.read( - off=(126, 141, 511), - shape=shape - ).squeeze() + data = dataset.read(off=(126, 141, 511), shape=shape).squeeze() print(f"Sum of all data: {data.sum()}") # Save annotations to zarr @@ -513,7 +510,7 @@ def wkw_seg_to_zarr( gt_name = f'{gt_name_prefix}gt_{annotation.dataset_name}_{annotation.username.replace(" ","")}_{time_str}' target_roi = roi - gt_array = daisy.Array(data, roi, (100,100,100)) + gt_array = daisy.Array(data, roi, (100, 100, 100)) chunk_size = 1000 num_channels = 1 @@ -550,9 +547,9 @@ def save_chunk(block: daisy.Roi): block.write_roi, gt_array.__getitem__(block.read_roi) ) # destination[block.write_roi] = gt_array[block.read_roi] - return 0 # success + return 0 # success except: - return 1 # error + return 1 # error # Write data to new dataset task = daisy.Task( @@ -569,9 +566,7 @@ def save_chunk(block: daisy.Roi): success = daisy.run_blockwise([task]) if success: - print( - f"{target_roi} from {annotation_name} written to {zarr_path}/{gt_name}" - ) + print(f"{target_roi} from {annotation_name} written to {zarr_path}/{gt_name}") return gt_name else: - print("Failed to save annotation layer.") \ No newline at end of file + print("Failed to save annotation layer.")