Skip to content

Commit

Permalink
Merge pull request #109 from bolna-ai/prompt_context_updates
Browse files Browse the repository at this point in the history
adding logic for context data in prompts
  • Loading branch information
prateeksachan authored Apr 1, 2024
2 parents 9c23bf0 + 7467f52 commit f97f2a0
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions bolna/agent_types/graph_based_conversational_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,28 @@ def __init__(self, node_id, node_label, content, classification_labels: list = N


class Graph:
def __init__(self, conversation_data, preprocessed=False):
def __init__(self, conversation_data, preprocessed=False, context_data=None):
self.preprocessed = preprocessed
self.root = None
self.graph = self._create_graph(conversation_data)
self.graph = self._create_graph(conversation_data, context_data)

def _create_graph(self, data):
def _create_graph(self, data, context_data=None):
logger.info(f"Creating graph")
node_map = dict()
for node_id, node_data in data.items():
prompt_parts = node_data.get("prompt").split('###Examples')
prompt = node_data.get('prompt')
if len(prompt_parts) == 2:
classification_prompt = prompt_parts[0]
user_prompt = update_prompt_with_context(prompt_parts[1], context_data)
prompt = '###Examples'.join([classification_prompt, user_prompt])

node = Node(
node_id=node_id,
node_label=node_data["label"],
content=node_data["content"],
classification_labels=node_data.get("classification_labels", []),
prompt=node_data.get("prompt"),
prompt=prompt,
children=[],
milestone_check_prompt=node_data.get("milestone_check_prompt", ""),
)
Expand Down Expand Up @@ -68,7 +75,7 @@ def __init__(self, llm, prompts, context_data=None, preprocessed=True):
self.conversation_intro_done = False

def load_prompts_and_create_graph(self, prompts):
self.graph = Graph(prompts)
self.graph = Graph(prompts, context_data=self.context_data)
self.current_node = self.graph.root
self.current_node_interim = self.graph.root #Handle interim node because we are dealing with interim results

Expand Down

0 comments on commit f97f2a0

Please sign in to comment.