From 834914caab8ffa71b9191acad6b355262df857be Mon Sep 17 00:00:00 2001 From: Somesh Fengade Date: Wed, 27 Mar 2024 16:43:08 +0530 Subject: [PATCH] adding AI model names to config --- backend/app/config.py | 7 ++++++- .../clinical_trials/clinical_trial_sql_query_engine.py | 8 +++++--- backend/app/router/orchestrator.py | 4 ++-- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/backend/app/config.py b/backend/app/config.py index e8c37ae2..30cae042 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -135,4 +135,9 @@ # Dspy Integration Configuration CLINICAL_TRIAL_SQL_PROGRAM: str = "app/dspy_integration/dspy_programs/clinical_trials_sql_generation.json" CLINICAL_TRIALS_RESPONSE_REFINEMENT_PROGRAM: str = "app/dspy_integration/dspy_programs/clinical_trials_response_refinement.json" -ORCHESRATOR_ROUTER_PROMPT_PROGRAM: str = "app/dspy_integration/dspy_programs/orchestrator_router_prompt.json" \ No newline at end of file +ORCHESRATOR_ROUTER_PROMPT_PROGRAM: str = "app/dspy_integration/dspy_programs/orchestrator_router_prompt.json" + +#AI models +ROUTER_MODEL: str = "gpt-3.5-turbo" +SQL_GENERATION_MODEL: str = "codellama/CodeLlama-13b-Instruct-hf" +RESPONSE_SYNTHESIZER_MODEL: str = "NousResearch/Nous-Hermes-llama-2-7b" \ No newline at end of file diff --git a/backend/app/rag/retrieval/clinical_trials/clinical_trial_sql_query_engine.py b/backend/app/rag/retrieval/clinical_trials/clinical_trial_sql_query_engine.py index c1ac8be6..d1ddee5c 100644 --- a/backend/app/rag/retrieval/clinical_trials/clinical_trial_sql_query_engine.py +++ b/backend/app/rag/retrieval/clinical_trials/clinical_trial_sql_query_engine.py @@ -26,7 +26,9 @@ EMBEDDING_MODEL_NAME, CLINICAL_TRIAL_SQL_PROGRAM, CLINICAL_TRIALS_RESPONSE_REFINEMENT_PROGRAM, - TOGETHER_KEY + TOGETHER_KEY, + SQL_GENERATION_MODEL, + RESPONSE_SYNTHESIZER_MODEL ) from app.services.search_utility import setup_logger @@ -62,8 +64,8 @@ class ClinicalTrialText2SQLEngine: def __init__(self, config): self.config = config - self.nous =dspy.Together(model = "NousResearch/Nous-Hermes-llama-2-7b", api_key=str(TOGETHER_KEY)) - self.llm = dspy.Together(model = "codellama/CodeLlama-13b-Instruct-hf", api_key=str(TOGETHER_KEY)) + self.nous =dspy.Together(model = str(RESPONSE_SYNTHESIZER_MODEL), api_key=str(TOGETHER_KEY)) + self.llm = dspy.Together(model = str(SQL_GENERATION_MODEL), api_key=str(TOGETHER_KEY)) dspy.settings.configure(lm = self.llm) self.sql_module = SQL_module() diff --git a/backend/app/router/orchestrator.py b/backend/app/router/orchestrator.py index 7f2ff19d..bfa6f693 100644 --- a/backend/app/router/orchestrator.py +++ b/backend/app/router/orchestrator.py @@ -11,7 +11,7 @@ from app.rag.retrieval.pubmed.pubmedqueryengine import PubmedSearchQueryEngine from app.rag.reranker.response_reranker import TextEmbeddingInferenceRerankEngine from app.api.common.util import RouteCategory -from app.config import OPENAI_API_KEY, TOGETHER_KEY, ORCHESRATOR_ROUTER_PROMPT_PROGRAM +from app.config import OPENAI_API_KEY, TOGETHER_KEY, ORCHESRATOR_ROUTER_PROMPT_PROGRAM, ROUTER_MODEL from app.services.search_utility import setup_logger import dspy @@ -34,7 +34,7 @@ def __init__(self, config): - self.llm = dspy.OpenAI(model="gpt-3.5-turbo", api_key=str(OPENAI_API_KEY)) + self.llm = dspy.OpenAI(model=str(ROUTER_MODEL), api_key=str(OPENAI_API_KEY)) dspy.settings.configure(lm = self.llm) self.router = Router_module() self.router.load(ORCHESRATOR_ROUTER_PROMPT_PROGRAM)