Skip to content

Commit

Permalink
ServiceRetriever and RankLLM REST API support (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
xpbowler committed May 27, 2024
1 parent d62d111 commit eb52bfc
Show file tree
Hide file tree
Showing 8 changed files with 376 additions and 53 deletions.
122 changes: 122 additions & 0 deletions src/rank_llm/api/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import argparse
from flask import Flask, jsonify, request

from rank_llm import retrieve_and_rerank
from rank_llm.rerank.rank_listwise_os_llm import RankListwiseOSLLM
from rank_llm.rerank.api_keys import get_openai_api_key, get_azure_openai_args
from rank_llm.rerank.rank_gpt import SafeOpenai
from rank_llm.rerank.rankllm import PromptMode

""" API URL FORMAT
http://localhost:8082/api/model/{model_name}/index/{index_name}/{retriever_base_host}?query={query}&hits_retriever={top_k_retriever}&hits_reranker={top_k_reranker}&qid={qid}&num_passes={num_passes}
hits_retriever, hits_reranker, qid, and num_passes are OPTIONAL
Default to 20, 5, None, and 1 respectively
"""


def create_app(model, port, use_azure_openai=False):

app = Flask(__name__)
if model == "rank_zephyr":
print(f"Loading {model} model...")
# Load specified model upon server initialization
default_agent = RankListwiseOSLLM(
model=f"castorini/{model}_7b_v1_full",
context_size=4096,
prompt_mode=PromptMode.RANK_GPT,
num_few_shot_examples=0,
device="cuda",
num_gpus=1,
variable_passages=True,
window_size=20,
system_message="You are RankLLM, an intelligent assistant that can rank passages based on their relevancy to the query.",
)
elif model == "rank_vicuna":
print(f"Loading {model} model...")
# Load specified model upon server initialization
default_agent = RankListwiseOSLLM(
model=f"castorini/{model}_7b_v1",
context_size=4096,
prompt_mode=PromptMode.RANK_GPT,
num_few_shot_examples=0,
device="cuda",
num_gpus=1,
variable_passages=False,
window_size=20,
)
elif "gpt" in model:
print(f"Loading {model} model...")
# Load specified model upon server initialization
openai_keys = get_openai_api_key()
print(openai_keys)
default_agent = SafeOpenai(
model=model,
context_size=8192,
prompt_mode=PromptMode.RANK_GPT,
num_few_shot_examples=0,
keys=openai_keys,
**(get_azure_openai_args() if use_azure_openai else {}),
)
else:
raise ValueError(f"Unsupported model: {model}")

@app.route(
"/api/model/<string:model_path>/index/<string:dataset>/<string:host>",
methods=["GET"],
)
def search(model_path, dataset, host):

query = request.args.get("query", type=str)
top_k_retrieve = request.args.get("hits_retriever", default=20, type=int)
top_k_rerank = request.args.get("hits_reranker", default=5, type=int)
qid = request.args.get("qid", default=None, type=str)
num_passes = request.args.get("num_passes", default=1, type=int)

try:
# Assuming the function is called with these parameters and returns a response
response = retrieve_and_rerank.retrieve_and_rerank(
dataset=dataset,
query=query,
model_path=model_path,
host="http://localhost:" + host,
interactive=True,
top_k_rerank=top_k_rerank,
top_k_retrieve=top_k_retrieve,
qid=qid,
populate_exec_summary=False,
default_agent=default_agent,
num_passes=num_passes,
)

return jsonify(response[0]), 200
except Exception as e:
return jsonify({"error": str(e)}), 500

return app, port


def main():
parser = argparse.ArgumentParser(description="Start the RankLLM Flask server.")
parser.add_argument(
"--model",
type=str,
default="rank_zephyr",
help="The model to load (e.g., rank_zephyr).",
)
parser.add_argument(
"--port", type=int, default=8082, help="The port to run the Flask server on."
)
parser.add_argument(
"--use_azure_openai", action="store_true", help="Use Azure OpenAI API."
)
args = parser.parse_args()

app, port = create_app(args.model, args.port, args.use_azure_openai)
app.run(host="0.0.0.0", port=port, debug=False)


if __name__ == "__main__":
main()
4 changes: 1 addition & 3 deletions src/rank_llm/demo/rerank_stored_retrieved_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
from rank_llm.data import read_requests_from_file, DataWriter
from rank_llm.rerank.zephyr_reranker import ZephyrReranker

file_name = (
"retrieve_results/BM25/retrieve_results_dl23_top20.json"
)
file_name = "retrieve_results/BM25/retrieve_results_dl23_top20.json"
requests = read_requests_from_file(file_name)

reranker = ZephyrReranker()
Expand Down
2 changes: 1 addition & 1 deletion src/rank_llm/rerank/rank_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def run_llm(
prompt: Union[str, List[Dict[str, str]]],
current_window_size: Optional[int] = None,
) -> Tuple[str, int]:
model_key = "engine" if self.use_azure_ai else "model"
model_key = "model"
response = self._call_completion(
messages=prompt,
temperature=0,
Expand Down
19 changes: 14 additions & 5 deletions src/rank_llm/rerank/rankllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def permutation_pipeline(
rank_start: int,
rank_end: int,
logging: bool = False,
populate_exec_summary: bool = True,
) -> Result:
"""
Runs the permutation pipeline on the passed in result set within the passed in rank range.
Expand All @@ -135,10 +136,11 @@ def permutation_pipeline(
)
if logging:
print(f"output: {permutation}")
ranking_exec_info = RankingExecInfo(
prompt, permutation, in_token_count, out_token_count
)
result.ranking_exec_summary.append(ranking_exec_info)
if populate_exec_summary:
ranking_exec_info = RankingExecInfo(
prompt, permutation, in_token_count, out_token_count
)
result.ranking_exec_summary.append(ranking_exec_info)
result = self.receive_permutation(result, permutation, rank_start, rank_end)
return result

Expand All @@ -151,6 +153,7 @@ def sliding_windows(
step: int,
shuffle_candidates: bool = False,
logging: bool = False,
populate_exec_summary: bool = True,
) -> Result:
"""
Applies the sliding window algorithm to the reranking process.
Expand Down Expand Up @@ -188,7 +191,11 @@ def sliding_windows(
while end_pos > rank_start and start_pos + step != rank_start:
start_pos = max(start_pos, rank_start)
rerank_result = self.permutation_pipeline(
rerank_result, start_pos, end_pos, logging
rerank_result,
start_pos,
end_pos,
logging,
populate_exec_summary=populate_exec_summary,
)
end_pos = end_pos - step
start_pos = start_pos - step
Expand Down Expand Up @@ -331,6 +338,8 @@ def covert_doc_to_prompt_content(self, doc: Dict[str, Any], max_length: int) ->
content = doc["segment"]
elif "contents" in doc:
content = doc["contents"]
elif "body" in doc:
content = doc["body"]
else:
content = doc["passage"]
if "title" in doc and doc["title"]:
Expand Down
2 changes: 2 additions & 0 deletions src/rank_llm/rerank/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def rerank_batch(
step: int = 10,
shuffle_candidates: bool = False,
logging: bool = False,
populate_exec_summary: bool = True,
) -> List[Result]:
"""
Reranks a list of requests using the RankLLM agent.
Expand Down Expand Up @@ -50,6 +51,7 @@ def rerank_batch(
step=step,
shuffle_candidates=shuffle_candidates,
logging=logging,
populate_exec_summary=populate_exec_summary,
)
results.append(result)
return results
Expand Down
98 changes: 98 additions & 0 deletions src/rank_llm/retrieve/service_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import json
import requests
from urllib import parse
from enum import Enum
from typing import Any, Dict, List, Union

from rank_llm.data import Request, Candidate, Query
from rank_llm.retrieve.pyserini_retriever import PyseriniRetriever, RetrievalMethod
from rank_llm.retrieve.repo_info import HITS_INFO
from rank_llm.retrieve.utils import compute_md5, download_cached_hits
from rank_llm.retrieve.retriever import RetrievalMode, Retriever


class ServiceRetriever:
def __init__(
self,
retrieval_mode: RetrievalMode = RetrievalMode.DATASET,
retrieval_method: RetrievalMethod = RetrievalMethod.BM25,
) -> None:
"""
Creates a ServiceRetriever instance with a specified retrieval method and mode.
Args:
retrieval_mode (RetrievalMode): The retrieval mode to be used. Defaults to DATASET. Only DATASET mode is currently supported.
retrieval_method (RetrievalMethod): The retrieval method to be used. Defaults to BM25.
Raises:
ValueError: If retrieval mode or retrieval method is invalid or missing.
"""
self._retrieval_mode = retrieval_mode
self._retrieval_method = retrieval_method

if retrieval_mode != RetrievalMode.DATASET:
raise ValueError(
f"{retrieval_mode} is not supported for ServiceRetriever. Only DATASET mode is currently supported."
)

if retrieval_method != RetrievalMethod.BM25:
raise ValueError(
f"{retrieval_method} is not supported for ServiceRetriever. Only BM25 is currently supported."
)

if not retrieval_method:
raise "Please provide a retrieval method."

if retrieval_method == RetrievalMethod.UNSPECIFIED:
raise ValueError(
f"Invalid retrieval method: {retrieval_method}. Please provide a specific retrieval method."
)

def retrieve(
self,
dataset: str,
request: Request,
k: int = 50,
host: str = "http://localhost:8081",
timeout: int = 10,
) -> Request:
"""
Executes the retrieval process based on the configation provided with the Retriever instance. Takes in a Request object with a query and empty candidates object and the top k items to retrieve.
Args:
request (Request): The request containing the query and qid.
dataset (str): The name of the dataset.
k (int, optional): The top k hits to retrieve. Defaults to 100.
host (str): The Anserini API host address. Defaults to http://localhost:8081
Returns:
Request. Contains a query and list of candidates
Raises:
ValueError: If the retrieval mode is invalid or the result format is not as expected.
"""

url = f"{host}/api/index/{dataset}/search?query={parse.quote(request.query.text)}&hits={str(k)}&qid={request.query.qid}"

try:
response = requests.get(url, timeout=timeout)
response.raise_for_status()
except requests.exceptions.RequestException as e:
raise type(e)(
f"Failed to retrieve data from Anserini server: {str(e)}"
) from e

data = response.json()
retrieved_results = Request(
query=Query(text=data["query"]["text"], qid=data["query"]["qid"])
)

for candidate in data["candidates"]:
retrieved_results.candidates.append(
Candidate(
docid=candidate["docid"],
score=candidate["score"],
doc=candidate["doc"],
)
)

return retrieved_results
Loading

0 comments on commit eb52bfc

Please sign in to comment.