diff --git a/backend/app/agent.py b/backend/app/agent.py index d2695e79..512d85b2 100644 --- a/backend/app/agent.py +++ b/backend/app/agent.py @@ -27,11 +27,11 @@ Arxiv, AvailableTools, Connery, + DallE, DDGSearch, PressReleases, PubMed, Retrieval, - DallE, SecFilings, Tavily, TavilyAnswer, diff --git a/backend/app/tools.py b/backend/app/tools.py index 81b24aa7..bb6e993c 100644 --- a/backend/app/tools.py +++ b/backend/app/tools.py @@ -5,14 +5,13 @@ from langchain.pydantic_v1 import BaseModel, Field from langchain.tools.retriever import create_retriever_tool from langchain_community.agent_toolkits.connery import ConneryToolkit -from langchain_core.tools import Tool from langchain_community.retrievers.kay import KayAiRetriever from langchain_community.retrievers.pubmed import PubMedRetriever from langchain_community.retrievers.wikipedia import WikipediaRetriever from langchain_community.retrievers.you import YouRetriever -from langchain_community.tools.ddg_search.tool import DuckDuckGoSearchRun from langchain_community.tools.arxiv.tool import ArxivQueryRun from langchain_community.tools.connery import ConneryService +from langchain_community.tools.ddg_search.tool import DuckDuckGoSearchRun from langchain_community.tools.tavily_search import ( TavilyAnswer as _TavilyAnswer, ) @@ -20,12 +19,12 @@ TavilySearchResults, ) from langchain_community.utilities.arxiv import ArxivAPIWrapper -from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper +from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper +from langchain_core.tools import Tool from langchain_robocorp import ActionServerToolkit from typing_extensions import TypedDict - from app.upload import vstore @@ -40,6 +39,7 @@ class ArxivInput(BaseModel): class PythonREPLInput(BaseModel): query: str = Field(description="python command to run") + class DallEInput(BaseModel): query: str = Field(description="image description to generate image from") @@ -60,8 +60,7 @@ class AvailableTools(str, Enum): DALL_E = "dall_e" -class ToolConfig(TypedDict): - ... +class ToolConfig(TypedDict): ... class BaseTool(BaseModel): @@ -193,7 +192,10 @@ class Retrieval(BaseTool): class DallE(BaseTool): type: AvailableTools = Field(AvailableTools.DALL_E, const=True) name: str = Field("Generate Image (Dall-E)", const=True) - description: str = Field("Generates images from a text description using OpenAI's DALL-E model.", const=True) + description: str = Field( + "Generates images from a text description using OpenAI's DALL-E model.", + const=True, + ) RETRIEVAL_DESCRIPTION = """Can be used to look up information that was uploaded to this assistant. @@ -296,6 +298,7 @@ def _get_connery_actions(): tools = connery_toolkit.get_tools() return tools + @lru_cache(maxsize=1) def _get_dalle_tools(): return Tool( @@ -304,6 +307,7 @@ def _get_dalle_tools(): "A wrapper around OpenAI DALL-E API. Useful for when you need to generate images from a text description. Input should be an image description.", ) + TOOLS = { AvailableTools.ACTION_SERVER: _get_action_server, AvailableTools.CONNERY: _get_connery_actions,