diff --git a/backend/app/router/orchestrator.py b/backend/app/router/orchestrator.py index f3db143c..fd15ac74 100644 --- a/backend/app/router/orchestrator.py +++ b/backend/app/router/orchestrator.py @@ -1,18 +1,30 @@ -import dspy -from app.rag.retrieval.web.brave_search import BraveSearchQueryEngine +import asyncio +import re + +from llama_index.core.tools import ToolMetadata +from llama_index.core.selectors import LLMSingleSelector +from llama_index.llms.openai import OpenAI +from llama_index.core.schema import QueryBundle +from llama_index.llms.together import TogetherLLM +from llama_index.core.response_synthesizers import SimpleSummarize + from app.rag.retrieval.clinical_trials.clinical_trial_sql_query_engine import ClinicalTrialText2SQLEngine from app.rag.retrieval.drug_chembl.drug_chembl_graph_query_engine import DrugChEMBLText2CypherEngine -from app.rag.reranker.response_reranker import ReRankEngine -from app.rag.generation.response_synthesis import ResponseSynthesisEngine -from app.config import config, OPENAI_API_KEY, RERANK_TOP_COUNT, ORCHESRATOR_ROUTER_PROMPT_PROGRAM - +from app.rag.retrieval.web.brave_search import BraveSearchQueryEngine +from app.rag.retrieval.pubmed.pubmedqueryengine import PubmedSearchQueryEngine +from app.rag.reranker.response_reranker import TextEmbeddingInferenceRerankEngine +from app.api.common.util import RouteCategory +from app.config import config, OPENAI_API_KEY, TOGETHER_KEY, ORCHESRATOR_ROUTER_PROMPT_PROGRAM from app.services.search_utility import setup_logger +import dspy from app.dspy_integration.router_prompt import Router_module -logger = setup_logger('Orchestrator') +logger = setup_logger("Orchestrator") +TAG_RE = re.compile(r'<[^>]+>') + class Orchestrator: """ Orchestrator is responsible for routing the search engine query. @@ -23,77 +35,104 @@ class Orchestrator: def __init__(self, config): self.config = config + + self.llm = dspy.OpenAI(model="gpt-3.5-turbo", api_key=str(OPENAI_API_KEY)) dspy.settings.configure(lm = self.llm) self.router = Router_module() self.router.load(ORCHESRATOR_ROUTER_PROMPT_PROGRAM) + self.clinicalTrialSearch = ClinicalTrialText2SQLEngine(config) + self.drugChemblSearch = DrugChEMBLText2CypherEngine(config) + self.pubmedsearch = PubmedSearchQueryEngine(config) + self.bravesearch = BraveSearchQueryEngine(config) + async def query_and_get_answer( - self, - search_text: str - ) -> str: - logger.info(f"query_and_get_answer.router_id search_text: {search_text}") - try : - router_id = int(self.router(search_text).answer) - except Exception as e: - logger.exception("query_and_get_answer.router_id Exception -", exc_info = e, stack_info=True) - logger.info(f"query_and_get_answer.router_id router_id: {router_id}") - - breaks_sql = False - - if router_id == 0: - clinicalTrialSearch = ClinicalTrialText2SQLEngine(config) + self, + routecategory: RouteCategory = RouteCategory.PBW, + search_text: str = "") -> str: + # search router call + logger.debug( + f"Orchestrator.query_and_get_answer.router_id search_text: {search_text}" + ) + + #initialize router with bad value + router_id = -1 + + # user not specified + if routecategory == RouteCategory.NS: + logger.info(f"query_and_get_answer.router_id search_text: {search_text}") + try : + router_id = int(self.router(search_text).answer) + except Exception as e: + logger.exception("query_and_get_answer.router_id Exception -", exc_info = e, stack_info=True) + logger.info(f"query_and_get_answer.router_id router_id: {router_id}") + + breaks_sql = False + + #routing + if router_id == 0 or routecategory == RouteCategory.CT: + # clinical trial call + logger.debug( + "Orchestrator.query_and_get_answer.router_id clinical trial Entered." + ) try: - sqlResponse = await clinicalTrialSearch.call_text2sql(search_text=search_text) - result = sqlResponse.get('result', '') - logger.info(f"query_and_get_answer.sqlResponse sqlResponse: {result}") + sqlResponse = self.clinicalTrialSearch.call_text2sql(search_text=search_text) + result = str(sqlResponse) + sources = result + + logger.debug(f"Orchestrator.query_and_get_answer.sqlResponse sqlResponse: {result}") except Exception as e: breaks_sql = True - logger.exception("query_and_get_answer.sqlResponse Exception -", exc_info = e, stack_info=True) + logger.exception("Orchestrator.query_and_get_answer.sqlResponse Exception -", exc_info = e, stack_info=True) + pass - elif router_id == 1: + elif router_id == 1 or routecategory == RouteCategory.DRUG: # drug information call - logger.info("query_and_get_answer.router_id drug_information_choice Entered.") - - drugChemblSearch = DrugChEMBLText2CypherEngine(config) - result = [] - + logger.debug( + "Orchestrator.query_and_get_answer.router_id drug_information_choice Entered." + ) try: - cypherResponse = await drugChemblSearch.call_text2cypher(search_text=search_text) + cypherResponse = self.drugChemblSearch.call_text2cypher( + search_text=search_text + ) result = str(cypherResponse) - - logger.info(f"query_and_get_answer.cypherResponse cypherResponse: {result}") + sources = result + logger.debug( + f"Orchestrator.query_and_get_answer.cypherResponse cypherResponse: {result}" + ) except Exception as e: breaks_sql = True - logger.exception("query_and_get_answer.cypherResponse Exception -", exc_info = e, stack_info=True) - - print() - - if router_id == 2 or breaks_sql: - logger.info("query_and_get_answer.router_id Fallback Entered.") + logger.exception( + "Orchestrator.query_and_get_answer.cypherResponse Exception -", + exc_info=e, + stack_info=True, + ) + + if router_id == 2 or routecategory == RouteCategory.PBW or routecategory == RouteCategory.NS or breaks_sql: + logger.debug( + "Orchestrator.query_and_get_answer.router_id Fallback Entered." + ) - bravesearch = BraveSearchQueryEngine(config) - extracted_retrieved_results = await bravesearch.call_brave_search_api(search_text=search_text) + extracted_pubmed_results, extracted_web_results = await asyncio.gather( + self.pubmedsearch.call_pubmed_vectors(search_text=search_text), self.bravesearch.call_brave_search_api(search_text=search_text) + ) + extracted_results = extracted_pubmed_results + extracted_web_results + logger.debug( + f"Orchestrator.query_and_get_answer.extracted_results count: {len(extracted_pubmed_results), len(extracted_web_results)}" + ) - logger.info(f"query_and_get_answer.extracted_retrieved_results: {extracted_retrieved_results}") + # rerank call + reranked_results = TextEmbeddingInferenceRerankEngine(top_n=2)._postprocess_nodes( + nodes = extracted_results, + query_bundle=QueryBundle(query_str=search_text)) + summarizer = SimpleSummarize(llm=TogetherLLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", api_key=str(TOGETHER_KEY))) + result = summarizer.get_response(query_str=search_text, text_chunks=[TAG_RE.sub('', node.get_content()) for node in reranked_results]) + sources = [node.node.metadata for node in reranked_results ] - #rerank call - rerank = ReRankEngine(config) - rerankResponse = await rerank.call_embedding_api( - search_text=search_text, - retrieval_results=extracted_retrieved_results - ) - rerankResponse_sliced = rerankResponse[:RERANK_TOP_COUNT] - logger.info(f"query_and_get_answer.rerankResponse_sliced: {rerankResponse_sliced}") - - #generation call - response_synthesis = ResponseSynthesisEngine(config) - result = await response_synthesis.call_llm_service_api( - search_text=search_text, - reranked_results=rerankResponse_sliced - ) - result = result.get('result', '') + "\n\n" + "Source: " + ', '.join(result.get('source', [])) - logger.info(f"query_and_get_answer.response_synthesis: {result}") - logger.info(f"query_and_get_answer. result: {result}") - return result \ No newline at end of file + return { + "result" : result, + "sources": sources + } + \ No newline at end of file