Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Update mine_hard_negatives to using a full corpus and multiple positives #2848

Merged
merged 25 commits into from
Sep 11, 2024

Conversation

ArthurCamara
Copy link
Contributor

@ArthurCamara ArthurCamara commented Jul 18, 2024

Following from #2818, this PR updates the mine_hard_negatives method to allow for a corpus to be passed (thanks @ChrisGeishauser) and to a single query to have multiple positives (like the case in the TREC-Covid dataset).

The way it handles multiple positives is to check for duplicated queries in the input dataset. If the same query appears multiple times, every occurrence is considered another positive for that query. The method then only uses each query once when searching, and keep tracks of the positives retrieved.

One thing to consider is that, if the dataset has too many positives and use_triplets=True, the method will "explode" the dataset, returning n_positives*n_negatives rows per query. If use_triplets=False only n_positives rows are returned per query. An alternative would be to return a nested dataset, with a "positives" and a "negatives" column.

Comment on lines +673 to +677
if range_max > 2048 and use_faiss:
# FAISS on GPU can only retrieve up to 2048 documents per query
range_max = 2048
if verbose:
print("Using FAISS, we can only retrieve up to 2048 documents per query. Setting range_max to 2048.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice comment, I didn't realise this

query_embeddings = query_embeddings.cpu().numpy()
corpus_embeddings = corpus_embeddings.cpu().numpy()
index = faiss.IndexFlatIP(len(corpus_embeddings[0]))
index = faiss.IndexFlatIP(model.get_sentence_embedding_dimension())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@tomaarsen
Copy link
Collaborator

tomaarsen commented Sep 11, 2024

Hello!

Apologies for the delay, I've been recovering from a surgery since July.
I quite like the direction of this PR, so I'd like to get this merged before mine_hard_negatives releases. However, I did have some issues (mainly nr. 4):

  1. If we set model.encode(..., convert_to_numpy=True), then each batch will be moved to CPU after it's computed. This means we don't need the chunk_size. I did however add faiss_batch_size as I think that might be useful.
  2. The missing_negatives were incorrectly computed because it used the number of unique queries, not the total number of queries.
  3. Some printed data is duplicate, e.g. the size of the final dataset.
  4. This PR embedded 1) all queries (duplicates included), 2) all positives (duplicates included), 3) all positives + corpus texts (duplicates excluded), and 4) queries again (duplicates excluded). In practice, this meant that we did about 2x more embedding than necessary. I've updated it to just 1) queries (duplicates excluded) and 2) all positives + corpus texts (duplicates excluded), and then reconstructing the positive scores from the precomputed embeddings.

I've tested this up to datasets of 3m samples (gooaq, amazon-qa) with/without FAISS and with/without a CrossEncoder.

Here's a script to test it:

from pprint import pprint
from sentence_transformers.util import mine_hard_negatives
from sentence_transformers import SentenceTransformer, CrossEncoder
from datasets import load_dataset

# Load a Sentence Transformer model
model = SentenceTransformer("all-MiniLM-L6-v2")

# Load a dataset to mine hard negatives from
dataset = load_dataset("sentence-transformers/amazon-qa", split="train").select(range(50000))
print(dataset)

corpus = load_dataset("sentence-transformers/gooaq", split="train[:50000]")["answer"]
# cross_encoder_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
# cross_encoder = CrossEncoder(cross_encoder_name)
# Mine hard negatives
dataset = mine_hard_negatives(
    dataset=dataset,
    model=model,
    corpus=corpus,
    # cross_encoder=cross_encoder,
    range_min=0,
    range_max=10,
    max_score=0.8,
    margin=0,
    num_negatives=5,
    sampling_strategy="random",
    batch_size=512,
    use_faiss=True,
)
print(dataset)
pprint(dataset[0])
# dataset.push_to_hub("natural-questions-cnn-hard-negatives", "triplet", private=True)
breakpoint()

Would love to hear what you think @ArthurCamara @ChrisGeishauser. Again sorry for the radio silence.

  • Tom Aarsen

@tomaarsen tomaarsen changed the title Update mine_hard_negatives to using a full corpus and multiple positives [feat] Update mine_hard_negatives to using a full corpus and multiple positives Sep 11, 2024
@tomaarsen
Copy link
Collaborator

I'm going to move ahead with merging this, as I'd like to include this in the upcoming release. If you find a moment, feel free to experiment with this function and report any issues.

  • Tom Aarsen

@tomaarsen tomaarsen merged commit a3f2236 into UKPLab:master Sep 11, 2024
11 checks passed
tomaarsen added a commit that referenced this pull request Sep 11, 2024
…le positives (#2848)

* updated mine_hard_negatives method to include a seperate corpus for mining hard negatives.

* Run 'make check'

* Update "corpus" to just a list of strings

* Prevent duplicate embeddings if no separate corpus

* Deduplicate corpus

Add a positive to corpus indices mapping, useful to get non-deduplicated positives and to filter away positives taken from the corpus

* Skip rescoring positive pairs via pos_to_corpus_indices instead

* Add a mine_hard_negatives_from_corpus util

* Speedup pos_to_corpus_indices for large corpora

* Fix range_max by number of max_positives in dataset

* encode in chunks, ensure at least one positive per query always

* Hard_negative_mining with corpus and multiple positives is possible

* docstring

* Fix for random sampling

* fix for return_triplets=False

* Typo on list

* Fix bug with multiple positives. More efficient creation of some tensors.

* Fix offset of positives scoring with multiple chunks

* fix pytorch copy warning

* Only embed each text once; no need for chunking if convert_to_numpy=True

* Undo unintended changes

* Fix mismatch in anchor/positive and negatives if multiple positives per query

* Don't repeat positive_scores as it inflates the positive score counts

* Remove the "Count" for Difference as it's rather confusing

---------

Co-authored-by: Christian Geishauser <[email protected]>
Co-authored-by: Tom Aarsen <[email protected]>
@ArthurCamara
Copy link
Contributor Author

Nice! Thanks for your help, @tomaarsen! I actually used this code to create a set of "NanoBEIR" datasets that I just made public here: https://huggingface.co/collections/zeta-alpha-ai/nanobeir-66e1a0af21dfd93e620cd9f6

@tomaarsen
Copy link
Collaborator

Awesome! Will you have some writing on how well it correlates to BEIR itself? Because as we all know, BEIR takes forever to run, and a faster option is definitely interesting 😅

@tomaarsen
Copy link
Collaborator

tomaarsen commented Sep 12, 2024

I've found the post already 👀 https://www.zeta-alpha.com/post/fine-tuning-an-llm-for-state-of-the-art-retrieval-zeta-alpha-s-top-10-submission-to-the-the-mteb-be

Looks quite solid! Nice collection of datasets, I'm glad some of my datasets came in useful too. And your eventual training recipe seems to be about equivalent to CachedMultipleNegativesRankingLoss, which is GradCache + InfoNCE (except this doesn't update the temperature (called scale here, it's the inverse of temperature) over time).

Also, in my opinion the Sentence Transformers evaluators are very useful, especially InformationRetrievalEvaluator, but preparing this data is always rather difficult. Perhaps NanoBEIR is a nice moment to e.g. package these datasets such that people can easily import them for use with any Sentence Transformer model?

That said, a lot of modern ST models require specific prompts for queries vs documents, which is not implemented in this evaluator (or the hard negatives mining) (yet?).

  • Tom Aarsen

@ArthurCamara
Copy link
Contributor Author

I've found the post already 👀 https://www.zeta-alpha.com/post/fine-tuning-an-llm-for-state-of-the-art-retrieval-zeta-alpha-s-top-10-submission-to-the-the-mteb-be

Looks quite solid! Nice collection of datasets, I'm glad some of my datasets came in useful too. And your eventual training recipe seems to be about equivalent to CachedMultipleNegativesRankingLoss, which is GradCache + InfoNCE (except this doesn't update the temperature (called scale here, it's the inverse of temperature) over time).

Yup. That's exactly that. I began working with the new ST trainer, but ended up building our own custom trainer, mainly because it was not clear/straightforward how to use multiple negatives per query or change between using (or not) in-batch negatives, specially with multiple GPUs (i.e., gathering negatives across devices).

Giving it another shot is is in my (rather long, I have to admit) to-do list. I'm starting to work on other large models now, so maybe that's a good moment to get working on it again.

Also, in my opinion the Sentence Transformers evaluators are very useful, especially InformationRetrievalEvaluator, but preparing this data is always rather difficult. Perhaps NanoBEIR is a nice moment to e.g. package these datasets such that people can easily import them for use with any Sentence Transformer model?

Seems like it's quite similar to the hard negative mining use case. Perhaps I can try adapting the code (or NanoBEIR) to work directly with it. NanoBEIR format is mainly how it is now because that's what the MTEB datasets look like. No hard attachments there.

That said, a lot of modern ST models require specific prompts for queries vs documents, which is not implemented in this evaluator (or the hard negatives mining) (yet?).

You mean things like E5's query: and passage: prompts? (and all the other instructions from Mistral-based models) Wouldn't it be enough to pass these to encode?

  • Tom Aarsen

@tomaarsen
Copy link
Collaborator

Yup. That's exactly that. I began working with the new ST trainer, but ended up building our own custom trainer, mainly because it was not clear/straightforward how to use multiple negatives per query or change between using (or not) in-batch negatives, specially with multiple GPUs (i.e., gathering negatives across devices).

Giving it another shot is is in my (rather long, I have to admit) to-do list. I'm starting to work on other large models now, so maybe that's a good moment to get working on it again.

I don't blame you at all, I like building things myself to really understand the core mechanics as well.
On this topic: I recently had a discussion in #2831 about the merit of sharing negatives across devices versus GradCache.

Seems like it's quite similar to the hard negative mining use case. Perhaps I can try adapting the code (or NanoBEIR) to work directly with it. NanoBEIR format is mainly how it is now because that's what the MTEB datasets look like. No hard attachments there.

I think it's totally fine to keep that format as-is (it's pretty normal). My thoughts was to create a new package like sentence-transformers-nanobeir that neatly packages these, so people can use:

from sentence_transformers_nanobeir import NFCorpusEvaluator

...

evaluator = NFCorpusEvaluator()

# Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=evaluator,
)
trainer.train()

# or

metrics = evaluator(model)
print(metrics)

I'm also considering packaging it into Sentence Transformers directly, but that might be a bit overkill.

You mean things like E5's query: and passage: prompts? (and all the other instructions from Mistral-based models) Wouldn't it be enough to pass these to encode?

Yeah. It's not hard, it's just missing still, e.g. adding "query_prompt", "query_prompt_name", "passage_prompt", "passage_prompt_name" kwargs to pass on to model.encode.

  • Tom Aarsen

@ArthurCamara
Copy link
Contributor Author

I don't blame you at all, I like building things myself to really understand the core mechanics as well. On this topic: I recently had a discussion in #2831 about the merit of sharing negatives across devices versus GradCache.

I've read #2831 and I actually agree with you there. It is easier to just don't sync across devices and just increase the batch size accordingly.

My only follow-up issue (and that I'm trying to understand how to solve) is how to handle multiple hard negatives per query when using NoDuplicatesBatchSampler and MNRL. For instance, take the training_nli_v3.py example. The dataset has a format (a_1, p_1, n_2), (a_1, p_1, n_2)... with multiple hard negatives per query (See rows 17-21 of the triplet subset), but the sample code don't take advantage of this.
In the example, the NoDuplicatesBatchSampler throws away all the negatives except for the first one, as the query and the positive are already in the batch. Even if it didn't, and only considered a duplicate sample if the positive or negative already exist between all of the negatives, the forward pass of the MNRL would compute the loss for the same query n_hard_negative times, instead of just "throwing everything" into the same negative bin.

From what I understood, if the anchor a_1 has 2 hard negatives in the samples (a_1, p_1, n_1) and (a_1, p_1, n_2), the MNRL would consider each row as an independent sample.

I guess the "proper" way around it would be to disentangle query and documents in the sampler/collator and then each anchor have a list of integers of where its negatives are in the list of all documents. Or can you see another way around it? (I should probably open another issue to discuss this)

I think it's totally fine to keep that format as-is (it's pretty normal). My thoughts was to create a new package like sentence-transformers-nanobeir that neatly packages these, so people can use:

from sentence_transformers_nanobeir import NFCorpusEvaluator

...

evaluator = NFCorpusEvaluator()

# Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=evaluator,
)
trainer.train()

# or

metrics = evaluator(model)
print(metrics)

I'm also considering packaging it into Sentence Transformers directly, but that might be a bit overkill.

Maybe to start with, adding a function to util.py could be easier? I can probably get it working quickly. I can open an issue and a PR with this later this week.

Yeah. It's not hard, it's just missing still, e.g. adding "query_prompt", "query_prompt_name", "passage_prompt", "passage_prompt_name" kwargs to pass on to model.encode.

I've opened a PR here: #2951 that adds this (and an option to deal with instruction masking in left-padded tokens)

  • Tom Aarsen

@ArthurCamara
Copy link
Contributor Author

I guess the "proper" way around it would be to disentangle query and documents in the sampler/collator and then each anchor have a list of integers of where its negatives are in the list of all documents. Or can you see another way around it? (I should probably open another issue to discuss this)

Edit: I think I found another solution for this. I've opened #2954 to discuss this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants