Skip to content

Commit

Permalink
keeping changes as it is
Browse files Browse the repository at this point in the history
  • Loading branch information
someshfengde committed Mar 26, 2024
1 parent 4871169 commit dda9a50
Showing 1 changed file with 101 additions and 62 deletions.
163 changes: 101 additions & 62 deletions backend/app/router/orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
import dspy
from app.rag.retrieval.web.brave_search import BraveSearchQueryEngine
import asyncio
import re

from llama_index.core.tools import ToolMetadata
from llama_index.core.selectors import LLMSingleSelector
from llama_index.llms.openai import OpenAI
from llama_index.core.schema import QueryBundle
from llama_index.llms.together import TogetherLLM
from llama_index.core.response_synthesizers import SimpleSummarize

from app.rag.retrieval.clinical_trials.clinical_trial_sql_query_engine import ClinicalTrialText2SQLEngine
from app.rag.retrieval.drug_chembl.drug_chembl_graph_query_engine import DrugChEMBLText2CypherEngine
from app.rag.reranker.response_reranker import ReRankEngine
from app.rag.generation.response_synthesis import ResponseSynthesisEngine
from app.config import config, OPENAI_API_KEY, RERANK_TOP_COUNT, ORCHESRATOR_ROUTER_PROMPT_PROGRAM

from app.rag.retrieval.web.brave_search import BraveSearchQueryEngine
from app.rag.retrieval.pubmed.pubmedqueryengine import PubmedSearchQueryEngine
from app.rag.reranker.response_reranker import TextEmbeddingInferenceRerankEngine
from app.api.common.util import RouteCategory
from app.config import config, OPENAI_API_KEY, TOGETHER_KEY, ORCHESRATOR_ROUTER_PROMPT_PROGRAM
from app.services.search_utility import setup_logger

import dspy
from app.dspy_integration.router_prompt import Router_module

logger = setup_logger('Orchestrator')


logger = setup_logger("Orchestrator")
TAG_RE = re.compile(r'<[^>]+>')

class Orchestrator:
"""
Orchestrator is responsible for routing the search engine query.
Expand All @@ -23,77 +35,104 @@ class Orchestrator:
def __init__(self, config):
self.config = config



self.llm = dspy.OpenAI(model="gpt-3.5-turbo", api_key=str(OPENAI_API_KEY))
dspy.settings.configure(lm = self.llm)
self.router = Router_module()
self.router.load(ORCHESRATOR_ROUTER_PROMPT_PROGRAM)

self.clinicalTrialSearch = ClinicalTrialText2SQLEngine(config)
self.drugChemblSearch = DrugChEMBLText2CypherEngine(config)
self.pubmedsearch = PubmedSearchQueryEngine(config)
self.bravesearch = BraveSearchQueryEngine(config)

async def query_and_get_answer(
self,
search_text: str
) -> str:
logger.info(f"query_and_get_answer.router_id search_text: {search_text}")
try :
router_id = int(self.router(search_text).answer)
except Exception as e:
logger.exception("query_and_get_answer.router_id Exception -", exc_info = e, stack_info=True)
logger.info(f"query_and_get_answer.router_id router_id: {router_id}")

breaks_sql = False

if router_id == 0:
clinicalTrialSearch = ClinicalTrialText2SQLEngine(config)
self,
routecategory: RouteCategory = RouteCategory.PBW,
search_text: str = "") -> str:
# search router call
logger.debug(
f"Orchestrator.query_and_get_answer.router_id search_text: {search_text}"
)

#initialize router with bad value
router_id = -1

# user not specified
if routecategory == RouteCategory.NS:
logger.info(f"query_and_get_answer.router_id search_text: {search_text}")
try :
router_id = int(self.router(search_text).answer)
except Exception as e:
logger.exception("query_and_get_answer.router_id Exception -", exc_info = e, stack_info=True)
logger.info(f"query_and_get_answer.router_id router_id: {router_id}")

breaks_sql = False

#routing
if router_id == 0 or routecategory == RouteCategory.CT:
# clinical trial call
logger.debug(
"Orchestrator.query_and_get_answer.router_id clinical trial Entered."
)
try:
sqlResponse = await clinicalTrialSearch.call_text2sql(search_text=search_text)
result = sqlResponse.get('result', '')
logger.info(f"query_and_get_answer.sqlResponse sqlResponse: {result}")
sqlResponse = self.clinicalTrialSearch.call_text2sql(search_text=search_text)
result = str(sqlResponse)
sources = result

logger.debug(f"Orchestrator.query_and_get_answer.sqlResponse sqlResponse: {result}")
except Exception as e:
breaks_sql = True
logger.exception("query_and_get_answer.sqlResponse Exception -", exc_info = e, stack_info=True)
logger.exception("Orchestrator.query_and_get_answer.sqlResponse Exception -", exc_info = e, stack_info=True)
pass

elif router_id == 1:
elif router_id == 1 or routecategory == RouteCategory.DRUG:
# drug information call
logger.info("query_and_get_answer.router_id drug_information_choice Entered.")

drugChemblSearch = DrugChEMBLText2CypherEngine(config)
result = []

logger.debug(
"Orchestrator.query_and_get_answer.router_id drug_information_choice Entered."
)
try:
cypherResponse = await drugChemblSearch.call_text2cypher(search_text=search_text)
cypherResponse = self.drugChemblSearch.call_text2cypher(
search_text=search_text
)
result = str(cypherResponse)

logger.info(f"query_and_get_answer.cypherResponse cypherResponse: {result}")
sources = result
logger.debug(
f"Orchestrator.query_and_get_answer.cypherResponse cypherResponse: {result}"
)
except Exception as e:
breaks_sql = True
logger.exception("query_and_get_answer.cypherResponse Exception -", exc_info = e, stack_info=True)

print()

if router_id == 2 or breaks_sql:
logger.info("query_and_get_answer.router_id Fallback Entered.")
logger.exception(
"Orchestrator.query_and_get_answer.cypherResponse Exception -",
exc_info=e,
stack_info=True,
)

if router_id == 2 or routecategory == RouteCategory.PBW or routecategory == RouteCategory.NS or breaks_sql:
logger.debug(
"Orchestrator.query_and_get_answer.router_id Fallback Entered."
)

bravesearch = BraveSearchQueryEngine(config)
extracted_retrieved_results = await bravesearch.call_brave_search_api(search_text=search_text)
extracted_pubmed_results, extracted_web_results = await asyncio.gather(
self.pubmedsearch.call_pubmed_vectors(search_text=search_text), self.bravesearch.call_brave_search_api(search_text=search_text)
)
extracted_results = extracted_pubmed_results + extracted_web_results
logger.debug(
f"Orchestrator.query_and_get_answer.extracted_results count: {len(extracted_pubmed_results), len(extracted_web_results)}"
)

logger.info(f"query_and_get_answer.extracted_retrieved_results: {extracted_retrieved_results}")
# rerank call
reranked_results = TextEmbeddingInferenceRerankEngine(top_n=2)._postprocess_nodes(
nodes = extracted_results,
query_bundle=QueryBundle(query_str=search_text))

summarizer = SimpleSummarize(llm=TogetherLLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", api_key=str(TOGETHER_KEY)))
result = summarizer.get_response(query_str=search_text, text_chunks=[TAG_RE.sub('', node.get_content()) for node in reranked_results])
sources = [node.node.metadata for node in reranked_results ]

#rerank call
rerank = ReRankEngine(config)
rerankResponse = await rerank.call_embedding_api(
search_text=search_text,
retrieval_results=extracted_retrieved_results
)
rerankResponse_sliced = rerankResponse[:RERANK_TOP_COUNT]
logger.info(f"query_and_get_answer.rerankResponse_sliced: {rerankResponse_sliced}")

#generation call
response_synthesis = ResponseSynthesisEngine(config)
result = await response_synthesis.call_llm_service_api(
search_text=search_text,
reranked_results=rerankResponse_sliced
)
result = result.get('result', '') + "\n\n" + "Source: " + ', '.join(result.get('source', []))
logger.info(f"query_and_get_answer.response_synthesis: {result}")
logger.info(f"query_and_get_answer. result: {result}")
return result
return {
"result" : result,
"sources": sources
}

0 comments on commit dda9a50

Please sign in to comment.