Skip to content

Commit

Permalink
Eval run
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Nov 20, 2023
1 parent ae2ed0e commit bbd98b1
Show file tree
Hide file tree
Showing 3 changed files with 382 additions and 224 deletions.
File renamed without changes.
366 changes: 142 additions & 224 deletions src/autoseg/eval/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,240 +1,158 @@
import os.path
import math
import networkx as nx
import numpy as np
import argparse
from itertools import combinations, product
from collections import defaultdict

import os
import daisy
import funlib.evaluate


def generate_graphs_with_seg_labels(segment_array, skeleton_path):
"""
Add predicted labels to the ground-truth graph. We also re-assign unique ids to
each cluster of connected nodes after removal of nodes outside of the ROI.
This is to differentiate between sets of nodes that are discontinuous in
the ROI but originally belonged to the same skeleton ID
"""
gt_graph = np.load(skeleton_path, allow_pickle=True)
next_highest_seg_label = int(segment_array.data.max()) + 1
nodes_outside_roi = []
for i, (treenode, attr) in enumerate(gt_graph.nodes(data=True)):
pos = attr["position"]
attr["zyx_coord"] = (pos[2], pos[1], pos[0])
try:
attr["seg_label"] = segment_array[daisy.Coordinate(attr["zyx_coord"])]
except AssertionError as e:
nodes_outside_roi.append(treenode)
continue
if attr["seg_label"] == 0:
# We'll need to relabel them to be a unique non-zero value
# for the Rand/VOI function to work. We also count contiguous skeletal
# nodes predicted to be 0 as split errors.
attr["seg_label"] = next_highest_seg_label
set_point_in_array(
array=segment_array,
point_coord=attr["zyx_coord"],
val=next_highest_seg_label,
)
next_highest_seg_label += 1

for node in nodes_outside_roi:
gt_graph.remove_node(node)
import logging

# reassign `skeleton_id` attribute used in eval functions
skel_clusters = nx.connected_components(gt_graph)
for i, cluster in enumerate(skel_clusters):
for node in cluster:
gt_graph.nodes[node]["skeleton_id"] = i
return gt_graph
logger: logging.Logger = logging.getLogger(__name__)

from evaluate import run_eval
from eval_db import Database
from ..postprocess import get_validation_segmentation

def eval_erl(graph):
node_seg_lut = {}
for node, attr in graph.nodes(data=True):
node_seg_lut[node] = attr["seg_label"]

# get total skel length
skeleton_lengths = funlib.evaluate.run_length.get_skeleton_lengths(
skeletons=graph,
skeleton_position_attributes=["zyx_coord"],
skeleton_id_attribute="skeleton_id",
)
skeleton_lengths = [l for _, l in skeleton_lengths.items() if l > 0]
average_skel_length = np.mean(skeleton_lengths)

erl = funlib.evaluate.expected_run_length(
skeletons=graph,
skeleton_id_attribute="skeleton_id",
node_segment_lut=node_seg_lut,
skeleton_position_attributes=["zyx_coord"],
return_merge_split_stats=False,
edge_length_attribute="edge_length",
)
erl_norm = erl / average_skel_length

return erl, erl_norm
def segment_and_validate(
model_checkpoint="latest",
checkpoint_num=250000,
setup_num="1738",
) -> dict:

logger.info(
msg=f"Segmenting checkpoint {model_checkpoint}, aff_model checkpoint {checkpoint_num}..."
)

def build_segment_label_subgraph(segment_nodes, graph):
subgraph = graph.subgraph(segment_nodes)
skeleton_clusters = nx.connected_components(subgraph)
seg_graph = nx.Graph()
seg_graph.add_nodes_from(subgraph.nodes)
seg_graph.add_edges_from(subgraph.edges)
for skeleton_1, skeleton_2 in combinations(skeleton_clusters, 2):
success: bool = get_validation_segmentation(iteration=checkpoint_num)
if success:
print("-----------------------------\nSuccessfully returned validation segmentation . . . now validating\n----------------------------------------------")
try:
node_1 = skeleton_1.pop()
node_2 = skeleton_2.pop()
if graph.nodes[node_1]["skeleton_id"] == graph.nodes[node_2]["skeleton_id"]:
seg_graph.add_edge(node_1, node_2)
except KeyError:
pass
return seg_graph


# Returns the closest pair of nodes on 2 skeletons
def get_closest_node_pair_between_two_skeletons(skel1, skel2, graph):
multiplier = (1, 1, 1)
shortest_len = math.inf
for node1, node2 in product(skel1, skel2):
coord1, coord2 = (
graph.nodes[node1]["zyx_coord"],
graph.nodes[node2]["zyx_coord"],
)
distance = math.sqrt(sum([(a - b) ** 2 for a, b in zip(coord1, coord2)]))
if distance < shortest_len:
shortest_len = distance
edge_attributes = {"distance": shortest_len}
closest_pair = (node1, node2, edge_attributes)
return closest_pair


def find_merge_errors(graph):
seg_dict = {}
for nid, attr in graph.nodes(data=True):
seg_label = attr["seg_label"]
assert seg_label != 0, "Processed predicted labels cannot be 0"
try:
seg_dict[seg_label].add(nid)
except KeyError:
seg_dict[seg_label] = {nid}

merge_errors = set()
for seg_label, nodes in seg_dict.items():
seg_graph = build_segment_label_subgraph(nodes, graph)
skel_clusters = list(nx.connected_components(seg_graph))
if len(skel_clusters) <= 1:
continue
potential_merge_sites = []
for skeleton_1, skeleton_2 in combinations(skel_clusters, 2):
shortest_connection = get_closest_node_pair_between_two_skeletons(
skeleton_1, skeleton_2, graph
logger.info(
f"Validating checkpoint {model_checkpoint}, aff_model checkpoint {checkpoint_num}..."
)
potential_merge_sites.append(shortest_connection)

merge_sites = [
(error_site[0], error_site[1]) for error_site in potential_merge_sites
]
merge_errors |= set(merge_sites)

return merge_errors


def get_split_merges(graph):
# Count split errors. An error is an edge in the GT skeleton graph connecting two nodes
# of different segment ids.
split_errors = []
for edge in graph.edges():
if graph.nodes[edge[0]]["seg_label"] != graph.nodes[edge[1]]["seg_label"]:
split_errors.append(edge)
merge_errors = find_merge_errors(graph)
return split_errors, merge_errors


def set_point_in_array(array, point_coord, val):
"""Helper function to set value using real-world nm coordinates"""
point_coord = daisy.Coordinate(point_coord)
vox_aligned_offset = (point_coord // array.voxel_size) * array.voxel_size
point_roi = daisy.Roi(vox_aligned_offset, array.voxel_size)
array[point_roi] = val

score_dict:dict = validate(
checkpoint=model_checkpoint,
threshold=float(f"{checkpoint_num}.{setup_num}"),
ds="segmentation_mws",
)
logger.info(
f"Validation for checkpoint {model_checkpoint}, aff_model checkpoint {checkpoint_num} successful"
)
return score_dict
except Exception as e:
logger.warn(
f"Validation for checkpoint {model_checkpoint}, aff_model checkpoint {checkpoint_num} failed: {e}"
)
else:
logger.warn(
f"Validation for checkpoint {model_checkpoint}, aff_model checkpoint {checkpoint_num} failed"
)

def make_voxel_gt_array(test_array, gt_graph):
"""Rasterize GT points to an empty array to compute Rand/VOI"""
gt_ndarray = np.zeros_like(test_array.data).astype(np.uint64)
gt_array = daisy.Array(
gt_ndarray, roi=test_array.roi, voxel_size=test_array.voxel_size
return {}


def validate(
checkpoint,
threshold,
offset:str= "3960,3960,3960",
roi_shape:str= "31680,31680,31680",
skel="../../data/XPRESS_validation_skels.npz",
zarr="./validation.zarr",
h5="validation.h5",
ds="pred_seg",
print_errors=False,
print_in_xyz=False,
downsample=None,
) -> None:
network = os.path.abspath(".").split(os.path.sep)[-1]
aff_setup, aff_checkpoint = str(threshold).split(".")[::-1]

logger.info(f"Preparing {ds}")
cmd_str = f"python ../../data/convert_to_zarr_h5.py {zarr} {ds} {h5} {ds}"
if downsample is not None:
cmd_str += f" --downsample {downsample}"
os.system(cmd_str)

# roi_begin = "8316,8316,8316"
roi_begin = offset
# roi_shape = "23067,23067,23067"
roi_shape = roi_shape
roi_begin = [float(k) for k in roi_begin.split(",")]
roi_shape = [float(k) for k in roi_shape.split(",")]
roi = daisy.Roi(roi_begin, roi_shape)

logger.info(
f"Evaluating {ds} for network {network}, checkpoint {checkpoint}, Raw->AFF setup{aff_setup}, checkpoint {aff_checkpoint}"
)
for neuron_id, cluster in enumerate(nx.connected_components(gt_graph)):
for point in cluster:
point_coord = gt_graph.nodes[point]["zyx_coord"]
set_point_in_array(array=gt_array, point_coord=point_coord, val=neuron_id)
return gt_array


def get_voi(segment_array, gt_graph):
"""Wrapper fn to compute Rand/VOI scores"""
voxel_gt = make_voxel_gt_array(segment_array, gt_graph)
res = funlib.evaluate.rand_voi(
truth=voxel_gt.data, test=segment_array.data.astype(np.uint64)
score_dict = run_eval(skel, h5, ds, roi, downsampling=downsample)
logger.info(
f"Finished evaluating {ds} for network {network}, checkpoint {checkpoint}. Saving results..."
)
return res

split_edges = score_dict.pop("split_edges")
merged_edges = score_dict.pop("merged_edges")
gt_graph = score_dict.pop("gt_graph")

try:
db = Database("validation_results")
db.add_score(
network, checkpoint, threshold, score_dict
) # threshold is set as {checkpoint of LSD>AFF model}.{model number}
except:
pass

# Terminal outputs
logger.info(f'n_neurons: {score_dict["n_neurons"]}')
logger.info(f'Expected run-length: {score_dict["erl"]}')
logger.info(f'Normalized ERL: {score_dict["erl_norm"]}')

logger.info("Count results:")
logger.info(
f'\tSplit count (total, per-neuron): {len(split_edges)}, {len(split_edges)/score_dict["n_neurons"]}'
)
logger.info(
f'\tMerge count (total, per-neuron): {len(merged_edges)}, {len(merged_edges)/score_dict["n_neurons"]}'
)

def run_eval(skeleton_file, segmentation_file, segmentation_ds, roi, downsampling=None):

# load segmentation
segment_array = daisy.open_ds(segmentation_file, segmentation_ds)
if roi is None:
roi = segment_array.roi
if not segment_array.roi.contains(roi):
raise RuntimeError(
f"Provided segmentation ROI ({segment_array.roi}) does not contain test ROI ({roi})"
)
segment_array = segment_array[roi]

if downsampling is not None:
assert type(downsampling) == int
ndarray = segment_array.data[::downsampling, ::downsampling, ::downsampling]
ds_voxel_size = segment_array.voxel_size * downsampling
# align ROI
roi_begin = (segment_array.roi.begin // ds_voxel_size) * ds_voxel_size
roi_shape = daisy.Coordinate(ndarray.shape) * ds_voxel_size
segment_array = daisy.Array(
data=ndarray, roi=daisy.Roi(roi_begin, roi_shape), voxel_size=ds_voxel_size
)

# load to mem
segment_array.materialize()
voxel_size = segment_array.voxel_size
ret = {}

# compute GT graph
gt_graph = generate_graphs_with_seg_labels(segment_array, skeleton_file)
n_neurons = len(list(nx.connected_components(gt_graph)))
ret["n_neurons"] = n_neurons
ret["gt_graph"] = gt_graph

# Compute ERL
ret["erl"], ret["erl_norm"] = eval_erl(gt_graph)

# Compute merge-split
split_edges, merged_edges = get_split_merges(gt_graph)
ret["split_edges"] = split_edges
ret["merged_edges"] = merged_edges

# Compute Rand/VOI
ret["rand_voi"] = get_voi(segment_array, gt_graph)

# Compute xpress score
# xpress score is 50% normed erl and 50% normed voi
rand_voi = ret["rand_voi"]
ret["xpress_voi"] = 1 - 0.5 * (rand_voi["nvi_split"] + rand_voi["nvi_merge"])
ret["xpress_rand"] = 0.5 * (rand_voi["rand_split"] + rand_voi["rand_merge"])
ret["xpress_erl_voi"] = 0.5 * ret["xpress_voi"] + 0.5 * ret["erl_norm"]
ret["xpress_erl_rand"] = 0.5 * ret["xpress_rand"] + 0.5 * ret["erl_norm"]

return ret
if print_errors:
gt_graph = score_dict["gt_graph"]

def print_coords(node1, node2):
node1_coord = daisy.Coordinate(gt_graph.nodes[node1]["zyx_coord"]) / 33
node2_coord = daisy.Coordinate(gt_graph.nodes[node2]["zyx_coord"]) / 33
if print_in_xyz:
node1_coord = node1_coord[::-1]
node2_coord = node2_coord[::-1]
logger.info(f"{node1_coord} to {node2_coord}")

logger.info("Split errors:")
splits_by_skel = defaultdict(list)
for edge in split_edges:
splits_by_skel[gt_graph.nodes[edge[0]]["skeleton_id"]].append(edge)
for skel in splits_by_skel:
logger.info(f"Skeleton #{skel}")
for edge in splits_by_skel[skel]:
print_coords(edge[0], edge[1])
logger.info("Split error histogram:")
split_histogram = defaultdict(int)
for i in range(score_dict["n_neurons"]):
split_histogram[len(splits_by_skel[i])] += 1
for k in sorted(split_histogram):
logger.info(f"{k}: {split_histogram[k]}")

logger.info("Merge errors:")
for node1, node2 in merged_edges:
print_coords(node1, node2)

rand_voi = score_dict["rand_voi"]
logger.info("Rand results (higher better):")
logger.info(f"\tRand split: {rand_voi['rand_split']}")
logger.info(f"\tRand merge: {rand_voi['rand_merge']}")
logger.info("VOI results (lower better):")
logger.info(f"\tNormalized VOI split: {rand_voi['nvi_split']}")
logger.info(f"\tNormalized VOI merge: {rand_voi['nvi_merge']}")

logger.info("XPRESS score (higher is better):")
logger.info(f"\tERL+VOI : {score_dict['xpress_erl_voi']}")
logger.info(f"\tERL+RAND: {score_dict['xpress_erl_rand']}")
logger.info(f"\tVOI : {score_dict['xpress_voi']}")
logger.info(f"\tRAND : {score_dict['xpress_rand']}")
return score_dict
Loading

0 comments on commit bbd98b1

Please sign in to comment.