Skip to content

Commit

Permalink
fix: handle bad image crops (zero length). Fixes #13.
Browse files Browse the repository at this point in the history
  • Loading branch information
danellecline committed Jul 23, 2024
1 parent 5506bad commit c02394c
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 52 deletions.
17 changes: 13 additions & 4 deletions sdcat/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def _run_hdbscan_assign(
:param out_path: The output path to save the clustering artifacts to
:return: The average similarity score for each cluster, exemplar_df, cluster ids, cluster means, and coverage
"""
info(f'Clustering using HDBSCAN using alpha {alpha}...')
info(f'Clustering using HDBSCAN using alpha {alpha} cluster_selection_epsilon {cluster_selection_epsilon} '
f'min_samples {min_samples} use_tsne {use_tsne} ...')

# Remove any existing cluster images in the output_path
for c in out_path.parent.rglob(f'{prefix}_*cluster*.png'):
Expand Down Expand Up @@ -144,7 +145,7 @@ def _run_hdbscan_assign(
max_scores = max_scores[:-1]

# If all the clusters are unassigned, then use all the samples as exemplars,
# and assign them to the unknown cluster
# and assign them to the unknown cluster. If embedding is empty, this is also the case (failed to extract embeddings)
if len(unique_clusters) == 1 and unique_clusters[0] == -1:
avg_sim_scores = []
exemplar_df = pd.DataFrame()
Expand Down Expand Up @@ -188,7 +189,8 @@ def _run_hdbscan_assign(

# Compute the average similarity score for each cluster
avg_sim_scores = []
for c in clusters:
for i, c in enumerate(clusters):
debug(f'Computing similarity for cluster {i} with {len(c)} samples')
cosine_sim_matrix = cosine_similarity(image_emb[c])
avg_sim_scores.append(np.mean(cosine_sim_matrix))

Expand Down Expand Up @@ -311,7 +313,14 @@ def cluster_vits(

# Fetch the cached embeddings
debug(f'Fetching embeddings ...')
image_emb = np.array([fetch_embedding(model, filename) for filename in images])
image_emb = []
for filename in images:
emb = fetch_embedding(model, filename)
if len(emb) == 0:
image_emb.append(np.zeros(384, dtype=np.float32))
else:
image_emb.append(emb)
image_emb = np.array(image_emb)

if not (output_path / prefix).exists():
(output_path / prefix).mkdir(parents=True)
Expand Down
86 changes: 44 additions & 42 deletions sdcat/cluster/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import numpy as np
from sahi.utils.torch import torch
from torchvision import transforms as pth_transforms
import torch.nn as nn
import cv2
from sdcat.logger import info, err

Expand Down Expand Up @@ -111,47 +110,50 @@ def compute_embedding(images: list, model_name: str):
if Path(f'{filename}_{model_name}.npy').exists():
continue

# Load the image
square_img = Image.open(filename)

# Do some image processing to reduce the noise in the image
# Gaussian blur
square_img = square_img.filter(ImageFilter.GaussianBlur(radius=1))

image = np.array(square_img)

norm_transform = pth_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
# Noramlize the tensor with the mean and std of the ImageNet dataset
img_tensor = norm_transform(img_tensor)
img_tensor = img_tensor.unsqueeze(0) # Add batch dimension
if 'cuda' in device:
img_tensor = img_tensor.to(device)
features = model(img_tensor)

# TODO: add attention map cach as optional
# attentions = model.get_last_selfattention(img_tensor)

# nh = attentions.shape[1] # number of head

# w_featmap = 224 // patch_size
# h_featmap = 224 // patch_size

# Keep only the output patch attention
# attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
# attentions = attentions.reshape(nh, w_featmap, h_featmap)
# attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[
# 0].cpu().numpy()
#
# # Resize the attention map to the original image size
# attentions = np.uint8(255 * attentions[0])

# Get the feature embeddings
embeddings = features.squeeze(dim=0) # Remove batch dimension
embeddings = embeddings.cpu().numpy() # Convert to numpy array

cache_embedding(embeddings, model_name, filename) # save the embedding to disk
#cache_attention(attentions, model_name, filename) # save the attention map to disk
try:
# Load the image
square_img = Image.open(filename)

# Do some image processing to reduce the noise in the image
# Gaussian blur
square_img = square_img.filter(ImageFilter.GaussianBlur(radius=1))

image = np.array(square_img)

norm_transform = pth_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
# Noramlize the tensor with the mean and std of the ImageNet dataset
img_tensor = norm_transform(img_tensor)
img_tensor = img_tensor.unsqueeze(0) # Add batch dimension
if 'cuda' in device:
img_tensor = img_tensor.to(device)
features = model(img_tensor)

# TODO: add attention map cach as optional
# attentions = model.get_last_selfattention(img_tensor)

# nh = attentions.shape[1] # number of head

# w_featmap = 224 // patch_size
# h_featmap = 224 // patch_size

# Keep only the output patch attention
# attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
# attentions = attentions.reshape(nh, w_featmap, h_featmap)
# attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[
# 0].cpu().numpy()
#
# # Resize the attention map to the original image size
# attentions = np.uint8(255 * attentions[0])

# Get the feature embeddings
embeddings = features.squeeze(dim=0) # Remove batch dimension
embeddings = embeddings.cpu().numpy() # Convert to numpy array

cache_embedding(embeddings, model_name, filename) # save the embedding to disk
#cache_attention(attentions, model_name, filename) # save the attention map to disk
except Exception as e:
err(f'Error processing {filename}: {e}')


def compute_norm_embedding(model_name: str, images: list):
Expand Down
23 changes: 17 additions & 6 deletions sdcat/cluster/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,16 @@ def gen_grid(with_attention: bool):
total_pages = len(images) // (nb_images_display * nb_images_display)
# debug(f"{i} Image filename:", images[j])
for j, image in enumerate(images_display):
image_square = Image.open(image)

grid[j].imshow(image_square)
try:
image_square = Image.open(image)
grid[j].imshow(image_square)
except Exception as e:
exception(f'Error opening {image} {e}')
continue

if with_attention:
# Get the attention map
# TODO: remove this or refactor with pass through of model name
attention = fetch_attention('dino_vitb8', image)

# Overlay the attention map on top of the original image
Expand Down Expand Up @@ -146,7 +150,7 @@ def crop_square_image(row, square_dim: int):

if Path(row.crop_path).exists(): # If the crop already exists, skip it
return

x1 = int(row.image_width * row.x)
y1 = int(row.image_height * row.y)
x2 = int(row.image_width * row.xx)
Expand Down Expand Up @@ -199,8 +203,15 @@ def crop_square_image(row, square_dim: int):
img = img.resize((square_dim, square_dim), Image.LANCZOS)

# Save the image
img.save(row.crop_path)
img.close()
# img.save(row.crop_path)

# Every 10th index, Create a zero byte file to indicate that the crop was successful
if Path(row.image_path).stem is 'e1f5e2b8-9e3c-5904-a896-acb3c7a9cbf6':
Path(row.crop_path).touch()
else:
img.save(row.crop_path)
img.close()

except Exception as e:
exception(f'Error cropping {row.image_path} {e}')
raise e
Expand Down

0 comments on commit c02394c

Please sign in to comment.