Skip to content

Commit

Permalink
Langfuse for General Agent! (#407)
Browse files Browse the repository at this point in the history
  • Loading branch information
kongzii authored Aug 29, 2024
1 parent 53d86e7 commit 609c337
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 206 deletions.
404 changes: 204 additions & 200 deletions poetry.lock

Large diffs are not rendered by default.

24 changes: 19 additions & 5 deletions prediction_market_agent/agents/microchain_agent/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@

import streamlit as st
from microchain import Agent
from prediction_market_agent_tooling.deploy.agent import initialize_langfuse
from prediction_market_agent_tooling.markets.markets import MarketType
from prediction_market_agent_tooling.tools.costs import openai_costs
from prediction_market_agent_tooling.tools.langfuse_ import langfuse_context, observe
from prediction_market_agent_tooling.tools.streamlit_user_login import streamlit_login
from prediction_market_agent_tooling.tools.utils import utcnow
from streamlit_extras.bottom_container import bottom

from prediction_market_agent.agents.microchain_agent.deploy import GENERAL_AGENT_TAG
from prediction_market_agent.agents.microchain_agent.memory import ChatHistory
from prediction_market_agent.agents.microchain_agent.microchain_agent import (
SupportedModel,
Expand All @@ -45,7 +49,7 @@
get_initial_history_length,
has_been_run_past_initialization,
)
from prediction_market_agent.agents.utils import AgentIdentifier
from prediction_market_agent.agents.utils import STREAMLIT_TAG, AgentIdentifier
from prediction_market_agent.db.long_term_memory_table_handler import (
LongTermMemoryTableHandler,
)
Expand All @@ -56,15 +60,24 @@
AGENT_IDENTIFIER = AgentIdentifier.MICROCHAIN_AGENT_STREAMLIT
ALLOW_STOP = False

st.session_state.session_id = st.session_state.get(
"session_id", "StrealitGeneralAgent - " + utcnow().strftime("%Y-%m-%d %H:%M:%S")
)


@observe()
def run_agent(agent: Agent, iterations: int, model: SupportedModel) -> None:
langfuse_context.update_current_trace(
tags=[GENERAL_AGENT_TAG, STREAMLIT_TAG], session_id=st.session_state.session_id
)
maybe_initialize_long_term_memory()
with openai_costs(
model.value if model.is_openai else None
) as costs: # TODO: Support for Replicate costs (below as well).
with st.spinner("Agent is running..."):
for _ in range(iterations):
agent.run(iterations=1, resume=True)
agent.run(iterations=1, resume=st.session_state.total_iterations > 0)
st.session_state.total_iterations += 1
st.session_state.running_cost += costs.cost


Expand Down Expand Up @@ -137,6 +150,7 @@ def maybe_initialize_agent(

# Initialize the agent
if not agent_is_initialized():
initialize_langfuse(ENABLE_LANGFUSE)
st.session_state.agent = build_agent(
market_type=MARKET_TYPE,
model=model,
Expand All @@ -147,11 +161,10 @@ def maybe_initialize_agent(
functions_config=FunctionsConfig.from_system_prompt_choice(
st.session_state.system_prompt_select
),
enable_langfuse=ENABLE_LANGFUSE,
)
st.session_state.agent.reset()
st.session_state.agent.build_initial_messages()
st.session_state.total_iterations = 0
st.session_state.running_cost = 0.0

# Add a callback to display the agent's history after each run
st.session_state.agent.on_iteration_end = display_new_history_callback

Expand Down Expand Up @@ -179,6 +192,7 @@ def get_function_bullet_point_list(agent: Agent) -> str:
streamlit_login()
check_required_api_keys(["OPENAI_API_KEY", "BET_FROM_PRIVATE_KEY"])
KEYS = APIKeys()
ENABLE_LANGFUSE = KEYS.default_enable_langfuse
maybe_initialize_long_term_memory()

with st.sidebar:
Expand Down
7 changes: 7 additions & 0 deletions prediction_market_agent/agents/microchain_agent/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from prediction_market_agent.db.prompt_table_handler import PromptTableHandler
from prediction_market_agent.utils import APIKeys

GENERAL_AGENT_TAG = "general_agent"


class DeployableMicrochainAgent(DeployableAgent):
model = SupportedModel.gpt_4o
Expand All @@ -38,6 +40,10 @@ def run(
Override main 'run' method, as the all logic from the helper methods
is handed over to the agent.
"""
self.langfuse_update_current_trace(
tags=[GENERAL_AGENT_TAG, self.system_prompt_choice, self.task_description]
)

long_term_memory = LongTermMemoryTableHandler(
task_description=self.task_description
)
Expand All @@ -58,6 +64,7 @@ def run(
functions_config=FunctionsConfig.from_system_prompt_choice(
self.system_prompt_choice
),
enable_langfuse=self.enable_langfuse,
)

# Save formatted system prompt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def build_agent(
model: SupportedModel,
unformatted_system_prompt: str,
functions_config: FunctionsConfig,
enable_langfuse: bool,
api_base: str = "https://api.openai.com/v1",
long_term_memory: LongTermMemoryTableHandler | None = None,
allow_stop: bool = True,
Expand All @@ -156,6 +157,7 @@ def build_agent(
api_key=keys.openai_api_key.get_secret_value(),
api_base=api_base,
temperature=0.7,
enable_langfuse=enable_langfuse,
)
if model.is_openai
else (
Expand All @@ -165,6 +167,7 @@ def build_agent(
model
),
api_key=keys.replicate_api_key.get_secret_value(),
enable_langfuse=enable_langfuse,
)
if model.is_replicate
else should_not_happen()
Expand All @@ -185,6 +188,7 @@ def step_end_callback(agent: Agent, step_output: StepOutput) -> None:
llm=LLM(generator=generator),
engine=engine,
on_iteration_step=on_iteration_step,
enable_langfuse=enable_langfuse,
)

for f in build_agent_functions(
Expand Down
2 changes: 2 additions & 0 deletions prediction_market_agent/agents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
)
from prediction_market_agent.utils import DEFAULT_OPENAI_MODEL, APIKeys

STREAMLIT_TAG = "streamlit"


class AgentIdentifier(str, Enum):
THINK_THOROUGHLY = "think-thoroughly-agent"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ autoflake = "^2.2.1"
isort = "^5.13.2"
markdownify = "^0.11.6"
tavily-python = "^0.3.9"
microchain-python = "^0.4.4"
microchain-python = { git = "https://github.com/galatolofederico/microchain", rev = "98e601f6b7413ea48fb0b099309d686c4b10ff5c" }
pysqlite3-binary = {version="^0.5.2.post3", markers = "sys_platform == 'linux'"}
psycopg2-binary = "^2.9.9"
sqlmodel = "^0.0.21"
Expand Down
1 change: 1 addition & 0 deletions scripts/deployed_general_agent_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def to_api_keys(self, identifier: AgentIdentifier) -> APIKeys:
include_universal_functions=True, # placeholder, not used
include_agent_functions=True, # placeholder, not used
),
enable_langfuse=False, # placeholder, not used
)
tab1, tab2 = st.tabs(["Overall", "Per-Session"])
usage_count_col_name = "Usage Count"
Expand Down

0 comments on commit 609c337

Please sign in to comment.