Skip to content

Commit

Permalink
fix: fixes a bug in handling zero clusters which occasionally happes …
Browse files Browse the repository at this point in the history
…for very small roi collections. Closes #12.
  • Loading branch information
danellecline committed Jul 22, 2024
1 parent ff2a46c commit 0f69805
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docker/Dockerfile.cuda
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ RUN pip install poetry && poetry build && python3 -m pip install dist/*.whl

FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04

ARG GIT_VERSION=latest
ARG IMAGE_URI=mbari/sdcat:${GIT_VERSION}

LABEL vendor="MBARI"
LABEL maintainer="[email protected]"
LABEL license="Apache License 2.0"
Expand Down
16 changes: 16 additions & 0 deletions sdcat/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,19 @@ def _run_hdbscan_assign(
# Remove the last index which is the -1 cluster
max_scores = max_scores[:-1]

# If all the clusters are unassigned, then use all the samples as exemplars,
# and assign them to the unknown cluster
if len(unique_clusters) == 1 and unique_clusters[0] == -1:
avg_sim_scores = []
exemplar_df = pd.DataFrame()
exemplar_df['cluster'] = len(x)*['Unknown']
exemplar_df['embedding'] = x.tolist()
exemplar_df['image_path'] = ancillary_df['image_path'].tolist()
clusters = []
cluster_means = []
coverage = 0.0
return avg_sim_scores, exemplar_df, clusters, cluster_means, coverage

# Get the representative embeddings for the max scoring examplars for each cluster and store them in a numpy array
exemplar_emb = [image_emb[i] for i in max_scores]
exemplar_emb = np.array(exemplar_emb)
Expand Down Expand Up @@ -333,6 +346,9 @@ def cluster_vits(

if len(unique_clusters) == 0:
warn('No clusters found')
# Save the exemplar embeddings with the model type
exemplar_df['model'] = model
exemplar_df.to_csv(output_path / f'{prefix}_exemplars.csv', index=False)
return None

info(f'Found {len(unique_clusters)} clusters with an average similarity of {avg_similarity:.2f} ')
Expand Down

0 comments on commit 0f69805

Please sign in to comment.