Skip to content

Commit

Permalink
Merge pull request #14 from mbari-org/vitbatch
Browse files Browse the repository at this point in the history
perf: transformer batching
  • Loading branch information
danellecline authored Jul 29, 2024
2 parents c582a17 + c5fe725 commit 427931e
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 229 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ torch==2.3.1
piexif
yolov5==7.0.13
torchvision==0.18.1
transformers
transformers[torch]
timm
pandas>=1.2.4
ultralytics
Expand Down
72 changes: 33 additions & 39 deletions sdcat/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _run_hdbscan_assign(
if not numerical.empty:
numerical = numerical.fillna(0)

# Normalize the numerical data from 0 to 1
# Normalize the numerical data from 0 to 1 and add it to the dataframe
numerical = (numerical - numerical.min()) / (numerical.max() - numerical.min())

df = pd.merge(df, numerical, left_index=True, right_index=True, how='left')
Expand Down Expand Up @@ -107,30 +107,15 @@ def _run_hdbscan_assign(
labels = scan.fit_predict(x)
else:
scan = HDBSCAN(
metric='l2',
allow_single_cluster=True,
min_cluster_size=min_cluster_size,
min_samples=min_samples,
alpha=alpha,
cluster_selection_epsilon=cluster_selection_epsilon,
cluster_selection_method='leaf')
metric='l2',
allow_single_cluster=True,
min_cluster_size=min_cluster_size,
min_samples=min_samples,
alpha=alpha,
cluster_selection_epsilon=cluster_selection_epsilon,
cluster_selection_method='leaf')
labels = scan.fit_predict(x)

# title_tree = f'HDBSCAN Tree Distances {cluster_selection_epsilon} min_cluster_size {min_cluster_size} min_samples {min_samples} alpha {alpha}'
# title_linkage = title_tree.replace('Tree Distances', 'Linkage')

# scan.condensed_tree_.plot(select_clusters=True,
# selection_palette=sns.color_palette('deep', 8))
# plt.title(title_tree)
# plt.xlabel('Index')
# plt.savefig(f"{out_path}/{prefix}_condensed_tree.png")

# plt.figure(figsize=(10, 6))
# scan.single_linkage_tree_.plot(cmap='viridis', colorbar=True)
# plt.title(title_linkage)
# plt.xlabel('Index')
# plt.savefig(f"{out_path}/{prefix}_tree.png")

# Get the unique clusters and sort them; -1 are unassigned clusters
cluster_df = pd.DataFrame(labels, columns=['cluster'])
unique_clusters = cluster_df['cluster'].unique().tolist()
Expand All @@ -149,7 +134,7 @@ def _run_hdbscan_assign(
if len(unique_clusters) == 1 and unique_clusters[0] == -1:
avg_sim_scores = []
exemplar_df = pd.DataFrame()
exemplar_df['cluster'] = len(x)*['Unknown']
exemplar_df['cluster'] = len(x) * ['Unknown']
exemplar_df['embedding'] = x.tolist()
exemplar_df['image_path'] = ancillary_df['image_path'].tolist()
clusters = []
Expand Down Expand Up @@ -191,6 +176,9 @@ def _run_hdbscan_assign(
avg_sim_scores = []
for i, c in enumerate(clusters):
debug(f'Computing similarity for cluster {i} with {len(c)} samples')
if len(c) == 0:
avg_sim_scores.append(0)
continue
cosine_sim_matrix = cosine_similarity(image_emb[c])
avg_sim_scores.append(np.mean(cosine_sim_matrix))

Expand Down Expand Up @@ -223,7 +211,7 @@ def _run_hdbscan_assign(
else:
init = 'spectral'

# Reduce the dimensionality of the embeddings using UMAP to 2 dimensions for visualization
# Reduce the dimensionality of the embeddings using UMAP to 2 dimensions to visualize the clusters
if have_gpu:
xx = cuUMAP(init=init,
n_components=2,
Expand All @@ -233,8 +221,6 @@ def _run_hdbscan_assign(
else:
xx = UMAP(init=init,
n_components=2,
n_neighbors=3,
min_dist=0.1,
metric='cosine',
low_memory=True).fit_transform(df.values)

Expand Down Expand Up @@ -285,14 +271,14 @@ def cluster_vits(
# Skip cropping if all the crops are already done
if num_crop != len(df_dets):
num_processes = min(multiprocessing.cpu_count(), len(df_dets))
if roi == True:
info(f'ROI crops already exist. Creating square crops in parallel using {multiprocessing.cpu_count()} processes...')
if roi is True:
info(f'ROI crops already exist. Creating square crops in parallel using {num_processes} processes...')
with multiprocessing.Pool(num_processes) as pool:
args = [(row, 224) for index, row in df_dets.iterrows()]
pool.starmap(square_image, args)
else:
# Crop and squaring the images in parallel using multiprocessing to speed up the processing
info(f'Cropping {len(df_dets)} detections in parallel using {multiprocessing.cpu_count()} processes...')
info(f'Cropping {len(df_dets)} detections in parallel using {num_processes} processes...')
with multiprocessing.Pool(num_processes) as pool:
args = [(row, 224) for index, row in df_dets.iterrows()]
pool.starmap(crop_square_image, args)
Expand All @@ -317,9 +303,17 @@ def cluster_vits(
for filename in images:
emb = fetch_embedding(model, filename)
if len(emb) == 0:
# If the embeddings are zero, then the extraction failed; add a zero array
image_emb.append(np.zeros(384, dtype=np.float32))
else:
image_emb.append(emb)

# If the embeddings are zero, then the extraction failed
num_failed = [i for i, e in enumerate(image_emb) if np.all(e == 0)]
if len(num_failed) == len(images):
warn('Failed to extract embeddings from all images')
return pd.DataFrame()

image_emb = np.array(image_emb)

if not (output_path / prefix).exists():
Expand All @@ -338,15 +332,15 @@ def cluster_vits(

# Cluster the images
cluster_sim, exemplar_df, unique_clusters, cluster_means, coverage = _run_hdbscan_assign(prefix,
image_emb,
alpha,
cluster_selection_epsilon,
min_similarity,
min_cluster_size,
min_samples,
use_tsne,
ancillary_df,
output_path / prefix)
image_emb,
alpha,
cluster_selection_epsilon,
min_similarity,
min_cluster_size,
min_samples,
use_tsne,
ancillary_df,
output_path / prefix)

# Get the average similarity across all clusters
avg_similarity = np.mean(cluster_sim)
Expand Down
5 changes: 4 additions & 1 deletion sdcat/cluster/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,11 @@ def is_day(utc_dt):
info(df.head(5))

if len(df) > 0:
# Replace / with _ in the model name
model_machine_friendly = model.replace('/', '_')

# A prefix for the output files to make sure the output is unique for each execution
prefix = f'{model}_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
prefix = f'{model_machine_friendly}_{datetime.now().strftime("%Y%m%d_%H%M%S")}'

# Cluster the detections
df_cluster = cluster_vits(prefix, model, df, save_dir, alpha, cluster_selection_epsilon, min_similarity,
Expand Down
162 changes: 63 additions & 99 deletions sdcat/cluster/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,53 +13,58 @@
from sahi.utils.torch import torch
from torchvision import transforms as pth_transforms
import cv2
from transformers import ViTModel, ViTImageProcessor

from sdcat.logger import info, err


def cache_embedding(embedding, model_name: str, filename: str):
# save numpy array as npy file
save(f'{filename}_{model_name}.npy', embedding)
class ViTWrapper:
MODEL_NAME = "google/vit-base-patch16-224"
VECTOR_DIMENSIONS = 768

def __init__(self, device: str = "cpu", reset: bool = False, batch_size: int = 32):
self.batch_size = batch_size

self.model = ViTModel.from_pretrained(self.MODEL_NAME)
self.processor = ViTImageProcessor.from_pretrained(self.MODEL_NAME)

def cache_attention(attention, model_name: str, filename: str):
# Load the model and processor
if 'cuda' in device and torch.cuda.is_available():
device_num = int(device.split(":")[-1])
info(f"Using GPU device {device_num}")
torch.cuda.set_device(device_num)
self.device = "cuda"
self.model.to("cuda")
else:
self.device = "cpu"


def cache_embedding(embedding, model_name: str, filename: str):
model_machine_friendly_name = model_name.replace("/", "_")
# save numpy array as npy file
save(f'{filename}_{model_name}_a.npy', attention)
save(f'{filename}_{model_machine_friendly_name}.npy', embedding)


def fetch_embedding(model_name: str, filename: str) -> np.array:
model_machine_friendly_name = model_name.replace("/", "_")
# if the npy file exists, return it
if os.path.exists(f'{filename}_{model_name}.npy'):
data = load(f'{filename}_{model_name}.npy')
if os.path.exists(f'{filename}_{model_machine_friendly_name}.npy'):
data = load(f'{filename}_{model_machine_friendly_name}.npy')
return data
else:
info(f'No embedding found for {filename}')
return []


def fetch_attention(model_name: str, filename: str) -> np.array:
"""
Fetch the attention map for the given filename and model name
:param model_name: Name of the model
:param filename: Name of the file
:return: Numpy array of the attention map
"""
# if the npy file exists, return it
if os.path.exists(f'{filename}_{model_name}_a.npy'):
data = load(f'{filename}_{model_name}_a.npy')
return data
else:
info(f'No attention map found for {filename}')
return []


def has_cached_embedding(model_name: str, filename: str) -> int:
"""
Check if the given filename has a cached embedding
:param model_name: Name of the model
:param filename: Name of the file
:return: 1 if the image has a cached embedding, otherwise 0
"""
if os.path.exists(f'{filename}_{model_name}.npy'):
model_machine_friendly_name = model_name.replace("/", "_")
if os.path.exists(f'{filename}_{model_machine_friendly_name}.npy'):
return 1
return 0

Expand All @@ -71,89 +76,48 @@ def encode_image(filename):
return keep


def compute_embedding(images: list, model_name: str):
def compute_embedding_vits(images: list, model_name: str, device: str = "cpu"):
"""
Compute the embedding for the given images using the given model
:param images: List of image filenames
:param model_name: Name of the model
:param model_name: Name of the model (i.e. google/vit-base-patch16-224, dinov2_vits16, etc.)
:param device: Device to use for the computation (cpu or cuda:0, cuda:1, etc.)
"""

# Load the model
if 'dinov2' in model_name:
info(f'Loading model {model_name} from facebookresearch/dinov2...')
model = torch.hub.load('facebookresearch/dinov2', model_name)
elif 'dino' in model_name:
info(f'Loading model {model_name} from facebookresearch/dino:main...')
model = torch.hub.load('facebookresearch/dino:main', model_name)
else:
# TODO: Add more models
err(f'Unknown model {model_name}!')
return

# The patch size is in the model name, e.g. dino_vits16 is a 16x16 patch size, dino_vits8 is a 8x8 patch size
res = re.findall(r'\d+$', model_name)
if len(res) > 0:
patch_size = int(res[0])
batch_size = 8
vit_model = ViTModel.from_pretrained(model_name)
processor = ViTImageProcessor.from_pretrained(model_name)

if 'cuda' in device and torch.cuda.is_available():
device_num = int(device.split(":")[-1])
info(f"Using GPU device {device_num}")
torch.cuda.set_device(device_num)
vit_model.to("cuda")
device = "cuda"
else:
raise ValueError(f'Could not find patch size in model name {model_name}')
info(f'Using patch size {patch_size} for model {model_name}')

# Load images and generate embeddings
device = 'cuda' if torch.cuda.is_available() else 'cpu'
with torch.no_grad():
# Set the cuda device
if torch.cuda.is_available():
model = model.to(device)

for filename in images:
# Skip if the embedding already exists
if Path(f'{filename}_{model_name}.npy').exists():
device = "cpu"

# Batch process the images
batches = [images[i:i + batch_size] for i in range(0, len(images), batch_size)]
for batch in batches:
try:
# Skip running the model if the embeddings already exist
if all([has_cached_embedding(model_name, filename) for filename in batch]):
continue

try:
# Load the image
square_img = Image.open(filename)

# Do some image processing to reduce the noise in the image
# Gaussian blur
square_img = square_img.filter(ImageFilter.GaussianBlur(radius=1))

image = np.array(square_img)

norm_transform = pth_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
img_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
# Noramlize the tensor with the mean and std of the ImageNet dataset
img_tensor = norm_transform(img_tensor)
img_tensor = img_tensor.unsqueeze(0) # Add batch dimension
if 'cuda' in device:
img_tensor = img_tensor.to(device)
features = model(img_tensor)

# TODO: add attention map cach as optional
# attentions = model.get_last_selfattention(img_tensor)

# nh = attentions.shape[1] # number of head

# w_featmap = 224 // patch_size
# h_featmap = 224 // patch_size
images = [Image.open(filename).convert("RGB") for filename in batch]
inputs = processor(images=images, return_tensors="pt").to(device)

# Keep only the output patch attention
# attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
# attentions = attentions.reshape(nh, w_featmap, h_featmap)
# attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[
# 0].cpu().numpy()
#
# # Resize the attention map to the original image size
# attentions = np.uint8(255 * attentions[0])
with torch.no_grad():
embeddings = vit_model(**inputs)

# Get the feature embeddings
embeddings = features.squeeze(dim=0) # Remove batch dimension
embeddings = embeddings.cpu().numpy() # Convert to numpy array
batch_embeddings = embeddings.last_hidden_state[:, 0, :].cpu().numpy()

cache_embedding(embeddings, model_name, filename) # save the embedding to disk
#cache_attention(attentions, model_name, filename) # save the attention map to disk
except Exception as e:
err(f'Error processing {filename}: {e}')
# Save the embeddings
for emb, filename in zip(batch_embeddings, batch):
emb = emb.astype(np.float32)
cache_embedding(emb, model_name, filename)
except Exception as e:
err(f'Error processing {batch}: {e}')


def compute_norm_embedding(model_name: str, images: list):
Expand All @@ -172,15 +136,15 @@ def compute_norm_embedding(model_name: str, images: list):

# If using a GPU, set then skip the parallel CPU processing
if torch.cuda.is_available():
compute_embedding(images, model_name)
compute_embedding_vits(images, model_name)
else:
# Use a pool of processes to speed up the embedding generation 20 images at a time on each process
num_processes = min(multiprocessing.cpu_count(), len(images) // 20)
num_processes = max(1, num_processes)
info(f'Using {num_processes} processes to compute {len(images)} embeddings 20 at a time ...')
with multiprocessing.Pool(num_processes) as pool:
args = [(images[i:i + 20], model_name) for i in range(0, len(images), 20)]
pool.starmap(compute_embedding, args)
pool.starmap(compute_embedding_vits, args)


def calc_mean_std(image_files: list) -> tuple:
Expand Down
Loading

0 comments on commit 427931e

Please sign in to comment.