Skip to content

Commit

Permalink
changes in engine params (#282)
Browse files Browse the repository at this point in the history
* changes in engine params

* Changed workflow

* Changed workflow

* Changed version
  • Loading branch information
Ansh5461 authored Apr 13, 2024
1 parent e56210a commit 2559a2e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
33 changes: 30 additions & 3 deletions querent/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,38 @@
async def start_workflow(config_dict: dict):
# Start the workflow
workflow_config = config_dict.get("workflow")
engine_params = workflow_config.get("config").get("engine_params", None)
engine_params = workflow_config.get("config", None)
is_engine_params = False
try:
if engine_params is not None:
engine_params = json.loads(engine_params)
engine_params_json = {}

if engine_params.get("fixed_entities") is not None:
engine_params_json["fixed_entities"] = [x for x in engine_params.get("fixed_entities").split(",")]

if engine_params.get("sample_entities") is not None:
engine_params_json["sample_entities"] = [x for x in engine_params.get("fixed_entities").split(",")]

if engine_params.get("ner_model_name") is not None:
engine_params_json["ner_model_name"] = engine_params.get("ner_model_name")

if engine_params.get("enable_filtering") is not None:
engine_params_json["enable_filtering"] = engine_params.get("enable_filtering")

engine_params_json["filter_params"] = {
"score_threshold": float(engine_params.get("score_threshold")) if engine_params.get("score_threshold") is not None else None,
"attention_score_threshold": float(engine_params.get("attention_score_threshold")) if engine_params.get("attention_score_threshold") is not None else None,
"similarity_threshold": float(engine_params.get("similarity_threshold")) if engine_params.get("similarity_threshold") is not None else None,
"min_cluster_size": int(engine_params.get("min_cluster_size")) if engine_params.get("min_cluster_size") is not None else None,
"min_samples": int(engine_params.get("min_samples")) if engine_params.get("min_samples") is not None else None,
"cluster_persistence_threshold": float(engine_params.get("cluster_persistence_threshold")) if engine_params.get("cluster_persistence_threshold") is not None else None,
}

if engine_params.get("is_confined_search") is not None:
engine_params_json["is_confined_search"] = engine_params.get("is_confined_search")

if engine_params.get("user_context") is not None:
engine_params_json["user_context"] = engine_params.get("user_context")
is_engine_params = True
except Exception as e:
logger.error("Got error while loading engine params: ", e)
Expand All @@ -37,7 +64,7 @@ async def start_workflow(config_dict: dict):
engines = []
for engine_config in engine_configs:
if is_engine_params:
engine_config.update(engine_params)
engine_config.update(engine_params_json)
engine_config_source = engine_config.get("config", {})
if engine_config["name"] == "knowledge_graph_using_openai":
engine_config.update({"openai_api_key": engine_config["config"]["openai_api_key"]})
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@

setup(
name="querent",
version="3.0.1",
version="3.0.2",
author="Querent AI",
description="The Asynchronous Data Dynamo and Graph Neural Network Catalyst",
long_description=long_description,
Expand Down

0 comments on commit 2559a2e

Please sign in to comment.