Skip to content

Commit

Permalink
feat: add to cluster command option --skip-visualization since this t…
Browse files Browse the repository at this point in the history
…akes some time and is not needed for prod ml workflows
  • Loading branch information
danellecline committed Sep 18, 2024
1 parent a64f12b commit 48d1f19
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 16 deletions.
25 changes: 14 additions & 11 deletions sdcat/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def cluster_vits(
min_samples: int,
device: str = "cpu",
use_tsne: bool = False,
skip_visualization: bool = False,
roi: bool = False) -> pd.DataFrame:
""" Cluster the crops using the VITS embeddings.
:param prefix: A unique prefix to save artifacts from clustering
Expand All @@ -286,6 +287,7 @@ def cluster_vits(
:param min_cluster_size: The minimum number of samples in a cluster
:param min_samples:The number of samples in a neighborhood for a point
:param device: The device to use for clustering, 'cpu' or 'cuda'
:param skip_visualization: Whether to skip the visualization of the clusters
:param use_tsne: Whether to use t-SNE for dimensionality reduction
:return: a dataframe with the assigned cluster indexes, or -1 for non-assigned."""

Expand Down Expand Up @@ -400,17 +402,18 @@ def cluster_vits(
err(f'No processes available to visualize the clusters')
return None

# Use a pool of processes to speed up the visualization of the clusters
with multiprocessing.Pool(num_processes) as pool:
args = [(prefix, # prefix
cluster_sim[cluster_id], # average similarity for the cluster
cluster_id, # cluster id
unique_clusters[cluster_id], # cluster indices
4 if len(unique_clusters[cluster_id]) < 50 else 8, # grid size; larger clusters get larger grids
[images[idx] for idx in unique_clusters[cluster_id]], # images in the cluster
output_path / prefix) for cluster_id in
range(0, len(unique_clusters))]
pool.starmap(cluster_grid, args)
if not skip_visualization:
# Use a pool of processes to speed up the visualization of the clusters
with multiprocessing.Pool(num_processes) as pool:
args = [(prefix, # prefix
cluster_sim[cluster_id], # average similarity for the cluster
cluster_id, # cluster id
unique_clusters[cluster_id], # cluster indices
4 if len(unique_clusters[cluster_id]) < 50 else 8, # grid size; larger clusters get larger grids
[images[idx] for idx in unique_clusters[cluster_id]], # images in the cluster
output_path / prefix) for cluster_id in
range(0, len(unique_clusters))]
pool.starmap(cluster_grid, args)

# Save the exemplar embeddings with the model type
exemplar_df['model'] = model
Expand Down
12 changes: 8 additions & 4 deletions sdcat/cluster/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@
@common_args.start_image
@common_args.end_image
@common_args.use_tsne
@common_args.skip_visualization
@common_args.alpha
@common_args.cluster_selection_epsilon
@common_args.cluster_selection_method
@common_args.min_cluster_size
@click.option('--det-dir', help='Input folder(s) with raw detection results', multiple=True, required=True)
@click.option('--save-dir', help='Output directory to save clustered detection results', required=True)
@click.option('--device', help='Device to use, e.g. cpu or cuda:0', type=str)
def run_cluster_det(det_dir, save_dir, device, config_ini, alpha, cluster_selection_epsilon, cluster_selection_method, min_cluster_size, start_image, end_image, use_tsne):
def run_cluster_det(det_dir, save_dir, device, config_ini, alpha, cluster_selection_epsilon, cluster_selection_method, min_cluster_size, start_image, end_image, use_tsne, skip_visualization):
config = cfg.Config(config_ini)
max_area = int(config('cluster', 'max_area'))
min_area = int(config('cluster', 'min_area'))
Expand Down Expand Up @@ -256,7 +257,8 @@ def is_day(utc_dt):

# Cluster the detections
df_cluster = cluster_vits(prefix, model, df, save_dir, alpha, cluster_selection_epsilon, cluster_selection_method,
min_similarity, min_cluster_size, min_samples, device, use_tsne, roi=False)
min_similarity, min_cluster_size, min_samples, device, use_tsne,
skip_visualization=skip_visualization, roi=False)

# Merge the results with the original DataFrame
df.update(df_cluster)
Expand All @@ -270,14 +272,15 @@ def is_day(utc_dt):
@click.command('roi', help='Cluster roi. See cluster --config-ini to override cluster defaults.')
@common_args.config_ini
@common_args.use_tsne
@common_args.skip_visualization
@common_args.alpha
@common_args.cluster_selection_epsilon
@common_args.cluster_selection_method
@common_args.min_cluster_size
@click.option('--roi-dir', help='Input folder(s) with raw ROI images', multiple=True, required=True)
@click.option('--save-dir', help='Output directory to save clustered detection results', required=True)
@click.option('--device', help='Device to use, e.g. cpu or cuda:0', type=str)
def run_cluster_roi(roi_dir, save_dir, device, config_ini, alpha, cluster_selection_epsilon, cluster_selection_method, min_cluster_size, use_tsne):
def run_cluster_roi(roi_dir, save_dir, device, config_ini, alpha, cluster_selection_epsilon, cluster_selection_method, min_cluster_size, use_tsne, skip_visualization):
config = cfg.Config(config_ini)
min_samples = int(config('cluster', 'min_samples'))
alpha = alpha if alpha else float(config('cluster', 'alpha'))
Expand Down Expand Up @@ -357,7 +360,8 @@ def run_cluster_roi(roi_dir, save_dir, device, config_ini, alpha, cluster_select

# Cluster the detections
df_cluster = cluster_vits(prefix, model, df, save_dir, alpha, cluster_selection_epsilon, cluster_selection_method,
min_similarity, min_cluster_size, min_samples, device, use_tsne, roi=True)
min_similarity, min_cluster_size, min_samples, device, use_tsne,
skip_visualization=skip_visualization, roi=True)

# Merge the results with the original DataFrame
df.update(df_cluster)
Expand Down
6 changes: 5 additions & 1 deletion sdcat/common_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,8 @@

use_tsne = click.option('--use-tsne',
is_flag=True,
help='Use t-SNE for dimensionality reduction. Default is False')
help='Use t-SNE for dimensionality reduction. Default is False')

skip_visualization = click.option('--skip-visualization',
is_flag=True,
help='Skip visualization. Default is False')

0 comments on commit 48d1f19

Please sign in to comment.