-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
84 lines (63 loc) · 2.21 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from dotenv import load_dotenv
from typing import Sequence
from langchain_core.messages import HumanMessage, BaseMessage
from langgraph.prebuilt import ToolNode
load_dotenv()
from llm_utils.chains import extract_chain, search_chain, recommend_chain
from llm_utils.tools import crawl_restaurants, query_restaurants
EXTRACT = "extract"
SEARCH = "search"
GET_RESTAURANTS = "get_restaurants"
RECOMMEND = "recommend"
tool_node = ToolNode([crawl_restaurants, query_restaurants])
def extract_node(state: Sequence[BaseMessage]):
res = extract_chain.invoke(input={"search_query": [state[-1].content]})
search_query = " ".join(list(res[0].dict().values()))
return search_query
def search_node(state: Sequence[BaseMessage]):
if len(state) > 10:
raise ValueError("Too many iterations")
search_term = state[-1].content if len(state) < 3 else state[1].content
tool_choice = "query_restaurants" if len(state) < 3 else "crawl_restaurants"
res = search_chain.invoke(
input={"search_term": [search_term], "tool_choice": [tool_choice]}
)
return res
def recommend_node(state: Sequence[BaseMessage]):
return recommend_chain.invoke(
input={
"user_input": [state[0].content],
"restaurants_list": [state[-1].content],
}
)
def decide_node(state: Sequence[BaseMessage]):
content = state[-1].content
if content == "":
return SEARCH
else:
return RECOMMEND
from langgraph.graph import END, MessageGraph
builder = MessageGraph()
builder.add_node(EXTRACT, extract_node)
builder.add_node(SEARCH, search_node)
builder.add_node(RECOMMEND, recommend_node)
builder.add_node(GET_RESTAURANTS, tool_node)
builder.set_entry_point(EXTRACT)
builder.add_edge(EXTRACT, SEARCH)
builder.add_edge(SEARCH, GET_RESTAURANTS)
builder.add_conditional_edges(
GET_RESTAURANTS,
decide_node,
{
RECOMMEND: RECOMMEND,
SEARCH: SEARCH,
},
)
builder.add_edge(RECOMMEND, END)
graph = builder.compile()
if __name__ == "__main__":
human_message = HumanMessage(
content="강남역 파스타 추천해줘 예약가능한 가게여야만 해"
)
res = graph.invoke(input=human_message)
# print(graph.get_graph().draw_mermaid())