Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Re-ranking Retrieval Optimization in Advanced RAG #1366

Open
kevintsai1202 opened this issue Sep 15, 2024 · 0 comments
Open

Support for Re-ranking Retrieval Optimization in Advanced RAG #1366

kevintsai1202 opened this issue Sep 15, 2024 · 0 comments
Labels
advisors RAG Issues related to Retrieval Augmented Generation

Comments

@kevintsai1202
Copy link

kevintsai1202 commented Sep 15, 2024

I want to add a re-ranking feature in RAG but have encountered some issues that need improvement.
Initially, I planned to add an Advisor for handling re-ranking after QuestionAnswerAdvisor, but I checked that QuestionAnswerAdvisor already adds the search results context to .withUserParams(advisedUserParams) in the code.

Although a subsequent Advisor can override the UserParams, when using a re-ranking model, the process typically involves vector search to retrieve 50-100 chunks, followed by re-ranking. If all these results are added to UserParams, it would waste both time and memory, even though the later Advisor can override them.

I hope the RAG-related Advisors can be improved in the following ways:

  1. Retrieve the document contents and attach them to UserParams only after all Advisors have completed processing.
  2. Add support for a re-ranking model, which is a common retrieval optimization technique.
  3. Typically, a re-ranking model only returns the index and similarity scores. However, it would be ideal if the re-ranking process still returns a List<Documents> containing the metadata.

Below is the code I have modified.

public class RerankRAGAdvisor implements RequestResponseAdvisor {
	private static final String DEFAULT_USER_TEXT_ADVISE = """
			Context information is below.
			---------------------
			{question_answer_context}
			---------------------
			Given the context and provided history information and not prior knowledge,
			reply to the user comment. If the answer is not in the context, inform
			the user that you can't answer the question.
			""";
	private final VectorStore vectorStore;
	private final String userTextAdvise;
	private final SearchRequest searchRequest;
	public static final String RETRIEVED_DOCUMENTS = "qa_retrieved_documents";
	public static final String FILTER_EXPRESSION = "qa_filter_expression";
	private final RestClient restClient;
	private String apiKey= System.getenv("VOYAGE_KEY");
	// Re-ranking API return data
	@JsonInclude(Include.NON_NULL)
	public record RerankList(
			@JsonProperty("object") String object,
			@JsonProperty("data") List<Rerank> data,
			@JsonProperty("model") String model,
			@JsonProperty("usage") Usage usage) {
	}
	@JsonInclude(Include.NON_NULL)
	public record Rerank(
			@JsonProperty("index") Integer index,
			@JsonProperty("relevance_score") float relevanceScore,
			@JsonProperty("document") String document) {
	}
	
	//Re-ranking API
	public ResponseEntity<RerankList> rerankDocuments(String query, List<Document> documents) {
        String url = "https://api.voyageai.com/v1/rerank";
        String bearerStr = "Bearer "+this.apiKey;
        Map<String, Object> requestBody = new HashMap<>();
        requestBody.put("query", query);			
        requestBody.put("model", "rerank-1");		
        requestBody.put("top_k", 5);				
        requestBody.put("return_documents",true);
        requestBody.put("documents", documents.stream().map(Document::getContent).toList());

        return restClient.post()
            .uri(url)
            .contentType(MediaType.APPLICATION_JSON)
            .header("Authorization",bearerStr)
            .body(requestBody)
            .retrieve()
            .toEntity(new ParameterizedTypeReference<>() {
			});
    }
	public RerankRAGAdvisor(RestClient restClient, VectorStore vectorStore) {
		this(restClient, vectorStore, SearchRequest.defaults(), DEFAULT_USER_TEXT_ADVISE);
	}
	public RerankRAGAdvisor(RestClient restClient, VectorStore vectorStore, SearchRequest searchRequest) {
		this(restClient, vectorStore, searchRequest, DEFAULT_USER_TEXT_ADVISE);
	}
	public RerankRAGAdvisor(RestClient restClient, VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise) {
		Assert.notNull(restClient, "The restClient must not be null!");
		Assert.notNull(vectorStore, "The vectorStore must not be null!");
		Assert.notNull(searchRequest, "The searchRequest must not be null!");
		Assert.hasText(userTextAdvise, "The userTextAdvise must not be empty!");

		this.restClient = restClient;
		this.vectorStore = vectorStore;
		this.searchRequest = searchRequest;
		this.userTextAdvise = userTextAdvise;
	}
	@Override
	public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object> context) {
		// 1. Advise the system text.
		String advisedUserText = request.userText() + System.lineSeparator() + this.userTextAdvise;
		var searchRequestToUse = SearchRequest.from(this.searchRequest)
			.withQuery(request.userText())
			.withTopK(100)
			.withFilterExpression(doGetFilterExpression(context));
		// 2. Search for similar documents in the vector store.
		List<Document> documents = this.vectorStore.similaritySearch(searchRequestToUse);
		// 3. Re-ranking
		List<Rerank> rerankDocs = rerankDocuments(request.userText(), documents).getBody().data();
		context.put(RETRIEVED_DOCUMENTS, rerankDocs);
		// 4. Create the context from the documents.
		String documentContext = rerankDocs.stream()
			.map(Rerank::document)
			.collect(Collectors.joining(System.lineSeparator()));
		// 5. Advise the user parameters.
		Map<String, Object> advisedUserParams = new HashMap<>(request.userParams());
		advisedUserParams.put("question_answer_context", documentContext);
		AdvisedRequest advisedRequest = AdvisedRequest.from(request)
			.withUserText(advisedUserText)
			.withUserParams(advisedUserParams)
			.build();
		return advisedRequest;
	}
	@Override
	public ChatResponse adviseResponse(ChatResponse response, Map<String, Object> context) {
		ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(response);
		chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS));
		return chatResponseBuilder.build();
	}
	@Override
	public Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxResponse, Map<String, Object> context) {
		return fluxResponse.map(cr -> {
			ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(cr);
			chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS));
			return chatResponseBuilder.build();
		});
	}
	protected Filter.Expression doGetFilterExpression(Map<String, Object> context) {
		if (!context.containsKey(FILTER_EXPRESSION)
				|| !StringUtils.hasText(context.get(FILTER_EXPRESSION).toString())) {
			return this.searchRequest.getFilterExpression();
		}
		return new FilterExpressionTextParser().parse(context.get(FILTER_EXPRESSION).toString());
	}
}
@markpollack markpollack added advisors RAG Issues related to Retrieval Augmented Generation labels Sep 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
advisors RAG Issues related to Retrieval Augmented Generation
Projects
None yet
Development

No branches or pull requests

2 participants