diff --git a/src/rank_llm/api/server.py b/src/rank_llm/api/server.py new file mode 100644 index 0000000..f3a2c63 --- /dev/null +++ b/src/rank_llm/api/server.py @@ -0,0 +1,122 @@ +import argparse +from flask import Flask, jsonify, request + +from rank_llm import retrieve_and_rerank +from rank_llm.rerank.rank_listwise_os_llm import RankListwiseOSLLM +from rank_llm.rerank.api_keys import get_openai_api_key, get_azure_openai_args +from rank_llm.rerank.rank_gpt import SafeOpenai +from rank_llm.rerank.rankllm import PromptMode + +""" API URL FORMAT + +http://localhost:8082/api/model/{model_name}/index/{index_name}/{retriever_base_host}?query={query}&hits_retriever={top_k_retriever}&hits_reranker={top_k_reranker}&qid={qid}&num_passes={num_passes} + +hits_retriever, hits_reranker, qid, and num_passes are OPTIONAL +Default to 20, 5, None, and 1 respectively + +""" + + +def create_app(model, port, use_azure_openai=False): + + app = Flask(__name__) + if model == "rank_zephyr": + print(f"Loading {model} model...") + # Load specified model upon server initialization + default_agent = RankListwiseOSLLM( + model=f"castorini/{model}_7b_v1_full", + context_size=4096, + prompt_mode=PromptMode.RANK_GPT, + num_few_shot_examples=0, + device="cuda", + num_gpus=1, + variable_passages=True, + window_size=20, + system_message="You are RankLLM, an intelligent assistant that can rank passages based on their relevancy to the query.", + ) + elif model == "rank_vicuna": + print(f"Loading {model} model...") + # Load specified model upon server initialization + default_agent = RankListwiseOSLLM( + model=f"castorini/{model}_7b_v1", + context_size=4096, + prompt_mode=PromptMode.RANK_GPT, + num_few_shot_examples=0, + device="cuda", + num_gpus=1, + variable_passages=False, + window_size=20, + ) + elif "gpt" in model: + print(f"Loading {model} model...") + # Load specified model upon server initialization + openai_keys = get_openai_api_key() + print(openai_keys) + default_agent = SafeOpenai( + model=model, + context_size=8192, + prompt_mode=PromptMode.RANK_GPT, + num_few_shot_examples=0, + keys=openai_keys, + **(get_azure_openai_args() if use_azure_openai else {}), + ) + else: + raise ValueError(f"Unsupported model: {model}") + + @app.route( + "/api/model//index//", + methods=["GET"], + ) + def search(model_path, dataset, host): + + query = request.args.get("query", type=str) + top_k_retrieve = request.args.get("hits_retriever", default=20, type=int) + top_k_rerank = request.args.get("hits_reranker", default=5, type=int) + qid = request.args.get("qid", default=None, type=str) + num_passes = request.args.get("num_passes", default=1, type=int) + + try: + # Assuming the function is called with these parameters and returns a response + response = retrieve_and_rerank.retrieve_and_rerank( + dataset=dataset, + query=query, + model_path=model_path, + host="http://localhost:" + host, + interactive=True, + top_k_rerank=top_k_rerank, + top_k_retrieve=top_k_retrieve, + qid=qid, + populate_exec_summary=False, + default_agent=default_agent, + num_passes=num_passes, + ) + + return jsonify(response[0]), 200 + except Exception as e: + return jsonify({"error": str(e)}), 500 + + return app, port + + +def main(): + parser = argparse.ArgumentParser(description="Start the RankLLM Flask server.") + parser.add_argument( + "--model", + type=str, + default="rank_zephyr", + help="The model to load (e.g., rank_zephyr).", + ) + parser.add_argument( + "--port", type=int, default=8082, help="The port to run the Flask server on." + ) + parser.add_argument( + "--use_azure_openai", action="store_true", help="Use Azure OpenAI API." + ) + args = parser.parse_args() + + app, port = create_app(args.model, args.port, args.use_azure_openai) + app.run(host="0.0.0.0", port=port, debug=False) + + +if __name__ == "__main__": + main() diff --git a/src/rank_llm/demo/rerank_stored_retrieved_results.py b/src/rank_llm/demo/rerank_stored_retrieved_results.py index 669b6ae..da921a7 100644 --- a/src/rank_llm/demo/rerank_stored_retrieved_results.py +++ b/src/rank_llm/demo/rerank_stored_retrieved_results.py @@ -10,9 +10,7 @@ from rank_llm.data import read_requests_from_file, DataWriter from rank_llm.rerank.zephyr_reranker import ZephyrReranker -file_name = ( - "retrieve_results/BM25/retrieve_results_dl23_top20.json" -) +file_name = "retrieve_results/BM25/retrieve_results_dl23_top20.json" requests = read_requests_from_file(file_name) reranker = ZephyrReranker() diff --git a/src/rank_llm/rerank/rank_gpt.py b/src/rank_llm/rerank/rank_gpt.py index f85473c..35e9798 100644 --- a/src/rank_llm/rerank/rank_gpt.py +++ b/src/rank_llm/rerank/rank_gpt.py @@ -127,7 +127,7 @@ def run_llm( prompt: Union[str, List[Dict[str, str]]], current_window_size: Optional[int] = None, ) -> Tuple[str, int]: - model_key = "engine" if self.use_azure_ai else "model" + model_key = "model" response = self._call_completion( messages=prompt, temperature=0, diff --git a/src/rank_llm/rerank/rankllm.py b/src/rank_llm/rerank/rankllm.py index 2fa66e4..2f3af8a 100644 --- a/src/rank_llm/rerank/rankllm.py +++ b/src/rank_llm/rerank/rankllm.py @@ -114,6 +114,7 @@ def permutation_pipeline( rank_start: int, rank_end: int, logging: bool = False, + populate_exec_summary: bool = True, ) -> Result: """ Runs the permutation pipeline on the passed in result set within the passed in rank range. @@ -135,10 +136,11 @@ def permutation_pipeline( ) if logging: print(f"output: {permutation}") - ranking_exec_info = RankingExecInfo( - prompt, permutation, in_token_count, out_token_count - ) - result.ranking_exec_summary.append(ranking_exec_info) + if populate_exec_summary: + ranking_exec_info = RankingExecInfo( + prompt, permutation, in_token_count, out_token_count + ) + result.ranking_exec_summary.append(ranking_exec_info) result = self.receive_permutation(result, permutation, rank_start, rank_end) return result @@ -151,6 +153,7 @@ def sliding_windows( step: int, shuffle_candidates: bool = False, logging: bool = False, + populate_exec_summary: bool = True, ) -> Result: """ Applies the sliding window algorithm to the reranking process. @@ -188,7 +191,11 @@ def sliding_windows( while end_pos > rank_start and start_pos + step != rank_start: start_pos = max(start_pos, rank_start) rerank_result = self.permutation_pipeline( - rerank_result, start_pos, end_pos, logging + rerank_result, + start_pos, + end_pos, + logging, + populate_exec_summary=populate_exec_summary, ) end_pos = end_pos - step start_pos = start_pos - step @@ -331,6 +338,8 @@ def covert_doc_to_prompt_content(self, doc: Dict[str, Any], max_length: int) -> content = doc["segment"] elif "contents" in doc: content = doc["contents"] + elif "body" in doc: + content = doc["body"] else: content = doc["passage"] if "title" in doc and doc["title"]: diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index 681da85..fedb7f6 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -21,6 +21,7 @@ def rerank_batch( step: int = 10, shuffle_candidates: bool = False, logging: bool = False, + populate_exec_summary: bool = True, ) -> List[Result]: """ Reranks a list of requests using the RankLLM agent. @@ -50,6 +51,7 @@ def rerank_batch( step=step, shuffle_candidates=shuffle_candidates, logging=logging, + populate_exec_summary=populate_exec_summary, ) results.append(result) return results diff --git a/src/rank_llm/retrieve/service_retriever.py b/src/rank_llm/retrieve/service_retriever.py new file mode 100644 index 0000000..1beb3c0 --- /dev/null +++ b/src/rank_llm/retrieve/service_retriever.py @@ -0,0 +1,98 @@ +import json +import requests +from urllib import parse +from enum import Enum +from typing import Any, Dict, List, Union + +from rank_llm.data import Request, Candidate, Query +from rank_llm.retrieve.pyserini_retriever import PyseriniRetriever, RetrievalMethod +from rank_llm.retrieve.repo_info import HITS_INFO +from rank_llm.retrieve.utils import compute_md5, download_cached_hits +from rank_llm.retrieve.retriever import RetrievalMode, Retriever + + +class ServiceRetriever: + def __init__( + self, + retrieval_mode: RetrievalMode = RetrievalMode.DATASET, + retrieval_method: RetrievalMethod = RetrievalMethod.BM25, + ) -> None: + """ + Creates a ServiceRetriever instance with a specified retrieval method and mode. + + Args: + retrieval_mode (RetrievalMode): The retrieval mode to be used. Defaults to DATASET. Only DATASET mode is currently supported. + retrieval_method (RetrievalMethod): The retrieval method to be used. Defaults to BM25. + + Raises: + ValueError: If retrieval mode or retrieval method is invalid or missing. + """ + self._retrieval_mode = retrieval_mode + self._retrieval_method = retrieval_method + + if retrieval_mode != RetrievalMode.DATASET: + raise ValueError( + f"{retrieval_mode} is not supported for ServiceRetriever. Only DATASET mode is currently supported." + ) + + if retrieval_method != RetrievalMethod.BM25: + raise ValueError( + f"{retrieval_method} is not supported for ServiceRetriever. Only BM25 is currently supported." + ) + + if not retrieval_method: + raise "Please provide a retrieval method." + + if retrieval_method == RetrievalMethod.UNSPECIFIED: + raise ValueError( + f"Invalid retrieval method: {retrieval_method}. Please provide a specific retrieval method." + ) + + def retrieve( + self, + dataset: str, + request: Request, + k: int = 50, + host: str = "http://localhost:8081", + timeout: int = 10, + ) -> Request: + """ + Executes the retrieval process based on the configation provided with the Retriever instance. Takes in a Request object with a query and empty candidates object and the top k items to retrieve. + + Args: + request (Request): The request containing the query and qid. + dataset (str): The name of the dataset. + k (int, optional): The top k hits to retrieve. Defaults to 100. + host (str): The Anserini API host address. Defaults to http://localhost:8081 + + Returns: + Request. Contains a query and list of candidates + Raises: + ValueError: If the retrieval mode is invalid or the result format is not as expected. + """ + + url = f"{host}/api/index/{dataset}/search?query={parse.quote(request.query.text)}&hits={str(k)}&qid={request.query.qid}" + + try: + response = requests.get(url, timeout=timeout) + response.raise_for_status() + except requests.exceptions.RequestException as e: + raise type(e)( + f"Failed to retrieve data from Anserini server: {str(e)}" + ) from e + + data = response.json() + retrieved_results = Request( + query=Query(text=data["query"]["text"], qid=data["query"]["qid"]) + ) + + for candidate in data["candidates"]: + retrieved_results.candidates.append( + Candidate( + docid=candidate["docid"], + score=candidate["score"], + doc=candidate["doc"], + ) + ) + + return retrieved_results diff --git a/src/rank_llm/retrieve_and_rerank.py b/src/rank_llm/retrieve_and_rerank.py index b8d28c3..cf960cd 100644 --- a/src/rank_llm/retrieve_and_rerank.py +++ b/src/rank_llm/retrieve_and_rerank.py @@ -1,24 +1,26 @@ import copy from typing import Any, Dict, List, Union -from rank_llm.data import Request +from rank_llm.data import Request, Query from rank_llm.evaluation.trec_eval import EvalFunction from rank_llm.rerank.api_keys import get_azure_openai_args, get_openai_api_key from rank_llm.rerank.rank_gpt import SafeOpenai from rank_llm.rerank.rank_listwise_os_llm import RankListwiseOSLLM -from rank_llm.rerank.rankllm import PromptMode +from rank_llm.rerank.rankllm import RankLLM, PromptMode from rank_llm.rerank.reranker import Reranker from rank_llm.retrieve.pyserini_retriever import RetrievalMethod from rank_llm.retrieve.retriever import RetrievalMode, Retriever +from rank_llm.retrieve.service_retriever import ServiceRetriever from rank_llm.retrieve.topics_dict import TOPICS def retrieve_and_rerank( model_path: str, dataset: Union[str, List[str], List[Dict[str, Any]]], - retrieval_mode: RetrievalMode, - retrieval_method: RetrievalMethod, - top_k_candidates: int = 100, + retrieval_mode: RetrievalMode = RetrievalMode.DATASET, + retrieval_method: RetrievalMethod = RetrievalMethod.BM25, + top_k_retrieve: int = 50, + top_k_rerank: int = 10, context_size: int = 4096, device: str = "cuda", num_gpus: int = 1, @@ -27,6 +29,7 @@ def retrieve_and_rerank( shuffle_candidates: bool = False, print_prompts_responses: bool = False, query: str = "", + qid: int = 1, use_azure_openai: bool = False, variable_passages: bool = False, num_passes: int = 1, @@ -36,9 +39,16 @@ def retrieve_and_rerank( index_path: str = None, topics_path: str = None, index_type: str = None, + interactive: bool = False, + host: str = "http://localhost:8081", + populate_exec_summary: bool = False, + default_agent: RankLLM = None, ): + model_full_path = "" + if interactive and default_agent is not None: + agent = default_agent # Construct Rerank Agent - if "gpt" in model_path or use_azure_openai: + elif "gpt" in model_path or use_azure_openai: openai_keys = get_openai_api_key() agent = SafeOpenai( model=model_path, @@ -49,8 +59,17 @@ def retrieve_and_rerank( **(get_azure_openai_args() if use_azure_openai else {}), ) elif "vicuna" in model_path.lower() or "zephyr" in model_path.lower(): + if model_path.lower() == "rank_zephyr": + model_full_path = "castorini/rank_zephyr_7b_v1_full" + elif model_path.lower() == "rank_vicuna": + model_full_path = "castorini/rank_vicuna_7b_v1" + else: + model_full_path = model_path + + print(f"Loading {model_path} ...") + agent = RankListwiseOSLLM( - model=model_path, + model=model_full_path, context_size=context_size, prompt_mode=prompt_mode, num_few_shot_examples=num_few_shot_examples, @@ -64,62 +83,91 @@ def retrieve_and_rerank( raise ValueError(f"Unsupported model: {model_path}") # Retrieve - print("Retrieving:") - if retrieval_mode == RetrievalMode.DATASET: - requests = Retriever.from_dataset_with_prebuilt_index( - dataset_name=dataset, retrieval_method=retrieval_method + print(f"Retrieving top {top_k_retrieve} passages...") + if interactive and retrieval_mode != RetrievalMode.DATASET: + raise ValueError( + f"Unsupport retrieval mode for interactive retrieval. Currently only DATASET mode is supported." ) + + if retrieval_mode == RetrievalMode.DATASET: + if interactive: + + service_retriever = ServiceRetriever( + retrieval_method=retrieval_method, retrieval_mode=retrieval_mode + ) + requests = [ + service_retriever.retrieve( + dataset=dataset, + request=Request(query=Query(text=query, qid=qid)), + k=top_k_retrieve, + host=host, + ) + ] + else: + requests = Retriever.from_dataset_with_prebuilt_index( + dataset_name=dataset, retrieval_method=retrieval_method + ) + elif retrieval_mode == RetrievalMode.CUSTOM: requests = Retriever.from_custom_index( index_path=index_path, topics_path=topics_path, index_type=index_type ) else: raise ValueError(f"Invalid retrieval mode: {retrieval_mode}") - print("Reranking:") + print(f"Retrieval complete!") + + # Reranking + print(f"Reranking and returning {top_k_rerank} passages...") reranker = Reranker(agent) for pass_ct in range(num_passes): print(f"Pass {pass_ct + 1} of {num_passes}:") - rerank_results = reranker.rerank_batach( + rerank_results = reranker.rerank_batch( requests, - rank_end=top_k_candidates, - window_size=min(window_size, top_k_candidates), + rank_end=top_k_retrieve, + window_size=min(window_size, top_k_retrieve), shuffle_candidates=shuffle_candidates, logging=print_prompts_responses, step=step_size, + populate_exec_summary=populate_exec_summary, ) - # generate trec_eval file & evaluate for named datasets only - if isinstance(dataset, str): - file_name = reranker.write_rerank_results( - retrieval_method.name, - rerank_results, - shuffle_candidates, - top_k_candidates=top_k_candidates, - pass_ct=None if num_passes == 1 else pass_ct, - window_size=window_size, - dataset_name=dataset, - ) - if ( - dataset in TOPICS - and dataset not in ["dl22", "dl22-passage", "news"] - and TOPICS[dataset] not in ["dl22", "dl22-passage", "news"] - ): - print("Evaluating:") - EvalFunction.eval( - ["-c", "-m", "ndcg_cut.1", TOPICS[dataset], file_name] - ) - EvalFunction.eval( - ["-c", "-m", "ndcg_cut.5", TOPICS[dataset], file_name] - ) - EvalFunction.eval( - ["-c", "-m", "ndcg_cut.10", TOPICS[dataset], file_name] - ) - else: - print(f"Skipping evaluation as {dataset} is not in TOPICS.") if num_passes > 1: requests = [ - Request(copy.deepycopy(r.query), copy.deepcopy(r.candidates)) + Request(copy.deepcopy(r.query), copy.deepcopy(r.candidates)) for r in rerank_results ] + print(f"Reranking with {num_passes} passes complete!") + rerank_results = [ + rr._replace(candidates=rr.candidates[:top_k_rerank]) for rr in rerank_results + ] + + # generate trec_eval file & evaluate for named datasets only + if isinstance(dataset, str): + file_name = reranker.write_rerank_results( + retrieval_method.name, + rerank_results, + shuffle_candidates, + top_k_candidates=top_k_retrieve, + pass_ct=None if num_passes == 1 else pass_ct, + window_size=window_size, + dataset_name=dataset, + ) + if ( + dataset in TOPICS + and dataset not in ["dl22", "dl22-passage", "news"] + and TOPICS[dataset] not in ["dl22", "dl22-passage", "news"] + ): + print("Evaluating:") + EvalFunction.eval( + ["-c", "-m", "ndcg_cut.1", TOPICS[dataset], file_name] + ) + EvalFunction.eval( + ["-c", "-m", "ndcg_cut.5", TOPICS[dataset], file_name] + ) + EvalFunction.eval( + ["-c", "-m", "ndcg_cut.10", TOPICS[dataset], file_name] + ) + else: + print(f"Skipping evaluation as {dataset} is not in TOPICS.") return rerank_results diff --git a/test/retrieve/test_ServiceRetriever.py b/test/retrieve/test_ServiceRetriever.py new file mode 100644 index 0000000..a6ff384 --- /dev/null +++ b/test/retrieve/test_ServiceRetriever.py @@ -0,0 +1,46 @@ +import unittest + +from rank_llm.retrieve.service_retriever import ServiceRetriever +from rank_llm.data import Request, Query, Candidate +from rank_llm import retrieve_and_rerank +from rank_llm.retrieve.pyserini_retriever import RetrievalMethod +from rank_llm.retrieve.retriever import RetrievalMode, Retriever + + +class TestServiceRetriever(unittest.TestCase): + def test_from_datatest_with_prebuilt_index(self): + + service_retriever = ServiceRetriever( + retrieval_method=RetrievalMethod.BM25, retrieval_mode=RetrievalMode.DATASET + ) + response = [ + service_retriever.retrieve( + dataset="msmarco-v2.1-doc", + request=Request(query=Query(text="hello", qid="1234")), + k=20, + host="http://localhost:8081", + ) + ] + + assert len(response[0].candidates) == 20 + assert type(response[0].candidates[0]) == Candidate + assert response[0].query == Query(text="hello", qid="1234") + + def test_retrieve_and_rerank_interactive(self): + top_k = 14 + + response = retrieve_and_rerank.retrieve_and_rerank( + dataset="msmarco-v2.1-doc", + query="hello", + model_path="rank_zephyr", + interactive=True, + top_k_retrieve=top_k, + exec_summary=False, + ) + + response = response[0] + assert len(response.candidates) == top_k + for candidate in response.candidates: + print(candidate.docid) + print(candidate.score) + # print(candidate.doc)