Skip to content

Commit

Permalink
perf: migrated to transformers library with batch size of 8, moved so…
Browse files Browse the repository at this point in the history
…me imports to only where needed for some speed-up, and removed unused activation maps.
  • Loading branch information
danellecline committed Jul 29, 2024
1 parent c582a17 commit c5fe725
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 229 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ torch==2.3.1
piexif
yolov5==7.0.13
torchvision==0.18.1
transformers
transformers[torch]
timm
pandas>=1.2.4
ultralytics
Expand Down
72 changes: 33 additions & 39 deletions sdcat/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _run_hdbscan_assign(
if not numerical.empty:
numerical = numerical.fillna(0)

# Normalize the numerical data from 0 to 1
# Normalize the numerical data from 0 to 1 and add it to the dataframe
numerical = (numerical - numerical.min()) / (numerical.max() - numerical.min())

df = pd.merge(df, numerical, left_index=True, right_index=True, how='left')
Expand Down Expand Up @@ -107,30 +107,15 @@ def _run_hdbscan_assign(
labels = scan.fit_predict(x)
else:
scan = HDBSCAN(
metric='l2',
allow_single_cluster=True,
min_cluster_size=min_cluster_size,
min_samples=min_samples,
alpha=alpha,
cluster_selection_epsilon=cluster_selection_epsilon,
cluster_selection_method='leaf')
metric='l2',
allow_single_cluster=True,
min_cluster_size=min_cluster_size,
min_samples=min_samples,
alpha=alpha,
cluster_selection_epsilon=cluster_selection_epsilon,
cluster_selection_method='leaf')
labels = scan.fit_predict(x)

# title_tree = f'HDBSCAN Tree Distances {cluster_selection_epsilon} min_cluster_size {min_cluster_size} min_samples {min_samples} alpha {alpha}'
# title_linkage = title_tree.replace('Tree Distances', 'Linkage')

# scan.condensed_tree_.plot(select_clusters=True,
# selection_palette=sns.color_palette('deep', 8))
# plt.title(title_tree)
# plt.xlabel('Index')
# plt.savefig(f"{out_path}/{prefix}_condensed_tree.png")

# plt.figure(figsize=(10, 6))
# scan.single_linkage_tree_.plot(cmap='viridis', colorbar=True)
# plt.title(title_linkage)
# plt.xlabel('Index')
# plt.savefig(f"{out_path}/{prefix}_tree.png")

# Get the unique clusters and sort them; -1 are unassigned clusters
cluster_df = pd.DataFrame(labels, columns=['cluster'])
unique_clusters = cluster_df['cluster'].unique().tolist()
Expand All @@ -149,7 +134,7 @@ def _run_hdbscan_assign(
if len(unique_clusters) == 1 and unique_clusters[0] == -1:
avg_sim_scores = []
exemplar_df = pd.DataFrame()
exemplar_df['cluster'] = len(x)*['Unknown']
exemplar_df['cluster'] = len(x) * ['Unknown']
exemplar_df['embedding'] = x.tolist()
exemplar_df['image_path'] = ancillary_df['image_path'].tolist()
clusters = []
Expand Down Expand Up @@ -191,6 +176,9 @@ def _run_hdbscan_assign(
avg_sim_scores = []
for i, c in enumerate(clusters):
debug(f'Computing similarity for cluster {i} with {len(c)} samples')
if len(c) == 0:
avg_sim_scores.append(0)
continue
cosine_sim_matrix = cosine_similarity(image_emb[c])
avg_sim_scores.append(np.mean(cosine_sim_matrix))

Expand Down Expand Up @@ -223,7 +211,7 @@ def _run_hdbscan_assign(
else:
init = 'spectral'

# Reduce the dimensionality of the embeddings using UMAP to 2 dimensions for visualization
# Reduce the dimensionality of the embeddings using UMAP to 2 dimensions to visualize the clusters
if have_gpu:
xx = cuUMAP(init=init,
n_components=2,
Expand All @@ -233,8 +221,6 @@ def _run_hdbscan_assign(
else:
xx = UMAP(init=init,
n_components=2,
n_neighbors=3,
min_dist=0.1,
metric='cosine',
low_memory=True).fit_transform(df.values)

Expand Down Expand Up @@ -285,14 +271,14 @@ def cluster_vits(
# Skip cropping if all the crops are already done
if num_crop != len(df_dets):
num_processes = min(multiprocessing.cpu_count(), len(df_dets))
if roi == True:
info(f'ROI crops already exist. Creating square crops in parallel using {multiprocessing.cpu_count()} processes...')
if roi is True:
info(f'ROI crops already exist. Creating square crops in parallel using {num_processes} processes...')
with multiprocessing.Pool(num_processes) as pool:
args = [(row, 224) for index, row in df_dets.iterrows()]
pool.starmap(square_image, args)
else:
# Crop and squaring the images in parallel using multiprocessing to speed up the processing
info(f'Cropping {len(df_dets)} detections in parallel using {multiprocessing.cpu_count()} processes...')
info(f'Cropping {len(df_dets)} detections in parallel using {num_processes} processes...')
with multiprocessing.Pool(num_processes) as pool:
args = [(row, 224) for index, row in df_dets.iterrows()]
pool.starmap(crop_square_image, args)
Expand All @@ -317,9 +303,17 @@ def cluster_vits(
for filename in images:
emb = fetch_embedding(model, filename)
if len(emb) == 0:
# If the embeddings are zero, then the extraction failed; add a zero array
image_emb.append(np.zeros(384, dtype=np.float32))
else:
image_emb.append(emb)

# If the embeddings are zero, then the extraction failed
num_failed = [i for i, e in enumerate(image_emb) if np.all(e == 0)]
if len(num_failed) == len(images):
warn('Failed to extract embeddings from all images')
return pd.DataFrame()

image_emb = np.array(image_emb)

if not (output_path / prefix).exists():
Expand All @@ -338,15 +332,15 @@ def cluster_vits(

# Cluster the images
cluster_sim, exemplar_df, unique_clusters, cluster_means, coverage = _run_hdbscan_assign(prefix,
image_emb,
alpha,
cluster_selection_epsilon,
min_similarity,
min_cluster_size,
min_samples,
use_tsne,
ancillary_df,
output_path / prefix)
image_emb,
alpha,
cluster_selection_epsilon,
min_similarity,
min_cluster_size,
min_samples,
use_tsne,
ancillary_df,
output_path / prefix)

# Get the average similarity across all clusters
avg_similarity = np.mean(cluster_sim)
Expand Down
5 changes: 4 additions & 1 deletion sdcat/cluster/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,11 @@ def is_day(utc_dt):
info(df.head(5))

if len(df) > 0:
# Replace / with _ in the model name
model_machine_friendly = model.replace('/', '_')

# A prefix for the output files to make sure the output is unique for each execution
prefix = f'{model}_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
prefix = f'{model_machine_friendly}_{datetime.now().strftime("%Y%m%d_%H%M%S")}'

# Cluster the detections
df_cluster = cluster_vits(prefix, model, df, save_dir, alpha, cluster_selection_epsilon, min_similarity,
Expand Down
162 changes: 63 additions & 99 deletions sdcat/cluster/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,53 +13,58 @@
from sahi.utils.torch import torch
from torchvision import transforms as pth_transforms
import cv2
from transformers import ViTModel, ViTImageProcessor

from sdcat.logger import info, err


def cache_embedding(embedding, model_name: str, filename: str):
# save numpy array as npy file
save(f'{filename}_{model_name}.npy', embedding)
class ViTWrapper:
MODEL_NAME = "google/vit-base-patch16-224"
VECTOR_DIMENSIONS = 768

def __init__(self, device: str = "cpu", reset: bool = False, batch_size: int = 32):
self.batch_size = batch_size

self.model = ViTModel.from_pretrained(self.MODEL_NAME)
self.processor = ViTImageProcessor.from_pretrained(self.MODEL_NAME)

def cache_attention(attention, model_name: str, filename: str):
# Load the model and processor
if 'cuda' in device and torch.cuda.is_available():
device_num = int(device.split(":")[-1])
info(f"Using GPU device {device_num}")
torch.cuda.set_device(device_num)
self.device = "cuda"
self.model.to("cuda")
else:
self.device = "cpu"


def cache_embedding(embedding, model_name: str, filename: str):
model_machine_friendly_name = model_name.replace("/", "_")
# save numpy array as npy file
save(f'{filename}_{model_name}_a.npy', attention)
save(f'{filename}_{model_machine_friendly_name}.npy', embedding)


def fetch_embedding(model_name: str, filename: str) -> np.array:
model_machine_friendly_name = model_name.replace("/", "_")
# if the npy file exists, return it
if os.path.exists(f'{filename}_{model_name}.npy'):
data = load(f'{filename}_{model_name}.npy')
if os.path.exists(f'{filename}_{model_machine_friendly_name}.npy'):
data = load(f'{filename}_{model_machine_friendly_name}.npy')
return data
else:
info(f'No embedding found for {filename}')
return []


def fetch_attention(model_name: str, filename: str) -> np.array:
"""
Fetch the attention map for the given filename and model name
:param model_name: Name of the model
:param filename: Name of the file
:return: Numpy array of the attention map
"""
# if the npy file exists, return it
if os.path.exists(f'{filename}_{model_name}_a.npy'):
data = load(f'{filename}_{model_name}_a.npy')
return data
else:
info(f'No attention map found for {filename}')
return []


def has_cached_embedding(model_name: str, filename: str) -> int:
"""
Check if the given filename has a cached embedding
:param model_name: Name of the model
:param filename: Name of the file
:return: 1 if the image has a cached embedding, otherwise 0
"""
if os.path.exists(f'{filename}_{model_name}.npy'):
model_machine_friendly_name = model_name.replace("/", "_")
if os.path.exists(f'{filename}_{model_machine_friendly_name}.npy'):
return 1
return 0

Expand All @@ -71,89 +76,48 @@ def encode_image(filename):
return keep


def compute_embedding(images: list, model_name: str):
def compute_embedding_vits(images: list, model_name: str, device: str = "cpu"):
"""
Compute the embedding for the given images using the given model
:param images: List of image filenames
:param model_name: Name of the model
:param model_name: Name of the model (i.e. google/vit-base-patch16-224, dinov2_vits16, etc.)
:param device: Device to use for the computation (cpu or cuda:0, cuda:1, etc.)
"""

# Load the model
if 'dinov2' in model_name:
info(f'Loading model {model_name} from facebookresearch/dinov2...')
model = torch.hub.load('facebookresearch/dinov2', model_name)
elif 'dino' in model_name:
info(f'Loading model {model_name} from facebookresearch/dino:main...')
model = torch.hub.load('facebookresearch/dino:main', model_name)
else:
# TODO: Add more models
err(f'Unknown model {model_name}!')
return

# The patch size is in the model name, e.g. dino_vits16 is a 16x16 patch size, dino_vits8 is a 8x8 patch size
res = re.findall(r'\d+$', model_name)
if len(res) > 0:
patch_size = int(res[0])
batch_size = 8
vit_model = ViTModel.from_pretrained(model_name)
processor = ViTImageProcessor.from_pretrained(model_name)

if 'cuda' in device and torch.cuda.is_available():
device_num = int(device.split(":")[-1])
info(f"Using GPU device {device_num}")
torch.cuda.set_device(device_num)
vit_model.to("cuda")
device = "cuda"
else:
raise ValueError(f'Could not find patch size in model name {model_name}')
info(f'Using patch size {patch_size} for model {model_name}')

# Load images and generate embeddings
device = 'cuda' if torch.cuda.is_available() else 'cpu'
with torch.no_grad():
# Set the cuda device
if torch.cuda.is_available():
model = model.to(device)

for filename in images:
# Skip if the embedding already exists
if Path(f'{filename}_{model_name}.npy').exists():
device = "cpu"

# Batch process the images
batches = [images[i:i + batch_size] for i in range(0, len(images), batch_size)]
for batch in batches:
try:
# Skip running the model if the embeddings already exist
if all([has_cached_embedding(model_name, filename) for filename in batch]):
continue

try:
# Load the image
square_img = Image.open(filename)

# Do some image processing to reduce the noise in the image
# Gaussian blur
square_img = square_img.filter(ImageFilter.GaussianBlur(radius=1))

image = np.array(square_img)

norm_transform = pth_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
# Noramlize the tensor with the mean and std of the ImageNet dataset
img_tensor = norm_transform(img_tensor)
img_tensor = img_tensor.unsqueeze(0) # Add batch dimension
if 'cuda' in device:
img_tensor = img_tensor.to(device)
features = model(img_tensor)

# TODO: add attention map cach as optional
# attentions = model.get_last_selfattention(img_tensor)

# nh = attentions.shape[1] # number of head

# w_featmap = 224 // patch_size
# h_featmap = 224 // patch_size
images = [Image.open(filename).convert("RGB") for filename in batch]
inputs = processor(images=images, return_tensors="pt").to(device)

# Keep only the output patch attention
# attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
# attentions = attentions.reshape(nh, w_featmap, h_featmap)
# attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[
# 0].cpu().numpy()
#
# # Resize the attention map to the original image size
# attentions = np.uint8(255 * attentions[0])
with torch.no_grad():
embeddings = vit_model(**inputs)

# Get the feature embeddings
embeddings = features.squeeze(dim=0) # Remove batch dimension
embeddings = embeddings.cpu().numpy() # Convert to numpy array
batch_embeddings = embeddings.last_hidden_state[:, 0, :].cpu().numpy()

cache_embedding(embeddings, model_name, filename) # save the embedding to disk
#cache_attention(attentions, model_name, filename) # save the attention map to disk
except Exception as e:
err(f'Error processing {filename}: {e}')
# Save the embeddings
for emb, filename in zip(batch_embeddings, batch):
emb = emb.astype(np.float32)
cache_embedding(emb, model_name, filename)
except Exception as e:
err(f'Error processing {batch}: {e}')


def compute_norm_embedding(model_name: str, images: list):
Expand All @@ -172,15 +136,15 @@ def compute_norm_embedding(model_name: str, images: list):

# If using a GPU, set then skip the parallel CPU processing
if torch.cuda.is_available():
compute_embedding(images, model_name)
compute_embedding_vits(images, model_name)
else:
# Use a pool of processes to speed up the embedding generation 20 images at a time on each process
num_processes = min(multiprocessing.cpu_count(), len(images) // 20)
num_processes = max(1, num_processes)
info(f'Using {num_processes} processes to compute {len(images)} embeddings 20 at a time ...')
with multiprocessing.Pool(num_processes) as pool:
args = [(images[i:i + 20], model_name) for i in range(0, len(images), 20)]
pool.starmap(compute_embedding, args)
pool.starmap(compute_embedding_vits, args)


def calc_mean_std(image_files: list) -> tuple:
Expand Down
Loading

0 comments on commit c5fe725

Please sign in to comment.