Skip to content

Commit

Permalink
Merge pull request #45 from langchain-ai/nc/permchain
Browse files Browse the repository at this point in the history
WIP Use permchain agent executor
  • Loading branch information
nfcampos authored Nov 20, 2023
2 parents fdfda6f + 669d957 commit 94f7fee
Show file tree
Hide file tree
Showing 38 changed files with 1,565 additions and 1,389 deletions.
3 changes: 3 additions & 0 deletions backend/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
# Default target executed when no arguments are given to make.
all: help

build_ui:
cd ../frontend && yarn build && cp -r dist/* ../backend/ui

######################
# TESTING AND COVERAGE
######################
Expand Down
23 changes: 23 additions & 0 deletions backend/app/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from fastapi import APIRouter

from app.api.assistants import router as assistants_router
from app.api.runs import router as runs_router
from app.api.threads import router as threads_router

router = APIRouter()

router.include_router(
assistants_router,
prefix="/assistants",
tags=["assistants"],
)
router.include_router(
runs_router,
prefix="/runs",
tags=["runs"],
)
router.include_router(
threads_router,
prefix="/threads",
tags=["threads"],
)
87 changes: 87 additions & 0 deletions backend/app/api/assistants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Annotated, List, Optional
from uuid import uuid4

from fastapi import APIRouter, HTTPException, Path, Query
from pydantic import BaseModel, Field

import app.storage as storage
from app.schema import Assistant, AssistantWithoutUserId, OpengptsUserId

router = APIRouter()

FEATURED_PUBLIC_ASSISTANTS = [
"ba721964-b7e4-474c-b817-fb089d94dc5f",
"dc3ec482-aafc-4d90-8a1a-afb9b2876cde",
]


class AssistantPayload(BaseModel):
"""Payload for creating an assistant."""

name: str = Field(..., description="The name of the assistant.")
config: dict = Field(..., description="The assistant config.")
public: bool = Field(default=False, description="Whether the assistant is public.")


AssistantID = Annotated[str, Path(description="The ID of the assistant.")]


@router.get("/")
def list_assistants(opengpts_user_id: OpengptsUserId) -> List[AssistantWithoutUserId]:
"""List all assistants for the current user."""
return storage.list_assistants(opengpts_user_id)


@router.get("/public/")
def list_public_assistants(
shared_id: Annotated[
Optional[str], Query(description="ID of a publicly shared assistant.")
] = None,
) -> List[AssistantWithoutUserId]:
"""List all public assistants."""
return storage.list_public_assistants(
FEATURED_PUBLIC_ASSISTANTS + ([shared_id] if shared_id else [])
)


@router.get("/{aid}")
def get_asistant(
opengpts_user_id: OpengptsUserId,
aid: AssistantID,
) -> Assistant:
"""Get an assistant by ID."""
assistant = storage.get_assistant(opengpts_user_id, aid)
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
return assistant


@router.post("")
def create_assistant(
opengpts_user_id: OpengptsUserId,
payload: AssistantPayload,
) -> Assistant:
"""Create an assistant."""
return storage.put_assistant(
opengpts_user_id,
str(uuid4()),
name=payload.name,
config=payload.config,
public=payload.public,
)


@router.put("/{aid}")
def upsert_assistant(
opengpts_user_id: OpengptsUserId,
aid: AssistantID,
payload: AssistantPayload,
) -> Assistant:
"""Create or update an assistant."""
return storage.put_assistant(
opengpts_user_id,
aid,
name=payload.name,
config=payload.config,
public=payload.public,
)
213 changes: 213 additions & 0 deletions backend/app/api/runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import asyncio
import json
from typing import AsyncIterator, Sequence
from uuid import uuid4

import langsmith.client
import orjson
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from gizmo_agent import agent
from langchain.pydantic_v1 import ValidationError
from langchain.schema.messages import AnyMessage, FunctionMessage
from langchain.schema.output import ChatGeneration
from langchain.schema.runnable import RunnableConfig
from langserve.callbacks import AsyncEventAggregatorCallback
from langserve.schema import FeedbackCreateRequest
from langserve.serialization import WellKnownLCSerializer
from langserve.server import _get_base_run_id_as_str, _unpack_input
from langsmith.utils import tracing_is_enabled
from pydantic import BaseModel, Field
from sse_starlette import EventSourceResponse

from app.schema import OpengptsUserId
from app.storage import get_assistant, get_thread_messages, public_user_id
from app.stream import StreamMessagesHandler

router = APIRouter()


_serializer = WellKnownLCSerializer()


class AgentInput(BaseModel):
"""An input into an agent."""

messages: Sequence[AnyMessage] = Field(default_factory=list)


class CreateRunPayload(BaseModel):
"""Payload for creating a run."""

assistant_id: str
thread_id: str
input: AgentInput = Field(default_factory=AgentInput)


async def _run_input_and_config(request: Request, opengpts_user_id: OpengptsUserId):
try:
body = await request.json()
except json.JSONDecodeError:
raise RequestValidationError(errors=["Invalid JSON body"])
assistant, public_assistant, state = await asyncio.gather(
asyncio.get_running_loop().run_in_executor(
None, get_assistant, opengpts_user_id, body["assistant_id"]
),
asyncio.get_running_loop().run_in_executor(
None, get_assistant, public_user_id, body["assistant_id"]
),
asyncio.get_running_loop().run_in_executor(
None, get_thread_messages, opengpts_user_id, body["thread_id"]
),
)
assistant = assistant or public_assistant
if not assistant:
raise HTTPException(status_code=404, detail="Assistant not found")
config: RunnableConfig = {
**assistant["config"],
"configurable": {
**assistant["config"]["configurable"],
"user_id": opengpts_user_id,
"thread_id": body["thread_id"],
"assistant_id": body["assistant_id"],
},
}
try:
input_ = _unpack_input(agent.get_input_schema(config).validate(body["input"]))
except ValidationError as e:
raise RequestValidationError(e.errors(), body=body)

return input_, config, state["messages"]


@router.post("")
async def create_run(
request: Request,
payload: CreateRunPayload, # for openapi docs
opengpts_user_id: OpengptsUserId,
background_tasks: BackgroundTasks,
):
"""Create a run."""
input_, config, messages = await _run_input_and_config(request, opengpts_user_id)
background_tasks.add_task(agent.ainvoke, input_, config)
return {"status": "ok"} # TODO add a run id


@router.post("/stream")
async def stream_run(
request: Request,
payload: CreateRunPayload, # for openapi docs
opengpts_user_id: OpengptsUserId,
):
"""Create a run."""
input_, config, messages = await _run_input_and_config(request, opengpts_user_id)
streamer = StreamMessagesHandler(messages + input_["messages"])
event_aggregator = AsyncEventAggregatorCallback()
config["callbacks"] = [streamer, event_aggregator]

# Call the runnable in streaming mode,
# add each chunk to the output stream
async def consume_astream() -> None:
try:
async for chunk in agent.astream(input_, config):
await streamer.send_stream.send(chunk)
# hack: function messages aren't generated by chat model
# so the callback handler doesn't know about them
if chunk["messages"]:
message = chunk["messages"][-1]
if isinstance(message, FunctionMessage):
streamer.output[uuid4()] = ChatGeneration(message=message)
except Exception as e:
await streamer.send_stream.send(e)
finally:
await streamer.send_stream.aclose()

# Start the runnable in the background
task = asyncio.create_task(consume_astream())

# Consume the stream into an EventSourceResponse
async def _stream() -> AsyncIterator[dict]:
has_sent_metadata = False

async for chunk in streamer.receive_stream:
if isinstance(chunk, BaseException):
yield {
"event": "error",
# Do not expose the error message to the client since
# the message may contain sensitive information.
# We'll add client side errors for validation as well.
"data": orjson.dumps(
{"status_code": 500, "message": "Internal Server Error"}
).decode(),
}
raise chunk
else:
if not has_sent_metadata and event_aggregator.callback_events:
yield {
"event": "metadata",
"data": orjson.dumps(
{"run_id": _get_base_run_id_as_str(event_aggregator)}
).decode(),
}
has_sent_metadata = True

yield {
# EventSourceResponse expects a string for data
# so after serializing into bytes, we decode into utf-8
# to get a string.
"data": _serializer.dumps(chunk).decode("utf-8"),
"event": "data",
}

# Send an end event to signal the end of the stream
yield {"event": "end"}
# Wait for the runnable to finish
await task

return EventSourceResponse(_stream())


@router.get("/input_schema")
async def input_schema() -> dict:
"""Return the input schema of the runnable."""
return agent.get_input_schema().schema()


@router.get("/output_schema")
async def output_schema() -> dict:
"""Return the output schema of the runnable."""
return agent.get_output_schema().schema()


@router.get("/config_schema")
async def config_schema() -> dict:
"""Return the config schema of the runnable."""
return agent.config_schema().schema()


if tracing_is_enabled():
langsmith_client = langsmith.client.Client()

@router.post("/feedback")
def create_run_feedback(feedback_create_req: FeedbackCreateRequest) -> dict:
"""
Send feedback on an individual run to langsmith
Note that a successful response means that feedback was successfully
submitted. It does not guarantee that the feedback is recorded by
langsmith. Requests may be silently rejected if they are
unauthenticated or invalid by the server.
"""

langsmith_client.create_feedback(
feedback_create_req.run_id,
feedback_create_req.key,
score=feedback_create_req.score,
value=feedback_create_req.value,
comment=feedback_create_req.comment,
source_info={
"from_langserve": True,
},
)

return {"status": "ok"}
Loading

0 comments on commit 94f7fee

Please sign in to comment.