Skip to content

Commit

Permalink
Evla docs
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Dec 14, 2023
1 parent 1d3bd47 commit d1c20c8
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 9 deletions.
63 changes: 58 additions & 5 deletions src/autoseg/eval/eval_db.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,36 @@
import sqlite3
import json


class Database:
"""
Simple SQLite Database Wrapper for Storing and Retrieving Scores.
This class provides a simple wrapper around an SQLite database for storing and retrieving scores.
Each score entry is associated with a network, checkpoint, threshold, and a dictionary of scores.
Args:
db_name (str):
The name of the SQLite database file.
table_name (str):
The name of the table within the database (default is 'scores_table').
Attributes:
conn (sqlite3.Connection):
The SQLite database connection.
cursor (sqlite3.Cursor):
The SQLite database cursor.
table_name (str):
The name of the table within the database.
Methods:
add_score(network, checkpoint, threshold, scores_dict):
Add a score entry to the database.
get_scores(networks=None, checkpoints=None, thresholds=None):
Retrieve scores from the database based on specified conditions.
"""

def __init__(self, db_name, table_name="scores_table"):
self.conn = sqlite3.connect(f"{db_name}.db", check_same_thread=False)
self.table_name = table_name
Expand All @@ -18,6 +46,19 @@ def __init__(self, db_name, table_name="scores_table"):
self.conn.commit()

def add_score(self, network, checkpoint, threshold, scores_dict):
"""
Add a score entry to the database.
Args:
network (str):
The name of the network.
checkpoint (int):
The checkpoint number.
threshold (float):
The threshold value.
scores_dict (dict):
A dictionary containing scores.
"""
assert type(network) is str
assert type(checkpoint) is int
assert type(threshold) is float
Expand All @@ -29,6 +70,20 @@ def add_score(self, network, checkpoint, threshold, scores_dict):
self.conn.commit()

def get_scores(self, networks=None, checkpoints=None, thresholds=None):
"""
Retrieve scores from the database based on specified conditions.
Args:
networks (str, list):
The name or list of names of networks to filter on.
checkpoints (int, list):
The checkpoint number or list of checkpoint numbers to filter on.
thresholds (float, list):
The threshold value or list of threshold values to filter on.
Returns:
list: A list of tuples representing retrieved score entries.
"""
assert type(networks) is str or networks is None or type(networks) is list
assert (
type(checkpoints) is int or checkpoints is None or type(checkpoints) is list
Expand All @@ -55,18 +110,16 @@ def add_where(var, var_name, query):
ret += f"{var_name} = '{var}'"
else:
ret += f"{var_name} in ({to_csv_list(var)})"
# print(ret)
return ret

query = f"SELECT * FROM {self.table_name}"
query += add_where(networks, "network", query)
query += add_where(checkpoints, "checkpoint", query)
query += add_where(thresholds, "threshold", query)
# print(query)

ret = self.cursor.execute(query).fetchall()
# print(ret)
ret = [list(k) for k in ret]
for item in ret:
item[3] = json.loads(item[3])
# print(ret)

return ret
45 changes: 45 additions & 0 deletions src/autoseg/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@ def segment_and_validate(
checkpoint_num=250000,
setup_num="1738",
) -> dict:
"""
Segment and Validate a given checkpoint for the Affinity Model.
This function performs segmentation using the specified model checkpoint and validates the results.
It logs information about the segmentation and validation process.
Args:
model_checkpoint (str):
The checkpoint of the segmentation model to use (default is "latest").
checkpoint_num (int):
The checkpoint number for the affinity model (default is 250000).
setup_num (str):
The setup number for the affinity model (default is "1738").
Returns:
dict:
A dictionary containing scores and evaluation metrics.
"""
logger.info(
msg=f"Segmenting checkpoint {model_checkpoint}, aff_model checkpoint {checkpoint_num}..."
)
Expand Down Expand Up @@ -63,6 +81,33 @@ def validate(
print_in_xyz=False,
downsample=None,
) -> None:
"""
Validate segmentation results using specified parameters.
Args:
checkpoint (str):
The checkpoint identifier.
threshold (float):
The threshold value.
offset (str):
The offset for ROI (default is "3960,3960,3960").
roi_shape (str):
The shape of ROI (default is "31680,31680,31680").
skel (str):
The path to the skeleton data file (default is "../../data/XPRESS_validation_skels.npz").
zarr (str):
The path to the Zarr file for storing segmentation data (default is "./validation.zarr").
h5 (str):
The path to the HDF5 file for storing validation data (default is "validation.h5").
ds (str):
The dataset name (default is "pred_seg").
print_errors (bool):
Print errors during validation (default is False).
print_in_xyz (bool):
Print coordinates in XYZ format (default is False).
downsample (int):
Downsample factor for evaluation (default is None).
"""
network = os.path.abspath(".").split(os.path.sep)[-1]
aff_setup, aff_checkpoint = str(threshold).split(".")[::-1]

Expand Down
134 changes: 130 additions & 4 deletions src/autoseg/eval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@ def generate_graphs_with_seg_labels(segment_array, skeleton_path):
Add predicted labels to the ground-truth graph. We also re-assign unique ids to
each cluster of connected nodes after removal of nodes outside of the ROI.
This is to differentiate between sets of nodes that are discontinuous in
the ROI but originally belonged to the same skeleton ID
the ROI but originally belonged to the same skeleton ID.
Args:
segment_array (daisy.Array):
Array containing predicted segmentation labels.
skeleton_path (str):
Path to the skeleton data file.
Returns:
networkx.Graph:
Ground-truth graph with predicted labels added.
"""
gt_graph = np.load(skeleton_path, allow_pickle=True)
next_highest_seg_label = int(segment_array.data.max()) + 1
Expand Down Expand Up @@ -52,6 +62,19 @@ def generate_graphs_with_seg_labels(segment_array, skeleton_path):


def eval_erl(graph):
"""
Compute Expected Run Length (ERL) and normalized ERL for a given graph.
Args:
graph (networkx.Graph):
Graph representing the ground-truth.
Returns:
float:
ERL value.
float:
Normalized ERL value.
"""
node_seg_lut = {}
for node, attr in graph.nodes(data=True):
node_seg_lut[node] = attr["seg_label"]
Expand Down Expand Up @@ -79,6 +102,19 @@ def eval_erl(graph):


def build_segment_label_subgraph(segment_nodes, graph):
"""
Build a subgraph using a set of segment nodes from the given graph.
Args:
segment_nodes (Iterable):
Nodes representing segments.
graph (networkx.Graph):
Original graph.
Returns:
networkx.Graph:
Subgraph containing specified segment nodes.
"""
subgraph = graph.subgraph(segment_nodes)
skeleton_clusters = nx.connected_components(subgraph)
seg_graph = nx.Graph()
Expand All @@ -97,6 +133,21 @@ def build_segment_label_subgraph(segment_nodes, graph):

# Returns the closest pair of nodes on 2 skeletons
def get_closest_node_pair_between_two_skeletons(skel1, skel2, graph):
"""
Get the closest pair of nodes between two skeletons in the given graph.
Args:
skel1 (Iterable):
Nodes of the first skeleton.
skel2 (Iterable):
Nodes of the second skeleton.
graph (networkx.Graph):
Original graph.
Returns:
Tuple:
Closest pair of nodes and their edge attributes.
"""
multiplier = (1, 1, 1)
shortest_len = math.inf
for node1, node2 in product(skel1, skel2):
Expand All @@ -113,6 +164,17 @@ def get_closest_node_pair_between_two_skeletons(skel1, skel2, graph):


def find_merge_errors(graph):
"""
Find merge errors in the given graph.
Args:
graph (networkx.Graph):
Original graph.
Returns:
set:
Set of merge errors.
"""
seg_dict = {}
for nid, attr in graph.nodes(data=True):
seg_label = attr["seg_label"]
Expand Down Expand Up @@ -144,6 +206,17 @@ def find_merge_errors(graph):


def get_split_merges(graph):
"""
Find split merges in the given graph.
Args:
graph (networkx.Graph):
Original graph.
Returns:
set:
Set of split merges.
"""
# Count split errors. An error is an edge in the GT skeleton graph connecting two nodes
# of different segment ids.
split_errors = []
Expand All @@ -155,15 +228,37 @@ def get_split_merges(graph):


def set_point_in_array(array, point_coord, val):
"""Helper function to set value using real-world nm coordinates"""
"""
Set a specific point in the array to a given value.
Args:
array (daisy.Array):
Target array.
point_coord (Tuple):
Coordinates of the point.
val:
Value to set.
"""
point_coord = daisy.Coordinate(point_coord)
vox_aligned_offset = (point_coord // array.voxel_size) * array.voxel_size
point_roi = daisy.Roi(vox_aligned_offset, array.voxel_size)
array[point_roi] = val


def make_voxel_gt_array(test_array, gt_graph):
"""Rasterize GT points to an empty array to compute Rand/VOI"""
"""
Rasterize ground-truth points to an empty array for computing Rand/VOI.
Args:
test_array (daisy.Array):
Target array.
gt_graph (networkx.Graph):
Ground-truth graph.
Returns:
daisy.Array:
Voxel array containing ground-truth information.
"""
gt_ndarray = np.zeros_like(test_array.data).astype(np.uint64)
gt_array = daisy.Array(
gt_ndarray, roi=test_array.roi, voxel_size=test_array.voxel_size
Expand All @@ -176,7 +271,19 @@ def make_voxel_gt_array(test_array, gt_graph):


def get_voi(segment_array, gt_graph):
"""Wrapper fn to compute Rand/VOI scores"""
"""
Wrapper function to compute Rand/VOI scores.
Args:
segment_array (daisy.Array):
Array containing predicted segmentation labels.
gt_graph (networkx.Graph):
Ground-truth graph.
Returns:
Dict:
Dictionary containing Rand/VOI scores.
"""
voxel_gt = make_voxel_gt_array(segment_array, gt_graph)
res = funlib.evaluate.rand_voi(
truth=voxel_gt.data, test=segment_array.data.astype(np.uint64)
Expand All @@ -185,6 +292,25 @@ def get_voi(segment_array, gt_graph):


def run_eval(skeleton_file, segmentation_file, segmentation_ds, roi, downsampling=None):
"""
Run evaluation on the predicted segmentation.
Args:
skeleton_file (str):
Path to the skeleton data file.
segmentation_file (str):
Path to the segmentation data file.
segmentation_ds (str):
Dataset name in the segmentation file.
roi (daisy.Roi):
Region of interest.
downsampling (int):
Downsample factor for evaluation.
Returns:
Dict:
Dictionary containing evaluation metrics.
"""
# load segmentation
segment_array = daisy.open_ds(segmentation_file, segmentation_ds)
if roi is None:
Expand Down

0 comments on commit d1c20c8

Please sign in to comment.