From d1c20c8401ab9ee78ea9261a0c59ddce907905ad Mon Sep 17 00:00:00 2001 From: brianreicher Date: Thu, 14 Dec 2023 11:25:30 -0500 Subject: [PATCH] Evla docs --- src/autoseg/eval/eval_db.py | 63 ++++++++++++++-- src/autoseg/eval/evaluate.py | 45 ++++++++++++ src/autoseg/eval/metrics.py | 134 +++++++++++++++++++++++++++++++++-- 3 files changed, 233 insertions(+), 9 deletions(-) diff --git a/src/autoseg/eval/eval_db.py b/src/autoseg/eval/eval_db.py index 2e3ee19..e6376d8 100644 --- a/src/autoseg/eval/eval_db.py +++ b/src/autoseg/eval/eval_db.py @@ -1,8 +1,36 @@ import sqlite3 import json - class Database: + """ + Simple SQLite Database Wrapper for Storing and Retrieving Scores. + + This class provides a simple wrapper around an SQLite database for storing and retrieving scores. + Each score entry is associated with a network, checkpoint, threshold, and a dictionary of scores. + + Args: + db_name (str): + The name of the SQLite database file. + table_name (str): + The name of the table within the database (default is 'scores_table'). + + Attributes: + conn (sqlite3.Connection): + The SQLite database connection. + cursor (sqlite3.Cursor): + The SQLite database cursor. + table_name (str): + The name of the table within the database. + + Methods: + add_score(network, checkpoint, threshold, scores_dict): + Add a score entry to the database. + + get_scores(networks=None, checkpoints=None, thresholds=None): + Retrieve scores from the database based on specified conditions. + + """ + def __init__(self, db_name, table_name="scores_table"): self.conn = sqlite3.connect(f"{db_name}.db", check_same_thread=False) self.table_name = table_name @@ -18,6 +46,19 @@ def __init__(self, db_name, table_name="scores_table"): self.conn.commit() def add_score(self, network, checkpoint, threshold, scores_dict): + """ + Add a score entry to the database. + + Args: + network (str): + The name of the network. + checkpoint (int): + The checkpoint number. + threshold (float): + The threshold value. + scores_dict (dict): + A dictionary containing scores. + """ assert type(network) is str assert type(checkpoint) is int assert type(threshold) is float @@ -29,6 +70,20 @@ def add_score(self, network, checkpoint, threshold, scores_dict): self.conn.commit() def get_scores(self, networks=None, checkpoints=None, thresholds=None): + """ + Retrieve scores from the database based on specified conditions. + + Args: + networks (str, list): + The name or list of names of networks to filter on. + checkpoints (int, list): + The checkpoint number or list of checkpoint numbers to filter on. + thresholds (float, list): + The threshold value or list of threshold values to filter on. + + Returns: + list: A list of tuples representing retrieved score entries. + """ assert type(networks) is str or networks is None or type(networks) is list assert ( type(checkpoints) is int or checkpoints is None or type(checkpoints) is list @@ -55,18 +110,16 @@ def add_where(var, var_name, query): ret += f"{var_name} = '{var}'" else: ret += f"{var_name} in ({to_csv_list(var)})" - # print(ret) return ret query = f"SELECT * FROM {self.table_name}" query += add_where(networks, "network", query) query += add_where(checkpoints, "checkpoint", query) query += add_where(thresholds, "threshold", query) - # print(query) + ret = self.cursor.execute(query).fetchall() - # print(ret) ret = [list(k) for k in ret] for item in ret: item[3] = json.loads(item[3]) - # print(ret) + return ret diff --git a/src/autoseg/eval/evaluate.py b/src/autoseg/eval/evaluate.py index 4fab4e2..7cfd552 100644 --- a/src/autoseg/eval/evaluate.py +++ b/src/autoseg/eval/evaluate.py @@ -16,6 +16,24 @@ def segment_and_validate( checkpoint_num=250000, setup_num="1738", ) -> dict: + """ + Segment and Validate a given checkpoint for the Affinity Model. + + This function performs segmentation using the specified model checkpoint and validates the results. + It logs information about the segmentation and validation process. + + Args: + model_checkpoint (str): + The checkpoint of the segmentation model to use (default is "latest"). + checkpoint_num (int): + The checkpoint number for the affinity model (default is 250000). + setup_num (str): + The setup number for the affinity model (default is "1738"). + + Returns: + dict: + A dictionary containing scores and evaluation metrics. + """ logger.info( msg=f"Segmenting checkpoint {model_checkpoint}, aff_model checkpoint {checkpoint_num}..." ) @@ -63,6 +81,33 @@ def validate( print_in_xyz=False, downsample=None, ) -> None: + """ + Validate segmentation results using specified parameters. + + Args: + checkpoint (str): + The checkpoint identifier. + threshold (float): + The threshold value. + offset (str): + The offset for ROI (default is "3960,3960,3960"). + roi_shape (str): + The shape of ROI (default is "31680,31680,31680"). + skel (str): + The path to the skeleton data file (default is "../../data/XPRESS_validation_skels.npz"). + zarr (str): + The path to the Zarr file for storing segmentation data (default is "./validation.zarr"). + h5 (str): + The path to the HDF5 file for storing validation data (default is "validation.h5"). + ds (str): + The dataset name (default is "pred_seg"). + print_errors (bool): + Print errors during validation (default is False). + print_in_xyz (bool): + Print coordinates in XYZ format (default is False). + downsample (int): + Downsample factor for evaluation (default is None). + """ network = os.path.abspath(".").split(os.path.sep)[-1] aff_setup, aff_checkpoint = str(threshold).split(".")[::-1] diff --git a/src/autoseg/eval/metrics.py b/src/autoseg/eval/metrics.py index d252e2f..c34c914 100644 --- a/src/autoseg/eval/metrics.py +++ b/src/autoseg/eval/metrics.py @@ -15,7 +15,17 @@ def generate_graphs_with_seg_labels(segment_array, skeleton_path): Add predicted labels to the ground-truth graph. We also re-assign unique ids to each cluster of connected nodes after removal of nodes outside of the ROI. This is to differentiate between sets of nodes that are discontinuous in - the ROI but originally belonged to the same skeleton ID + the ROI but originally belonged to the same skeleton ID. + + Args: + segment_array (daisy.Array): + Array containing predicted segmentation labels. + skeleton_path (str): + Path to the skeleton data file. + + Returns: + networkx.Graph: + Ground-truth graph with predicted labels added. """ gt_graph = np.load(skeleton_path, allow_pickle=True) next_highest_seg_label = int(segment_array.data.max()) + 1 @@ -52,6 +62,19 @@ def generate_graphs_with_seg_labels(segment_array, skeleton_path): def eval_erl(graph): + """ + Compute Expected Run Length (ERL) and normalized ERL for a given graph. + + Args: + graph (networkx.Graph): + Graph representing the ground-truth. + + Returns: + float: + ERL value. + float: + Normalized ERL value. + """ node_seg_lut = {} for node, attr in graph.nodes(data=True): node_seg_lut[node] = attr["seg_label"] @@ -79,6 +102,19 @@ def eval_erl(graph): def build_segment_label_subgraph(segment_nodes, graph): + """ + Build a subgraph using a set of segment nodes from the given graph. + + Args: + segment_nodes (Iterable): + Nodes representing segments. + graph (networkx.Graph): + Original graph. + + Returns: + networkx.Graph: + Subgraph containing specified segment nodes. + """ subgraph = graph.subgraph(segment_nodes) skeleton_clusters = nx.connected_components(subgraph) seg_graph = nx.Graph() @@ -97,6 +133,21 @@ def build_segment_label_subgraph(segment_nodes, graph): # Returns the closest pair of nodes on 2 skeletons def get_closest_node_pair_between_two_skeletons(skel1, skel2, graph): + """ + Get the closest pair of nodes between two skeletons in the given graph. + + Args: + skel1 (Iterable): + Nodes of the first skeleton. + skel2 (Iterable): + Nodes of the second skeleton. + graph (networkx.Graph): + Original graph. + + Returns: + Tuple: + Closest pair of nodes and their edge attributes. + """ multiplier = (1, 1, 1) shortest_len = math.inf for node1, node2 in product(skel1, skel2): @@ -113,6 +164,17 @@ def get_closest_node_pair_between_two_skeletons(skel1, skel2, graph): def find_merge_errors(graph): + """ + Find merge errors in the given graph. + + Args: + graph (networkx.Graph): + Original graph. + + Returns: + set: + Set of merge errors. + """ seg_dict = {} for nid, attr in graph.nodes(data=True): seg_label = attr["seg_label"] @@ -144,6 +206,17 @@ def find_merge_errors(graph): def get_split_merges(graph): + """ + Find split merges in the given graph. + + Args: + graph (networkx.Graph): + Original graph. + + Returns: + set: + Set of split merges. + """ # Count split errors. An error is an edge in the GT skeleton graph connecting two nodes # of different segment ids. split_errors = [] @@ -155,7 +228,17 @@ def get_split_merges(graph): def set_point_in_array(array, point_coord, val): - """Helper function to set value using real-world nm coordinates""" + """ + Set a specific point in the array to a given value. + + Args: + array (daisy.Array): + Target array. + point_coord (Tuple): + Coordinates of the point. + val: + Value to set. + """ point_coord = daisy.Coordinate(point_coord) vox_aligned_offset = (point_coord // array.voxel_size) * array.voxel_size point_roi = daisy.Roi(vox_aligned_offset, array.voxel_size) @@ -163,7 +246,19 @@ def set_point_in_array(array, point_coord, val): def make_voxel_gt_array(test_array, gt_graph): - """Rasterize GT points to an empty array to compute Rand/VOI""" + """ + Rasterize ground-truth points to an empty array for computing Rand/VOI. + + Args: + test_array (daisy.Array): + Target array. + gt_graph (networkx.Graph): + Ground-truth graph. + + Returns: + daisy.Array: + Voxel array containing ground-truth information. + """ gt_ndarray = np.zeros_like(test_array.data).astype(np.uint64) gt_array = daisy.Array( gt_ndarray, roi=test_array.roi, voxel_size=test_array.voxel_size @@ -176,7 +271,19 @@ def make_voxel_gt_array(test_array, gt_graph): def get_voi(segment_array, gt_graph): - """Wrapper fn to compute Rand/VOI scores""" + """ + Wrapper function to compute Rand/VOI scores. + + Args: + segment_array (daisy.Array): + Array containing predicted segmentation labels. + gt_graph (networkx.Graph): + Ground-truth graph. + + Returns: + Dict: + Dictionary containing Rand/VOI scores. + """ voxel_gt = make_voxel_gt_array(segment_array, gt_graph) res = funlib.evaluate.rand_voi( truth=voxel_gt.data, test=segment_array.data.astype(np.uint64) @@ -185,6 +292,25 @@ def get_voi(segment_array, gt_graph): def run_eval(skeleton_file, segmentation_file, segmentation_ds, roi, downsampling=None): + """ + Run evaluation on the predicted segmentation. + + Args: + skeleton_file (str): + Path to the skeleton data file. + segmentation_file (str): + Path to the segmentation data file. + segmentation_ds (str): + Dataset name in the segmentation file. + roi (daisy.Roi): + Region of interest. + downsampling (int): + Downsample factor for evaluation. + + Returns: + Dict: + Dictionary containing evaluation metrics. + """ # load segmentation segment_array = daisy.open_ds(segmentation_file, segmentation_ds) if roi is None: