Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retrieve image text pairs #14

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
bdc8649
Add a raw retrieval option to store queries and their retrieved candi…
sahel-sh May 22, 2024
1104db5
Add interactive retriever
sahel-sh May 27, 2024
1017242
Merge branch 'main' into raw_retrieval
sahel-sh May 28, 2024
f6d7659
add retrieval of image-text pairs to retrieval config yaml
sahel-sh May 28, 2024
cfeaaf5
left a todo for retrieving complementary candidates
sahel-sh May 28, 2024
b4bde72
Merge branch 'raw_retrieval' into interactive_retrieval
sahel-sh May 28, 2024
1a3b79a
retrieve complement candidates
sahel-sh May 28, 2024
5035a49
Merge branch 'main' into raw_retrieval
sahel-sh May 29, 2024
9b2e01b
reformated with 120 chars
sahel-sh May 29, 2024
d8f81bf
reformatted with 120
sahel-sh May 29, 2024
7c568c6
reformatted with 120
sahel-sh May 29, 2024
e1c4915
fix retrieved candidates path
sahel-sh May 31, 2024
de8a73c
Merge branch 'raw_retrieval' into interactive_retrieval
sahel-sh May 31, 2024
70a0145
Merge branch 'interactive_retrieval' into retrieve_image_text_pairs
sahel-sh May 31, 2024
9957ef3
fixed query embedder config
sahel-sh May 31, 2024
938e53f
fix distributed settings
sahel-sh May 31, 2024
a95131c
skip getting complements for candidates with text,image modality
sahel-sh May 31, 2024
db3222e
fix typpo
sahel-sh May 31, 2024
5cc9370
refactor raw retrieval
sahel-sh Jun 5, 2024
9e5ec40
refactor interactive_retriever
sahel-sh Jun 5, 2024
424ae26
refactored raw retrieval
sahel-sh Jun 5, 2024
b4acd11
Add a todo for image-txt retrieval
sahel-sh Jun 5, 2024
160fea0
add default value for not to break the existing calls
sahel-sh Jun 5, 2024
46222df
merge with raw_retrieval
sahel-sh Jun 5, 2024
410ffd3
update requirements
sahel-sh Jun 5, 2024
7323f36
temp commit
sahel-sh Jun 7, 2024
65a6b08
temp fix for complement retriever
sahel-sh Jun 7, 2024
72871c1
add complement candidates
sahel-sh Jun 7, 2024
2e21807
Merge branch 'main' into retrieve_image_text_pairs
sahel-sh Jul 9, 2024
c62880b
addressed review comments
sahel-sh Jul 13, 2024
b25ddbc
polish readme
sahel-sh Jul 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/common/faiss_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,18 @@ dependencies:
- faiss-gpu
- pip
- pip:
- omegaconf
- omegaconf
- python-dotenv
- pycocoevalcap
- numpy
- wandb
- ftfy
- tqdm
- regex
- typeguard
- datasets
- transformers
- git+https://github.com/openai/CLIP.git
- timm
- fairscale
- opencv-python
214 changes: 214 additions & 0 deletions src/common/interactive_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
"""
Retrieves candidates for a given set of queries after embedding them.
"""

from enum import Enum
import gc
import json
import os

import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP

from data.mbeir_dataset import (
MBEIRInferenceOnlyDataset,
MBEIRInferenceOnlyCollator,
)
import dist_utils
from dist_utils import ContiguousDistributedSampler
from mbeir_embedder import generate_embeds_and_ids_for_dataset_with_gather
from utils import build_model_from_config, set_seed
from data.preprocessing.utils import unhash_did


class QueryModality(Enum):
TEXT = "text"
IMAGE = "image"


class InteractiveRetriever:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the InteractiveRetriever requires a pre-built candidate index file to function correctly. To assist users with this setup, could we consider adding a script, such as run_interactive_retriever_pipeline.sh, that demonstrates the entire pipeline? This script would cover embedding, indexing, and loading the index for the interactive retriever and retrieve demo queries. Additionally, incorporating a step-by-step guide in the README could greatly enhance the user experience.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, I created unirag folder next to inbatch for BLIP_FF Large and CLIP_SF Large. It has embed, index, and retrieval configs and the run script as your requested.

def __init__(self, cand_index_path: str, candidates_path: str, config):
# Set up seed for reproducibility
seed = config.seed + dist_utils.get_rank()
set_seed(seed)

# MSCOCO's dataset id is hardcoded since the dataset id and query/candidate modalities determine the instruction part of the prompt.
# MSCOCO's dataset supports prompt instructions for both image->text and text->image query->candidate modalities.
self.dataset_id = 9
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the InteractiveRetriever specifically designed for the MSCOCO dataset? I observed that the self.dataset_id and task_id assignments appear to be hardcoded.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the interactiveRetriever to be generic, but the way it is currently integrated with the mbeir_retriever is for retrieving complement candidates to create image text pairs and mscoco is a dataset that supports both text->image and image->text queries. Now the embeir retriever sets the dataset to mscoco for this task.


# Setup query embedder
model = build_model_from_config(config)
model.eval()

# Ensure the model has an 'encode' method before generating embeddings
if not callable(getattr(model, "encode_mbeir_batch")):
raise AttributeError("The provided model does not have a callable 'encode' method.")
if not callable(getattr(model, "get_img_preprocess_fn")):
raise AttributeError("The provided model does not have an 'img_preprocess_fn' attribute.")
if not callable(getattr(model, "get_tokenizer")):
raise AttributeError("The provided model does not have a 'tokenizer' attribute.")
self.img_preprocess_fn = model.get_img_preprocess_fn()
self.tokenizer = model.get_tokenizer()

# Enable distributed data parallel
model = model.to(config.dist_config.gpu_id)
if config.dist_config.distributed_mode:
model = DDP(model, device_ids=[config.dist_config.gpu_id])
self.model = model
print(f"Model is set up on GPU {config.dist_config.gpu_id}.")

self.cand_index_path = cand_index_path
self.config = config
self.queries = []

# Load did_to_candidates
self.did_to_candidates = {}
with open(candidates_path, "r") as f:
for l in f:
c = json.loads(l.strip())
assert c["did"] not in self.did_to_candidates, "dids must be unique"
self.did_to_candidates[c["did"]] = c

def add_queries(self, queries: list[tuple[str, str, str]]):
for query_modality, query_txt, query_img_path in queries:
if query_modality == QueryModality.TEXT.value:
task_id = 0
candidate_modality = QueryModality.IMAGE.value
assert query_txt, "Query with 'text' modality must have non-null 'query_txt'"
assert query_img_path is None, "Query with 'text' modality must have null 'query_img_path'"
elif query_modality == QueryModality.IMAGE.value:
task_id = 3
candidate_modality = QueryModality.TEXT.value
assert query_txt is None, "Query with 'image' modality must have null 'query_txt'"
assert query_img_path, "Query with 'image' modality must have non-null 'query_img_path'"
else:
raise ValueError("Only 'text' and 'image' query modalities are supported.")
self.queries.append(
{
# Hardcoded qid in format of dataset_id:query_num.
"qid": ":".join([str(self.dataset_id), str(len(self.queries) + 1)]),
"query_modality": query_modality,
"query_txt": query_txt,
"query_img_path": query_img_path,
"task_id": task_id,
"candidate_modality": candidate_modality,
}
)

def _embed_queries(self):
mbeir_data_dir = self.config.mbeir_data_dir
embed_config = self.config.embed_config

# Config for dataset
data_config = self.config.data_config
query_instruct_path = data_config.query_instruct_path
image_size = tuple(map(int, data_config.image_size.split(",")))

print_config = False
if dist_utils.is_main_process():
print(f"\nEmbedder Log: Generating embeddings for {len(self.queries)} queries.")
print_config = True

dataset = MBEIRInferenceOnlyDataset(
mbeir_data_dir,
self.queries,
query_instruct_path,
self.img_preprocess_fn,
enable_query_instruct=data_config.enable_query_instruct,
print_config=print_config,
)
collator = MBEIRInferenceOnlyCollator(
tokenizer=self.tokenizer,
image_size=image_size,
)

# Config for data loader
batch_size = self.config.dataloader_config.batch_size
num_workers = self.config.dataloader_config.num_workers

# Set up distributed data parallel
num_tasks = dist_utils.get_world_size()
global_rank = dist_utils.get_rank()
sampler = ContiguousDistributedSampler(
dataset,
num_replicas=num_tasks,
rank=global_rank,
) # Note: assume the dataset is in sorted order.
data_loader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
sampler=sampler,
shuffle=False, # Since we have distributed sampler, we don't need to shuffle the data here.
collate_fn=collator,
drop_last=False,
)
if dist.is_initialized():
dist.barrier() # Wait for rank 0 to finish saving the embeddings and ids.
if dist_utils.is_main_process():
print(f"Embedder Log: Data loader is set up.")
print(f"Embedder Log: Generating embeddings for {len(self.queries)} queries ...")
print(f"Inference with half precision: {embed_config.use_fp16}")

# Generate embeddings and ids
embedding_list, id_list = generate_embeds_and_ids_for_dataset_with_gather(
self.model,
data_loader,
device=self.config.dist_config.gpu_id,
use_fp16=embed_config.use_fp16,
)

# Save the embeddings to a temprary .npy
if not dist.is_initialized() or dist.get_rank() == 0:
print(f"Embedder Log: Embedding list length: {len(embedding_list)}")
print(f"Embedder Log: ID list length: {len(id_list)}")

# Save the embeddings to .npy
self.embed_file = "interactive_queries_embed.npy"
np.save(self.embed_file, embedding_list)
print(f"Embedder Log: Saved embeddings to {self.embed_file}.")

if dist.is_initialized():
dist.barrier() # Wait for rank 0 to finish saving the embeddings and ids.

# Delete the embeddings and IDs to free up memory
del embedding_list
del id_list
del data_loader
del dataset
del collator
del sampler

# Explicitly call the garbage collector
gc.collect()
torch.cuda.empty_cache()

def retrieve(self, k: int = 1, batch_size: int = 100):
results = []
self._embed_queries()
# retrieve skipping the eval
from mbeir_retriever import search_index

print(f"Retriever: Searching with k={k}")
_, retrieved_indices = search_index(
self.embed_file,
self.cand_index_path,
batch_size=batch_size,
num_cand_to_retrieve=k,
)

for indices in retrieved_indices:
retrieved_cands = []
for hashed_doc_id in indices:
doc_id = unhash_did(hashed_doc_id)
retrieved_cands.append(self.did_to_candidates[doc_id])
results.append(retrieved_cands)

# Remove the temprarily stored embeddings
os.remove(self.embed_file)

return results
10 changes: 1 addition & 9 deletions src/common/mbeir_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import dist_utils
from dist_utils import ContiguousDistributedSampler
from utils import build_model_from_config
from utils import build_model_from_config, set_seed
from data.mbeir_dataset import (
MBEIRMainDataset,
MBEIRMainCollator,
Expand Down Expand Up @@ -461,14 +461,6 @@ def generate_embeds_for_config(model, img_preprocess_fn, tokenizer, config):
dist.barrier() # Wait for rank 0 to finish saving the embeddings and ids.


def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)


def main(config):
# Set up seed for reproducibility
seed = config.seed + dist_utils.get_rank()
Expand Down
Loading