diff --git a/bolna/agent_types/graph_based_conversational_agent.py b/bolna/agent_types/graph_based_conversational_agent.py index a6d43233..a89c1add 100644 --- a/bolna/agent_types/graph_based_conversational_agent.py +++ b/bolna/agent_types/graph_based_conversational_agent.py @@ -23,21 +23,28 @@ def __init__(self, node_id, node_label, content, classification_labels: list = N class Graph: - def __init__(self, conversation_data, preprocessed=False): + def __init__(self, conversation_data, preprocessed=False, context_data=None): self.preprocessed = preprocessed self.root = None - self.graph = self._create_graph(conversation_data) + self.graph = self._create_graph(conversation_data, context_data) - def _create_graph(self, data): + def _create_graph(self, data, context_data=None): logger.info(f"Creating graph") node_map = dict() for node_id, node_data in data.items(): + prompt_parts = node_data.get("prompt").split('###Examples') + prompt = node_data.get('prompt') + if len(prompt_parts) == 2: + classification_prompt = prompt_parts[0] + user_prompt = update_prompt_with_context(prompt_parts[1], context_data) + prompt = '###Examples'.join([classification_prompt, user_prompt]) + node = Node( node_id=node_id, node_label=node_data["label"], content=node_data["content"], classification_labels=node_data.get("classification_labels", []), - prompt=node_data.get("prompt"), + prompt=prompt, children=[], milestone_check_prompt=node_data.get("milestone_check_prompt", ""), ) @@ -68,7 +75,7 @@ def __init__(self, llm, prompts, context_data=None, preprocessed=True): self.conversation_intro_done = False def load_prompts_and_create_graph(self, prompts): - self.graph = Graph(prompts) + self.graph = Graph(prompts, context_data=self.context_data) self.current_node = self.graph.root self.current_node_interim = self.graph.root #Handle interim node because we are dealing with interim results