Skip to content

Commit

Permalink
Merge pull request #307 from Sema4AI/dalle-fixed
Browse files Browse the repository at this point in the history
Dalle fixed
  • Loading branch information
mkorpela authored Apr 16, 2024
2 parents cb39b9b + dcec7c8 commit bb498e3
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
2 changes: 2 additions & 0 deletions backend/app/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Arxiv,
AvailableTools,
Connery,
DallE,
DDGSearch,
PressReleases,
PubMed,
Expand All @@ -53,6 +54,7 @@
Tavily,
TavilyAnswer,
Retrieval,
DallE,
]


Expand Down
2 changes: 1 addition & 1 deletion backend/app/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import boto3
import httpx
from langchain_anthropic import ChatAnthropic
from langchain_community.chat_models import BedrockChat, ChatFireworks
from langchain_community.chat_models.ollama import ChatOllama
from langchain_google_vertexai import ChatVertexAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langchain_anthropic import ChatAnthropic

logger = logging.getLogger(__name__)

Expand Down
37 changes: 31 additions & 6 deletions backend/app/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,23 @@
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools.retriever import create_retriever_tool
from langchain_community.agent_toolkits.connery import ConneryToolkit
from langchain_community.retrievers import (
KayAiRetriever,
PubMedRetriever,
WikipediaRetriever,
)
from langchain_community.retrievers.kay import KayAiRetriever
from langchain_community.retrievers.pubmed import PubMedRetriever
from langchain_community.retrievers.wikipedia import WikipediaRetriever
from langchain_community.retrievers.you import YouRetriever
from langchain_community.tools import ArxivQueryRun, DuckDuckGoSearchRun
from langchain_community.tools.arxiv.tool import ArxivQueryRun
from langchain_community.tools.connery import ConneryService
from langchain_community.tools.ddg_search.tool import DuckDuckGoSearchRun
from langchain_community.tools.tavily_search import (
TavilyAnswer as _TavilyAnswer,
)
from langchain_community.tools.tavily_search import (
TavilySearchResults,
)
from langchain_community.utilities.arxiv import ArxivAPIWrapper
from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
from langchain_core.tools import Tool
from langchain_robocorp import ActionServerToolkit
from typing_extensions import TypedDict

Expand All @@ -39,6 +40,10 @@ class PythonREPLInput(BaseModel):
query: str = Field(description="python command to run")


class DallEInput(BaseModel):
query: str = Field(description="image description to generate image from")


class AvailableTools(str, Enum):
ACTION_SERVER = "action_server_by_robocorp"
CONNERY = "ai_action_runner_by_connery"
Expand All @@ -52,6 +57,7 @@ class AvailableTools(str, Enum):
PRESS_RELEASES = "press_releases_kai_ai"
PUBMED = "pubmed"
WIKIPEDIA = "wikipedia"
DALL_E = "dall_e"


class ToolConfig(TypedDict):
Expand Down Expand Up @@ -184,6 +190,15 @@ class Retrieval(BaseTool):
description: str = Field("Look up information in uploaded files.", const=True)


class DallE(BaseTool):
type: AvailableTools = Field(AvailableTools.DALL_E, const=True)
name: str = Field("Generate Image (Dall-E)", const=True)
description: str = Field(
"Generates images from a text description using OpenAI's DALL-E model.",
const=True,
)


RETRIEVAL_DESCRIPTION = """Can be used to look up information that was uploaded to this assistant.
If the user is referencing particular files, that is often a good hint that information may be here.
If the user asks a vague question, they are likely meaning to look up info from this retriever, and you should call it!"""
Expand Down Expand Up @@ -285,6 +300,15 @@ def _get_connery_actions():
return tools


@lru_cache(maxsize=1)
def _get_dalle_tools():
return Tool(
"Dall-E-Image-Generator",
DallEAPIWrapper(size="1024x1024", quality="hd").run,
"A wrapper around OpenAI DALL-E API. Useful for when you need to generate images from a text description. Input should be an image description.",
)


TOOLS = {
AvailableTools.ACTION_SERVER: _get_action_server,
AvailableTools.CONNERY: _get_connery_actions,
Expand All @@ -297,4 +321,5 @@ def _get_connery_actions():
AvailableTools.TAVILY: _get_tavily,
AvailableTools.WIKIPEDIA: _get_wikipedia,
AvailableTools.TAVILY_ANSWER: _get_tavily_answer,
AvailableTools.DALL_E: _get_dalle_tools,
}

0 comments on commit bb498e3

Please sign in to comment.