Skip to content

Commit

Permalink
Black linting
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Nov 30, 2023
1 parent 5033f40 commit a9ddcb9
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 63 deletions.
45 changes: 26 additions & 19 deletions src/autoseg/train_job.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from more_itertools import raise_
from .train import mtlsd_train, aclsd_train, stelarr_train
from .utils import tiff_to_zarr, create_masks, wkw_seg_to_zarr, download_wk_skeleton, rasterize_skeleton
from .utils import (
tiff_to_zarr,
create_masks,
wkw_seg_to_zarr,
download_wk_skeleton,
rasterize_skeleton,
)


def train_model(
Expand All @@ -15,45 +21,46 @@ def train_model(
get_rasters: bool = False,
generate_masks: bool = False,
voxel_size: int = 33,
save_every: int =2500,
save_every: int = 2500,
annotation_id: str = None,
wk_token="YqSgxzFJpP2eyjtqymCTPg",

) -> None:

# TODO: add util funcs for generating masks, pulling paintings
if raw_file.endswith(".tiff") or raw_file.endswith(".tif"):
try:
tiff_to_zarr(tiff_file=raw_file,
out_file=rewrite_file,
out_ds=rewrite_ds)
tiff_to_zarr(tiff_file=raw_file, out_file=rewrite_file, out_ds=rewrite_ds)
raw_file: str = rewrite_file
except:
raise("Could not convert TIFF file to zarr volume")
raise ("Could not convert TIFF file to zarr volume")

if get_labels:
try:
wkw_seg_to_zarr(annotation_id=annotation_id,
save_path=".",
zarr_path=raw_file,
wk_token=wk_token,
gt_name="training_labels")
wkw_seg_to_zarr(
annotation_id=annotation_id,
save_path=".",
zarr_path=raw_file,
wk_token=wk_token,
gt_name="training_labels",
)
except:
raise("Could not fetch and convert paintings to zarr format")
raise ("Could not fetch and convert paintings to zarr format")

if get_rasters:
try:
zip_path: str = download_wk_skeleton(annotation_id=annotation_id,
token=wk_token)
zip_path: str = download_wk_skeleton(
annotation_id=annotation_id, token=wk_token
)
rasterize_skeleton(zip_path=zip_path, raw_file=raw_file)
except:
raise("Could not fetch and convert skeletons to zarr format")
raise ("Could not fetch and convert skeletons to zarr format")

if generate_masks:
try:
create_masks(raw_file, "volumes/training_gt_labels")
except:
raise("Could not generate masks - check to make sure a painting labels volume exists")
raise (
"Could not generate masks - check to make sure a painting labels volume exists"
)

model_type: str = model_type.lower()
if model_type == "mtlsd":
Expand Down
83 changes: 39 additions & 44 deletions src/autoseg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,22 @@
[0, 0, 10],
]

def tiff_to_zarr(tiff_file:str="path/to/.tiff",
out_file:str="tiffAsZarr.zarr",
out_ds:str="volumes/raw",
voxel_size: int = 33,
offset: int = 0,
dtype=np.uint8,
transpose:bool=False) -> None:

def tiff_to_zarr(
tiff_file: str = "path/to/.tiff",
out_file: str = "tiffAsZarr.zarr",
out_ds: str = "volumes/raw",
voxel_size: int = 33,
offset: int = 0,
dtype=np.uint8,
transpose: bool = False,
) -> None:
tiff_stack: np.ndarray = tifffile.imread(tiff_file)
if transpose:
tiff_stack = np.transpose(tiff_stack, (2, 1, 0))

voxel_size: Coordinate = Coordinate((voxel_size)*3)
roi: Roi = Roi(offset=(offset)*3, shape=tiff_stack.shape * np.array(voxel_size))
voxel_size: Coordinate = Coordinate((voxel_size) * 3)
roi: Roi = Roi(offset=(offset) * 3, shape=tiff_stack.shape * np.array(voxel_size))

print("Roi: ", roi)
voxel_size: Coordinate = Coordinate(100, 100, 100)
Expand Down Expand Up @@ -113,6 +116,7 @@ def create_masks(raw_file: str, labels_ds: str) -> None:
except KeyError:
pass


def generate_graph(test_array, skeleton_path):
print("Loading from file . . .")
gt_graph = np.load(skeleton_path, allow_pickle=True)
Expand Down Expand Up @@ -154,8 +158,9 @@ def create_array(test_array, gt_graph):
return gt_ndarray


def rasterized_skeletons(raw_file: str, raw_ds: str, out_file: str, skeleton_path: str) -> None:

def rasterized_skeletons(
raw_file: str, raw_ds: str, out_file: str, skeleton_path: str
) -> None:
array = daisy.open_ds(raw_file, raw_ds)
gt_graph = generate_graph(array, skeleton_path)
gt_array = create_array(array, gt_graph)
Expand All @@ -164,27 +169,23 @@ def rasterized_skeletons(raw_file: str, raw_ds: str, out_file: str, skeleton_pat
unabelled_mask = (gt_array > 0).astype(np.uint8)

out["volumes/validation_gt_rasters"] = gt_array
out["volumes/validation_gt_rasters"].attrs[
"resolution"
] = array.voxel_size
out["volumes/validation_gt_rasters"].attrs["resolution"] = array.voxel_size
out["volumes/validation_gt_rasters"].attrs["offset"] = array.roi.offset


logger = logging.getLogger(__name__)


def download_wk_skeleton(
save_path=".",
url="http://catmaid2.hms.harvard.edu:9000",
annotation_id=None,
token = None,
token=None,
overwrite=True,
zip_suffix=None,
):
# print(f"Downloading {wk_url}/annotations/Explorational/{annotation_ID}...")
with wk.webknossos_context(
token=token,
url=url
):
with wk.webknossos_context(token=token, url=url):
annotation = wk.Annotation.download(
annotation_id,
annotation_type="Explorational",
Expand Down Expand Up @@ -217,8 +218,8 @@ def download_wk_skeleton(
annotation.save(zip_path)
return zip_path

def parse_skeleton(zip_path):

def parse_skeleton(zip_path):
fin = zip_path
if not fin.endswith(".zip"):
try:
Expand All @@ -240,9 +241,8 @@ def parse_skeleton(zip_path):

return skel_coor

def get_updated_skeleton(zip_path):


def get_updated_skeleton(zip_path):
if not os.path.exists(zip_path):
path = os.path.dirname(os.path.realpath(zip_path))
search_path = os.path.join(path, "skeletons/*")
Expand All @@ -255,10 +255,12 @@ def get_updated_skeleton(zip_path):

return skel_file

def rasterize_skeleton(zip_path="/n/groups/htem/users/br128/xray-challenge-entry/monkeyv1axonseg001_KevinOhgami_20231010.zip",
raw_file="./data/monkey_xnh.zarr",
raw_ds="volumes/training_raw"):

def rasterize_skeleton(
zip_path="/n/groups/htem/users/br128/xray-challenge-entry/monkeyv1axonseg001_KevinOhgami_20231010.zip",
raw_file="./data/monkey_xnh.zarr",
raw_ds="volumes/training_raw",
):
logger.info(f"Rasterizing skeleton...")

skel_coor = parse_skeleton(zip_path)
Expand All @@ -269,11 +271,11 @@ def rasterize_skeleton(zip_path="/n/groups/htem/users/br128/xray-challenge-entry
dataset_shape = raw.data.shape
print(dataset_shape)
voxel_size = raw.voxel_size
offset = raw.roi.begin # unhardcode for nonzero offset
offset = raw.roi.begin # unhardcode for nonzero offset
image = np.zeros(dataset_shape, dtype=np.uint8)

def adjust(coor):
ds_under = [x-1 for x in dataset_shape]
ds_under = [x - 1 for x in dataset_shape]
return np.min([coor - offset, ds_under], 0)

print("adjusting . . .")
Expand All @@ -283,7 +285,6 @@ def adjust(coor):
line = line_nd(adjust(start), adjust(end))
image[line] = id


# Save GT rasterization #TODO: implement daisy blockwise option
total_roi = Roi(
Coordinate(offset) * Coordinate(voxel_size),
Expand All @@ -303,6 +304,7 @@ def adjust(coor):

return image


def get_wk_mask(
annotation_ID,
save_path,
Expand Down Expand Up @@ -450,9 +452,7 @@ def wkw_seg_to_zarr(
):
print(f"Downloading {annotation_id} from {wk_url}...")
with wk.webknossos_context(token=wk_token, url=wk_url):
annotation = wk.Annotation.download(
annotation_id
)
annotation = wk.Annotation.download(annotation_id)

time_str = strftime("%Y%m%d", gmtime())
annotation_name = (
Expand Down Expand Up @@ -484,8 +484,8 @@ def wkw_seg_to_zarr(
# ds = daisy.open_ds(zarr_path, raw_name)
# offset = ds.roi.get_offset() #/ ds.voxel_size
# offset = Coordinate(0, 0, 0)
roi = Roi((12600, 14100, 51100), (20000,20000,20000))
shape = Coordinate(200,200,200)
roi = Roi((12600, 14100, 51100), (20000, 20000, 20000))
shape = Coordinate(200, 200, 200)
# shape = ds.roi.get_shape() / ds.voxel_size
# shape = Roi()

Expand All @@ -502,18 +502,15 @@ def wkw_seg_to_zarr(
# Open the WKW dataset (as the `1` folder)
print(f"Opening {zf_data_tmpdir + '/1'}...")
dataset = wkw.wkw.Dataset.open(zf_data_tmpdir + "/1")
data = dataset.read(
off=(126, 141, 511),
shape=shape
).squeeze()
data = dataset.read(off=(126, 141, 511), shape=shape).squeeze()

print(f"Sum of all data: {data.sum()}")
# Save annotations to zarr
if gt_name is None:
gt_name = f'{gt_name_prefix}gt_{annotation.dataset_name}_{annotation.username.replace(" ","")}_{time_str}'

target_roi = roi
gt_array = daisy.Array(data, roi, (100,100,100))
gt_array = daisy.Array(data, roi, (100, 100, 100))

chunk_size = 1000
num_channels = 1
Expand Down Expand Up @@ -550,9 +547,9 @@ def save_chunk(block: daisy.Roi):
block.write_roi, gt_array.__getitem__(block.read_roi)
)
# destination[block.write_roi] = gt_array[block.read_roi]
return 0 # success
return 0 # success
except:
return 1 # error
return 1 # error

# Write data to new dataset
task = daisy.Task(
Expand All @@ -569,9 +566,7 @@ def save_chunk(block: daisy.Roi):
success = daisy.run_blockwise([task])

if success:
print(
f"{target_roi} from {annotation_name} written to {zarr_path}/{gt_name}"
)
print(f"{target_roi} from {annotation_name} written to {zarr_path}/{gt_name}")
return gt_name
else:
print("Failed to save annotation layer.")
print("Failed to save annotation layer.")

0 comments on commit a9ddcb9

Please sign in to comment.