From 2559a2e0d7b54a0453a8264502a9ffe01d33946b Mon Sep 17 00:00:00 2001 From: Ansh Joshi <54464396+Ansh5461@users.noreply.github.com> Date: Sat, 13 Apr 2024 13:03:59 +0530 Subject: [PATCH] changes in engine params (#282) * changes in engine params * Changed workflow * Changed workflow * Changed version --- querent/workflow/workflow.py | 33 ++++++++++++++++++++++++++++++--- setup.py | 2 +- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/querent/workflow/workflow.py b/querent/workflow/workflow.py index bba60a91..ec410a4b 100644 --- a/querent/workflow/workflow.py +++ b/querent/workflow/workflow.py @@ -20,11 +20,38 @@ async def start_workflow(config_dict: dict): # Start the workflow workflow_config = config_dict.get("workflow") - engine_params = workflow_config.get("config").get("engine_params", None) + engine_params = workflow_config.get("config", None) is_engine_params = False try: if engine_params is not None: - engine_params = json.loads(engine_params) + engine_params_json = {} + + if engine_params.get("fixed_entities") is not None: + engine_params_json["fixed_entities"] = [x for x in engine_params.get("fixed_entities").split(",")] + + if engine_params.get("sample_entities") is not None: + engine_params_json["sample_entities"] = [x for x in engine_params.get("fixed_entities").split(",")] + + if engine_params.get("ner_model_name") is not None: + engine_params_json["ner_model_name"] = engine_params.get("ner_model_name") + + if engine_params.get("enable_filtering") is not None: + engine_params_json["enable_filtering"] = engine_params.get("enable_filtering") + + engine_params_json["filter_params"] = { + "score_threshold": float(engine_params.get("score_threshold")) if engine_params.get("score_threshold") is not None else None, + "attention_score_threshold": float(engine_params.get("attention_score_threshold")) if engine_params.get("attention_score_threshold") is not None else None, + "similarity_threshold": float(engine_params.get("similarity_threshold")) if engine_params.get("similarity_threshold") is not None else None, + "min_cluster_size": int(engine_params.get("min_cluster_size")) if engine_params.get("min_cluster_size") is not None else None, + "min_samples": int(engine_params.get("min_samples")) if engine_params.get("min_samples") is not None else None, + "cluster_persistence_threshold": float(engine_params.get("cluster_persistence_threshold")) if engine_params.get("cluster_persistence_threshold") is not None else None, + } + + if engine_params.get("is_confined_search") is not None: + engine_params_json["is_confined_search"] = engine_params.get("is_confined_search") + + if engine_params.get("user_context") is not None: + engine_params_json["user_context"] = engine_params.get("user_context") is_engine_params = True except Exception as e: logger.error("Got error while loading engine params: ", e) @@ -37,7 +64,7 @@ async def start_workflow(config_dict: dict): engines = [] for engine_config in engine_configs: if is_engine_params: - engine_config.update(engine_params) + engine_config.update(engine_params_json) engine_config_source = engine_config.get("config", {}) if engine_config["name"] == "knowledge_graph_using_openai": engine_config.update({"openai_api_key": engine_config["config"]["openai_api_key"]}) diff --git a/setup.py b/setup.py index fe07170b..9b86a07a 100644 --- a/setup.py +++ b/setup.py @@ -83,7 +83,7 @@ setup( name="querent", - version="3.0.1", + version="3.0.2", author="Querent AI", description="The Asynchronous Data Dynamo and Graph Neural Network Catalyst", long_description=long_description,