From 48d1f1955510ad2269def64dee9e14b96db36c0f Mon Sep 17 00:00:00 2001 From: danellecline Date: Wed, 18 Sep 2024 09:10:13 -0700 Subject: [PATCH] feat: add to cluster command option --skip-visualization since this takes some time and is not needed for prod ml workflows --- sdcat/cluster/cluster.py | 25 ++++++++++++++----------- sdcat/cluster/commands.py | 12 ++++++++---- sdcat/common_args.py | 6 +++++- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/sdcat/cluster/cluster.py b/sdcat/cluster/cluster.py index bba5aea..a62eaee 100755 --- a/sdcat/cluster/cluster.py +++ b/sdcat/cluster/cluster.py @@ -272,6 +272,7 @@ def cluster_vits( min_samples: int, device: str = "cpu", use_tsne: bool = False, + skip_visualization: bool = False, roi: bool = False) -> pd.DataFrame: """ Cluster the crops using the VITS embeddings. :param prefix: A unique prefix to save artifacts from clustering @@ -286,6 +287,7 @@ def cluster_vits( :param min_cluster_size: The minimum number of samples in a cluster :param min_samples:The number of samples in a neighborhood for a point :param device: The device to use for clustering, 'cpu' or 'cuda' + :param skip_visualization: Whether to skip the visualization of the clusters :param use_tsne: Whether to use t-SNE for dimensionality reduction :return: a dataframe with the assigned cluster indexes, or -1 for non-assigned.""" @@ -400,17 +402,18 @@ def cluster_vits( err(f'No processes available to visualize the clusters') return None - # Use a pool of processes to speed up the visualization of the clusters - with multiprocessing.Pool(num_processes) as pool: - args = [(prefix, # prefix - cluster_sim[cluster_id], # average similarity for the cluster - cluster_id, # cluster id - unique_clusters[cluster_id], # cluster indices - 4 if len(unique_clusters[cluster_id]) < 50 else 8, # grid size; larger clusters get larger grids - [images[idx] for idx in unique_clusters[cluster_id]], # images in the cluster - output_path / prefix) for cluster_id in - range(0, len(unique_clusters))] - pool.starmap(cluster_grid, args) + if not skip_visualization: + # Use a pool of processes to speed up the visualization of the clusters + with multiprocessing.Pool(num_processes) as pool: + args = [(prefix, # prefix + cluster_sim[cluster_id], # average similarity for the cluster + cluster_id, # cluster id + unique_clusters[cluster_id], # cluster indices + 4 if len(unique_clusters[cluster_id]) < 50 else 8, # grid size; larger clusters get larger grids + [images[idx] for idx in unique_clusters[cluster_id]], # images in the cluster + output_path / prefix) for cluster_id in + range(0, len(unique_clusters))] + pool.starmap(cluster_grid, args) # Save the exemplar embeddings with the model type exemplar_df['model'] = model diff --git a/sdcat/cluster/commands.py b/sdcat/cluster/commands.py index 3835c33..6af573c 100644 --- a/sdcat/cluster/commands.py +++ b/sdcat/cluster/commands.py @@ -27,6 +27,7 @@ @common_args.start_image @common_args.end_image @common_args.use_tsne +@common_args.skip_visualization @common_args.alpha @common_args.cluster_selection_epsilon @common_args.cluster_selection_method @@ -34,7 +35,7 @@ @click.option('--det-dir', help='Input folder(s) with raw detection results', multiple=True, required=True) @click.option('--save-dir', help='Output directory to save clustered detection results', required=True) @click.option('--device', help='Device to use, e.g. cpu or cuda:0', type=str) -def run_cluster_det(det_dir, save_dir, device, config_ini, alpha, cluster_selection_epsilon, cluster_selection_method, min_cluster_size, start_image, end_image, use_tsne): +def run_cluster_det(det_dir, save_dir, device, config_ini, alpha, cluster_selection_epsilon, cluster_selection_method, min_cluster_size, start_image, end_image, use_tsne, skip_visualization): config = cfg.Config(config_ini) max_area = int(config('cluster', 'max_area')) min_area = int(config('cluster', 'min_area')) @@ -256,7 +257,8 @@ def is_day(utc_dt): # Cluster the detections df_cluster = cluster_vits(prefix, model, df, save_dir, alpha, cluster_selection_epsilon, cluster_selection_method, - min_similarity, min_cluster_size, min_samples, device, use_tsne, roi=False) + min_similarity, min_cluster_size, min_samples, device, use_tsne, + skip_visualization=skip_visualization, roi=False) # Merge the results with the original DataFrame df.update(df_cluster) @@ -270,6 +272,7 @@ def is_day(utc_dt): @click.command('roi', help='Cluster roi. See cluster --config-ini to override cluster defaults.') @common_args.config_ini @common_args.use_tsne +@common_args.skip_visualization @common_args.alpha @common_args.cluster_selection_epsilon @common_args.cluster_selection_method @@ -277,7 +280,7 @@ def is_day(utc_dt): @click.option('--roi-dir', help='Input folder(s) with raw ROI images', multiple=True, required=True) @click.option('--save-dir', help='Output directory to save clustered detection results', required=True) @click.option('--device', help='Device to use, e.g. cpu or cuda:0', type=str) -def run_cluster_roi(roi_dir, save_dir, device, config_ini, alpha, cluster_selection_epsilon, cluster_selection_method, min_cluster_size, use_tsne): +def run_cluster_roi(roi_dir, save_dir, device, config_ini, alpha, cluster_selection_epsilon, cluster_selection_method, min_cluster_size, use_tsne, skip_visualization): config = cfg.Config(config_ini) min_samples = int(config('cluster', 'min_samples')) alpha = alpha if alpha else float(config('cluster', 'alpha')) @@ -357,7 +360,8 @@ def run_cluster_roi(roi_dir, save_dir, device, config_ini, alpha, cluster_select # Cluster the detections df_cluster = cluster_vits(prefix, model, df, save_dir, alpha, cluster_selection_epsilon, cluster_selection_method, - min_similarity, min_cluster_size, min_samples, device, use_tsne, roi=True) + min_similarity, min_cluster_size, min_samples, device, use_tsne, + skip_visualization=skip_visualization, roi=True) # Merge the results with the original DataFrame df.update(df_cluster) diff --git a/sdcat/common_args.py b/sdcat/common_args.py index bd34a7d..809e28e 100644 --- a/sdcat/common_args.py +++ b/sdcat/common_args.py @@ -43,4 +43,8 @@ use_tsne = click.option('--use-tsne', is_flag=True, - help='Use t-SNE for dimensionality reduction. Default is False') \ No newline at end of file + help='Use t-SNE for dimensionality reduction. Default is False') + +skip_visualization = click.option('--skip-visualization', + is_flag=True, + help='Skip visualization. Default is False') \ No newline at end of file