-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ae2ed0e
commit bbd98b1
Showing
3 changed files
with
382 additions
and
224 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,240 +1,158 @@ | ||
import os.path | ||
import math | ||
import networkx as nx | ||
import numpy as np | ||
import argparse | ||
from itertools import combinations, product | ||
from collections import defaultdict | ||
|
||
import os | ||
import daisy | ||
import funlib.evaluate | ||
|
||
|
||
def generate_graphs_with_seg_labels(segment_array, skeleton_path): | ||
""" | ||
Add predicted labels to the ground-truth graph. We also re-assign unique ids to | ||
each cluster of connected nodes after removal of nodes outside of the ROI. | ||
This is to differentiate between sets of nodes that are discontinuous in | ||
the ROI but originally belonged to the same skeleton ID | ||
""" | ||
gt_graph = np.load(skeleton_path, allow_pickle=True) | ||
next_highest_seg_label = int(segment_array.data.max()) + 1 | ||
nodes_outside_roi = [] | ||
for i, (treenode, attr) in enumerate(gt_graph.nodes(data=True)): | ||
pos = attr["position"] | ||
attr["zyx_coord"] = (pos[2], pos[1], pos[0]) | ||
try: | ||
attr["seg_label"] = segment_array[daisy.Coordinate(attr["zyx_coord"])] | ||
except AssertionError as e: | ||
nodes_outside_roi.append(treenode) | ||
continue | ||
if attr["seg_label"] == 0: | ||
# We'll need to relabel them to be a unique non-zero value | ||
# for the Rand/VOI function to work. We also count contiguous skeletal | ||
# nodes predicted to be 0 as split errors. | ||
attr["seg_label"] = next_highest_seg_label | ||
set_point_in_array( | ||
array=segment_array, | ||
point_coord=attr["zyx_coord"], | ||
val=next_highest_seg_label, | ||
) | ||
next_highest_seg_label += 1 | ||
|
||
for node in nodes_outside_roi: | ||
gt_graph.remove_node(node) | ||
import logging | ||
|
||
# reassign `skeleton_id` attribute used in eval functions | ||
skel_clusters = nx.connected_components(gt_graph) | ||
for i, cluster in enumerate(skel_clusters): | ||
for node in cluster: | ||
gt_graph.nodes[node]["skeleton_id"] = i | ||
return gt_graph | ||
logger: logging.Logger = logging.getLogger(__name__) | ||
|
||
from evaluate import run_eval | ||
from eval_db import Database | ||
from ..postprocess import get_validation_segmentation | ||
|
||
def eval_erl(graph): | ||
node_seg_lut = {} | ||
for node, attr in graph.nodes(data=True): | ||
node_seg_lut[node] = attr["seg_label"] | ||
|
||
# get total skel length | ||
skeleton_lengths = funlib.evaluate.run_length.get_skeleton_lengths( | ||
skeletons=graph, | ||
skeleton_position_attributes=["zyx_coord"], | ||
skeleton_id_attribute="skeleton_id", | ||
) | ||
skeleton_lengths = [l for _, l in skeleton_lengths.items() if l > 0] | ||
average_skel_length = np.mean(skeleton_lengths) | ||
|
||
erl = funlib.evaluate.expected_run_length( | ||
skeletons=graph, | ||
skeleton_id_attribute="skeleton_id", | ||
node_segment_lut=node_seg_lut, | ||
skeleton_position_attributes=["zyx_coord"], | ||
return_merge_split_stats=False, | ||
edge_length_attribute="edge_length", | ||
) | ||
erl_norm = erl / average_skel_length | ||
|
||
return erl, erl_norm | ||
def segment_and_validate( | ||
model_checkpoint="latest", | ||
checkpoint_num=250000, | ||
setup_num="1738", | ||
) -> dict: | ||
|
||
logger.info( | ||
msg=f"Segmenting checkpoint {model_checkpoint}, aff_model checkpoint {checkpoint_num}..." | ||
) | ||
|
||
def build_segment_label_subgraph(segment_nodes, graph): | ||
subgraph = graph.subgraph(segment_nodes) | ||
skeleton_clusters = nx.connected_components(subgraph) | ||
seg_graph = nx.Graph() | ||
seg_graph.add_nodes_from(subgraph.nodes) | ||
seg_graph.add_edges_from(subgraph.edges) | ||
for skeleton_1, skeleton_2 in combinations(skeleton_clusters, 2): | ||
success: bool = get_validation_segmentation(iteration=checkpoint_num) | ||
if success: | ||
print("-----------------------------\nSuccessfully returned validation segmentation . . . now validating\n----------------------------------------------") | ||
try: | ||
node_1 = skeleton_1.pop() | ||
node_2 = skeleton_2.pop() | ||
if graph.nodes[node_1]["skeleton_id"] == graph.nodes[node_2]["skeleton_id"]: | ||
seg_graph.add_edge(node_1, node_2) | ||
except KeyError: | ||
pass | ||
return seg_graph | ||
|
||
|
||
# Returns the closest pair of nodes on 2 skeletons | ||
def get_closest_node_pair_between_two_skeletons(skel1, skel2, graph): | ||
multiplier = (1, 1, 1) | ||
shortest_len = math.inf | ||
for node1, node2 in product(skel1, skel2): | ||
coord1, coord2 = ( | ||
graph.nodes[node1]["zyx_coord"], | ||
graph.nodes[node2]["zyx_coord"], | ||
) | ||
distance = math.sqrt(sum([(a - b) ** 2 for a, b in zip(coord1, coord2)])) | ||
if distance < shortest_len: | ||
shortest_len = distance | ||
edge_attributes = {"distance": shortest_len} | ||
closest_pair = (node1, node2, edge_attributes) | ||
return closest_pair | ||
|
||
|
||
def find_merge_errors(graph): | ||
seg_dict = {} | ||
for nid, attr in graph.nodes(data=True): | ||
seg_label = attr["seg_label"] | ||
assert seg_label != 0, "Processed predicted labels cannot be 0" | ||
try: | ||
seg_dict[seg_label].add(nid) | ||
except KeyError: | ||
seg_dict[seg_label] = {nid} | ||
|
||
merge_errors = set() | ||
for seg_label, nodes in seg_dict.items(): | ||
seg_graph = build_segment_label_subgraph(nodes, graph) | ||
skel_clusters = list(nx.connected_components(seg_graph)) | ||
if len(skel_clusters) <= 1: | ||
continue | ||
potential_merge_sites = [] | ||
for skeleton_1, skeleton_2 in combinations(skel_clusters, 2): | ||
shortest_connection = get_closest_node_pair_between_two_skeletons( | ||
skeleton_1, skeleton_2, graph | ||
logger.info( | ||
f"Validating checkpoint {model_checkpoint}, aff_model checkpoint {checkpoint_num}..." | ||
) | ||
potential_merge_sites.append(shortest_connection) | ||
|
||
merge_sites = [ | ||
(error_site[0], error_site[1]) for error_site in potential_merge_sites | ||
] | ||
merge_errors |= set(merge_sites) | ||
|
||
return merge_errors | ||
|
||
|
||
def get_split_merges(graph): | ||
# Count split errors. An error is an edge in the GT skeleton graph connecting two nodes | ||
# of different segment ids. | ||
split_errors = [] | ||
for edge in graph.edges(): | ||
if graph.nodes[edge[0]]["seg_label"] != graph.nodes[edge[1]]["seg_label"]: | ||
split_errors.append(edge) | ||
merge_errors = find_merge_errors(graph) | ||
return split_errors, merge_errors | ||
|
||
|
||
def set_point_in_array(array, point_coord, val): | ||
"""Helper function to set value using real-world nm coordinates""" | ||
point_coord = daisy.Coordinate(point_coord) | ||
vox_aligned_offset = (point_coord // array.voxel_size) * array.voxel_size | ||
point_roi = daisy.Roi(vox_aligned_offset, array.voxel_size) | ||
array[point_roi] = val | ||
|
||
score_dict:dict = validate( | ||
checkpoint=model_checkpoint, | ||
threshold=float(f"{checkpoint_num}.{setup_num}"), | ||
ds="segmentation_mws", | ||
) | ||
logger.info( | ||
f"Validation for checkpoint {model_checkpoint}, aff_model checkpoint {checkpoint_num} successful" | ||
) | ||
return score_dict | ||
except Exception as e: | ||
logger.warn( | ||
f"Validation for checkpoint {model_checkpoint}, aff_model checkpoint {checkpoint_num} failed: {e}" | ||
) | ||
else: | ||
logger.warn( | ||
f"Validation for checkpoint {model_checkpoint}, aff_model checkpoint {checkpoint_num} failed" | ||
) | ||
|
||
def make_voxel_gt_array(test_array, gt_graph): | ||
"""Rasterize GT points to an empty array to compute Rand/VOI""" | ||
gt_ndarray = np.zeros_like(test_array.data).astype(np.uint64) | ||
gt_array = daisy.Array( | ||
gt_ndarray, roi=test_array.roi, voxel_size=test_array.voxel_size | ||
return {} | ||
|
||
|
||
def validate( | ||
checkpoint, | ||
threshold, | ||
offset:str= "3960,3960,3960", | ||
roi_shape:str= "31680,31680,31680", | ||
skel="../../data/XPRESS_validation_skels.npz", | ||
zarr="./validation.zarr", | ||
h5="validation.h5", | ||
ds="pred_seg", | ||
print_errors=False, | ||
print_in_xyz=False, | ||
downsample=None, | ||
) -> None: | ||
network = os.path.abspath(".").split(os.path.sep)[-1] | ||
aff_setup, aff_checkpoint = str(threshold).split(".")[::-1] | ||
|
||
logger.info(f"Preparing {ds}") | ||
cmd_str = f"python ../../data/convert_to_zarr_h5.py {zarr} {ds} {h5} {ds}" | ||
if downsample is not None: | ||
cmd_str += f" --downsample {downsample}" | ||
os.system(cmd_str) | ||
|
||
# roi_begin = "8316,8316,8316" | ||
roi_begin = offset | ||
# roi_shape = "23067,23067,23067" | ||
roi_shape = roi_shape | ||
roi_begin = [float(k) for k in roi_begin.split(",")] | ||
roi_shape = [float(k) for k in roi_shape.split(",")] | ||
roi = daisy.Roi(roi_begin, roi_shape) | ||
|
||
logger.info( | ||
f"Evaluating {ds} for network {network}, checkpoint {checkpoint}, Raw->AFF setup{aff_setup}, checkpoint {aff_checkpoint}" | ||
) | ||
for neuron_id, cluster in enumerate(nx.connected_components(gt_graph)): | ||
for point in cluster: | ||
point_coord = gt_graph.nodes[point]["zyx_coord"] | ||
set_point_in_array(array=gt_array, point_coord=point_coord, val=neuron_id) | ||
return gt_array | ||
|
||
|
||
def get_voi(segment_array, gt_graph): | ||
"""Wrapper fn to compute Rand/VOI scores""" | ||
voxel_gt = make_voxel_gt_array(segment_array, gt_graph) | ||
res = funlib.evaluate.rand_voi( | ||
truth=voxel_gt.data, test=segment_array.data.astype(np.uint64) | ||
score_dict = run_eval(skel, h5, ds, roi, downsampling=downsample) | ||
logger.info( | ||
f"Finished evaluating {ds} for network {network}, checkpoint {checkpoint}. Saving results..." | ||
) | ||
return res | ||
|
||
split_edges = score_dict.pop("split_edges") | ||
merged_edges = score_dict.pop("merged_edges") | ||
gt_graph = score_dict.pop("gt_graph") | ||
|
||
try: | ||
db = Database("validation_results") | ||
db.add_score( | ||
network, checkpoint, threshold, score_dict | ||
) # threshold is set as {checkpoint of LSD>AFF model}.{model number} | ||
except: | ||
pass | ||
|
||
# Terminal outputs | ||
logger.info(f'n_neurons: {score_dict["n_neurons"]}') | ||
logger.info(f'Expected run-length: {score_dict["erl"]}') | ||
logger.info(f'Normalized ERL: {score_dict["erl_norm"]}') | ||
|
||
logger.info("Count results:") | ||
logger.info( | ||
f'\tSplit count (total, per-neuron): {len(split_edges)}, {len(split_edges)/score_dict["n_neurons"]}' | ||
) | ||
logger.info( | ||
f'\tMerge count (total, per-neuron): {len(merged_edges)}, {len(merged_edges)/score_dict["n_neurons"]}' | ||
) | ||
|
||
def run_eval(skeleton_file, segmentation_file, segmentation_ds, roi, downsampling=None): | ||
|
||
# load segmentation | ||
segment_array = daisy.open_ds(segmentation_file, segmentation_ds) | ||
if roi is None: | ||
roi = segment_array.roi | ||
if not segment_array.roi.contains(roi): | ||
raise RuntimeError( | ||
f"Provided segmentation ROI ({segment_array.roi}) does not contain test ROI ({roi})" | ||
) | ||
segment_array = segment_array[roi] | ||
|
||
if downsampling is not None: | ||
assert type(downsampling) == int | ||
ndarray = segment_array.data[::downsampling, ::downsampling, ::downsampling] | ||
ds_voxel_size = segment_array.voxel_size * downsampling | ||
# align ROI | ||
roi_begin = (segment_array.roi.begin // ds_voxel_size) * ds_voxel_size | ||
roi_shape = daisy.Coordinate(ndarray.shape) * ds_voxel_size | ||
segment_array = daisy.Array( | ||
data=ndarray, roi=daisy.Roi(roi_begin, roi_shape), voxel_size=ds_voxel_size | ||
) | ||
|
||
# load to mem | ||
segment_array.materialize() | ||
voxel_size = segment_array.voxel_size | ||
ret = {} | ||
|
||
# compute GT graph | ||
gt_graph = generate_graphs_with_seg_labels(segment_array, skeleton_file) | ||
n_neurons = len(list(nx.connected_components(gt_graph))) | ||
ret["n_neurons"] = n_neurons | ||
ret["gt_graph"] = gt_graph | ||
|
||
# Compute ERL | ||
ret["erl"], ret["erl_norm"] = eval_erl(gt_graph) | ||
|
||
# Compute merge-split | ||
split_edges, merged_edges = get_split_merges(gt_graph) | ||
ret["split_edges"] = split_edges | ||
ret["merged_edges"] = merged_edges | ||
|
||
# Compute Rand/VOI | ||
ret["rand_voi"] = get_voi(segment_array, gt_graph) | ||
|
||
# Compute xpress score | ||
# xpress score is 50% normed erl and 50% normed voi | ||
rand_voi = ret["rand_voi"] | ||
ret["xpress_voi"] = 1 - 0.5 * (rand_voi["nvi_split"] + rand_voi["nvi_merge"]) | ||
ret["xpress_rand"] = 0.5 * (rand_voi["rand_split"] + rand_voi["rand_merge"]) | ||
ret["xpress_erl_voi"] = 0.5 * ret["xpress_voi"] + 0.5 * ret["erl_norm"] | ||
ret["xpress_erl_rand"] = 0.5 * ret["xpress_rand"] + 0.5 * ret["erl_norm"] | ||
|
||
return ret | ||
if print_errors: | ||
gt_graph = score_dict["gt_graph"] | ||
|
||
def print_coords(node1, node2): | ||
node1_coord = daisy.Coordinate(gt_graph.nodes[node1]["zyx_coord"]) / 33 | ||
node2_coord = daisy.Coordinate(gt_graph.nodes[node2]["zyx_coord"]) / 33 | ||
if print_in_xyz: | ||
node1_coord = node1_coord[::-1] | ||
node2_coord = node2_coord[::-1] | ||
logger.info(f"{node1_coord} to {node2_coord}") | ||
|
||
logger.info("Split errors:") | ||
splits_by_skel = defaultdict(list) | ||
for edge in split_edges: | ||
splits_by_skel[gt_graph.nodes[edge[0]]["skeleton_id"]].append(edge) | ||
for skel in splits_by_skel: | ||
logger.info(f"Skeleton #{skel}") | ||
for edge in splits_by_skel[skel]: | ||
print_coords(edge[0], edge[1]) | ||
logger.info("Split error histogram:") | ||
split_histogram = defaultdict(int) | ||
for i in range(score_dict["n_neurons"]): | ||
split_histogram[len(splits_by_skel[i])] += 1 | ||
for k in sorted(split_histogram): | ||
logger.info(f"{k}: {split_histogram[k]}") | ||
|
||
logger.info("Merge errors:") | ||
for node1, node2 in merged_edges: | ||
print_coords(node1, node2) | ||
|
||
rand_voi = score_dict["rand_voi"] | ||
logger.info("Rand results (higher better):") | ||
logger.info(f"\tRand split: {rand_voi['rand_split']}") | ||
logger.info(f"\tRand merge: {rand_voi['rand_merge']}") | ||
logger.info("VOI results (lower better):") | ||
logger.info(f"\tNormalized VOI split: {rand_voi['nvi_split']}") | ||
logger.info(f"\tNormalized VOI merge: {rand_voi['nvi_merge']}") | ||
|
||
logger.info("XPRESS score (higher is better):") | ||
logger.info(f"\tERL+VOI : {score_dict['xpress_erl_voi']}") | ||
logger.info(f"\tERL+RAND: {score_dict['xpress_erl_rand']}") | ||
logger.info(f"\tVOI : {score_dict['xpress_voi']}") | ||
logger.info(f"\tRAND : {score_dict['xpress_rand']}") | ||
return score_dict |
Oops, something went wrong.