Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating tools to use newer apis from langchain #19

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 0 additions & 31 deletions agents/l4m_agent.py

This file was deleted.

229 changes: 19 additions & 210 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,211 +1,20 @@
import os

import rasterio as rio
import folium
import streamlit as st
from streamlit_folium import folium_static

import langchain
from langchain.agents import AgentType
from langchain.chat_models import ChatOpenAI
from langchain.tools import Tool, DuckDuckGoSearchRun
from langchain.callbacks import (
StreamlitCallbackHandler,
AimCallbackHandler,
get_openai_callback,
)

from tools.mercantile_tool import MercantileTool
from tools.geopy.geocode import GeopyGeocodeTool
from tools.geopy.distance import GeopyDistanceTool
from tools.osmnx.geometry import OSMnxGeometryTool
from tools.osmnx.network import OSMnxNetworkTool
from tools.stac.search import STACSearchTool
from agents.l4m_agent import base_agent

# DEBUG
langchain.debug = True


@st.cache_resource(ttl="1h")
def get_agent(
openai_api_key, agent_type=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION
):
llm = ChatOpenAI(
temperature=0,
openai_api_key=openai_api_key,
model_name="gpt-3.5-turbo-0613",
)
# define a set of tools the agent has access to for queries
duckduckgo_tool = Tool(
name="DuckDuckGo",
description="Use this tool to answer questions about current events and places. \
Please ask targeted questions.",
func=DuckDuckGoSearchRun().run,
)
geocode_tool = GeopyGeocodeTool()
distance_tool = GeopyDistanceTool()
mercantile_tool = MercantileTool()
geometry_tool = OSMnxGeometryTool()
network_tool = OSMnxNetworkTool()
search_tool = STACSearchTool()

tools = [
duckduckgo_tool,
geocode_tool,
distance_tool,
mercantile_tool,
geometry_tool,
network_tool,
search_tool,
]

agent = base_agent(llm, tools, agent_type=agent_type)
return agent


def run_query(agent, query):
return response


def plot_raster(items):
st.subheader("Preview of the first item sorted by cloud cover")
selected_item = min(items, key=lambda item: item.properties["eo:cloud_cover"])
href = selected_item.assets["rendered_preview"].href
# arr = rio.open(href).read()

# m = folium.Map(location=[28.6, 77.7], zoom_start=6)

# img = folium.raster_layers.ImageOverlay(
# name="Sentinel 2",
# image=arr.transpose(1, 2, 0),
# bounds=selected_item.bbox,
# opacity=0.9,
# interactive=True,
# cross_origin=False,
# zindex=1,
# )

# img.add_to(m)
# folium.LayerControl().add_to(m)

# folium_static(m)
st.image(href)


def plot_vector(df):
st.subheader("Add the geometry to the Map")
center = df.centroid.iloc[0]
m = folium.Map(location=[center.y, center.x], zoom_start=12)
folium.GeoJson(df).add_to(m)
folium_static(m)


st.set_page_config(page_title="LLLLM", page_icon="🤖", layout="wide")
st.subheader("🤖 I am Geo LLM Agent!")

if "msgs" not in st.session_state:
st.session_state.msgs = []

if "total_tokens" not in st.session_state:
st.session_state.total_tokens = 0

if "prompt_tokens" not in st.session_state:
st.session_state.prompt_tokens = 0

if "completion_tokens" not in st.session_state:
st.session_state.completion_tokens = 0

if "total_cost" not in st.session_state:
st.session_state.total_cost = 0

with st.sidebar:
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
openai_api_key = st.text_input("OpenAI API Key", type="password")

st.subheader("OpenAI Usage")
total_tokens = st.empty()
prompt_tokens = st.empty()
completion_tokens = st.empty()
total_cost = st.empty()

total_tokens.write(f"Total Tokens: {st.session_state.total_tokens:,.0f}")
prompt_tokens.write(f"Prompt Tokens: {st.session_state.prompt_tokens:,.0f}")
completion_tokens.write(
f"Completion Tokens: {st.session_state.completion_tokens:,.0f}"
)
total_cost.write(f"Total Cost (USD): ${st.session_state.total_cost:,.4f}")


for msg in st.session_state.msgs:
with st.chat_message(name=msg["role"], avatar=msg["avatar"]):
st.markdown(msg["content"])

if prompt := st.chat_input("Ask me anything about the flat world..."):
with st.chat_message(name="user", avatar="🧑‍💻"):
st.markdown(prompt)

st.session_state.msgs.append({"role": "user", "avatar": "🧑‍💻", "content": prompt})

if not openai_api_key:
st.info("Please add your OpenAI API key to continue.")
st.stop()

aim_callback = AimCallbackHandler(
repo=".",
experiment_name="LLLLLM: Base Agent v0.1",
)

agent = get_agent(openai_api_key)

with get_openai_callback() as cb:
st_callback = StreamlitCallbackHandler(st.container())
response = agent.run(prompt, callbacks=[st_callback, aim_callback])

aim_callback.flush_tracker(langchain_asset=agent, reset=False, finish=True)

# Log OpenAI stats
# print(f"Model name: {response.llm_output.get('model_name', '')}")
st.session_state.total_tokens += cb.total_tokens
st.session_state.prompt_tokens += cb.prompt_tokens
st.session_state.completion_tokens += cb.completion_tokens
st.session_state.total_cost += cb.total_cost

total_tokens.write(f"Total Tokens: {st.session_state.total_tokens:,.0f}")
prompt_tokens.write(f"Prompt Tokens: {st.session_state.prompt_tokens:,.0f}")
completion_tokens.write(
f"Completion Tokens: {st.session_state.completion_tokens:,.0f}"
)
total_cost.write(f"Total Cost (USD): ${st.session_state.total_cost:,.4f}")

with st.chat_message(name="assistant", avatar="🤖"):
if type(response) == str:
content = response
st.markdown(response)
else:
tool, result = response

match tool:
case "stac-search":
content = f"Found {len(result)} items from the catalog."
st.markdown(content)
if len(result) > 0:
plot_raster(result)
case "geometry":
content = f"Found {len(result)} geometries."
gdf = result
st.markdown(content)
plot_vector(gdf)
case "network":
content = f"Found {len(result)} network geometries."
ndf = result
st.markdown(content)
plot_vector(ndf)
case _:
content = response
st.markdown(content)

st.session_state.msgs.append(
{"role": "assistant", "avatar": "🤖", "content": content}
)
from langchain_core.messages import HumanMessage

from graphs.l4m_graph import graph

if prompt := st.chat_input():
st.chat_message("user").write(prompt)
config = {"configurable": {"thread_id": "1"}}
for chunk in graph.stream(
{"messages": [HumanMessage(content=prompt)]}, config, stream_mode="updates"
):
# for chunk in graph.invoke(
# {"messages": [HumanMessage(content=prompt)]}, config, stream_mode="updates"
# ):
# st.markdown(chunk)

node = "assistant" if "assistant" in chunk else "tools"
with st.chat_message(node):
for msg in chunk[node]["messages"]:
st.markdown(msg.content)
32 changes: 17 additions & 15 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@ name: llllm-env
channels:
- conda-forge
dependencies:
- python=3
- python=3.12
- pip
- osmnx=1.3.1
- rasterio
- pip:
- openai==0.27.8
- langchain==0.0.215
- duckduckgo-search==3.8.3
- mercantile==1.2.1
- geopy==2.3.0
- ipywidgets==8.0.6
- jupyterlab==4.0.2
- planetary-computer==0.5.1
- pystac-client==0.7.2
- streamlit==1.24.1
- streamlit-folium==0.12.0
- watchdog==3.0.0
- aim==3.17.5
- langchain
- langchain-ollama
- langchain-community
- duckduckgo-search
- mercantile
- geopy
- ipywidgets
- jupyterlab
- planetary-computer
- pystac-client
- streamlit
- streamlit-folium
- watchdog
- altair
- osmnx
File renamed without changes.
72 changes: 72 additions & 0 deletions graphs/l4m_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import langchain
from langchain_core.messages import SystemMessage
from langchain_ollama import ChatOllama
from langgraph.graph import START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition

from tools.geopy.distance import distance_tool
from tools.geopy.geocode import geocode_tool
from tools.mercantile_tool import mercantile_tool
from tools.osmnx.geometry import geometry_tool
from tools.osmnx.network import network_tool
from tools.stac.search import stac_search

# from tools.duck_tool import duckduckgo_tool

# DEBUG
langchain.debug = True

llm = ChatOllama(
model="llama3.2",
temperature=0,
)

tools = [
# duckduckgo_tool,
geocode_tool,
distance_tool,
mercantile_tool,
geometry_tool,
network_tool,
stac_search,
]

# For this ipynb we set parallel tool calling to false as math generally is done sequentially, and this time we have 3 tools that can do math
# the OpenAI model specifically defaults to parallel tool calling for efficiency, see https://python.langchain.com/docs/how_to/tool_calling_parallel/
# play around with it and see how the model behaves with math equations!
llm_with_tools = llm.bind_tools(tools, parallel_tool_calls=False)


# System message
sys_msg = SystemMessage(
content="You are a helpful assistant tasked with answering questions on a set of geographic inputs."
# "You are a helpful assistant tasked with performing arithmetic on a set of inputs. "
"do not use tools unless the message does not contain geographic inputs"
# "do NOT use tools unless strictly necessary to answer the question"
# " Do NOT answer the question, just reformulate it if needed and otherwise return it as is.."
)


# Node
def assistant(state: MessagesState):
return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}


# Graph
builder = StateGraph(MessagesState)

# Define nodes: these do the work
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))

# Define edges: these determine how the control flow moves
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
"assistant",
# If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
# If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
tools_condition,
)
builder.add_edge("tools", "assistant")

graph = builder.compile()
8 changes: 8 additions & 0 deletions tools/duck_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from langchain.tools import DuckDuckGoSearchRun, Tool

duckduckgo_tool = Tool(
name="DuckDuckGo",
description="Use this tool to answer questions about current events and places. \
Please ask targeted questions.",
func=DuckDuckGoSearchRun().run,
)
Loading