Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support bedrock llm provider #213

Merged
merged 6 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions backend/app/alembic/versions/04d81be446c3_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""empty message

Revision ID: 04d81be446c3
Revises: e32f1e546eec
Create Date: 2024-08-08 17:11:50.178696

"""
from alembic import op
from sqlalchemy.dialects import mysql

# revision identifiers, used by Alembic.
revision = '04d81be446c3'
down_revision = 'e32f1e546eec'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('llms', 'provider',
existing_type=mysql.ENUM('OPENAI', 'GEMINI', 'ANTHROPIC_VERTEX', 'OPENAI_LIKE', 'BEDROCK'),
nullable=False)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('llms', 'provider',
existing_type=mysql.ENUM('OPENAI', 'GEMINI', 'ANTHROPIC_VERTEX', 'OPENAI_LIKE'),
nullable=False)
# ### end Alembic commands ###
3 changes: 3 additions & 0 deletions backend/app/api/admin_routes/llm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import List
from pydantic import BaseModel

Expand Down Expand Up @@ -28,6 +29,7 @@
)

router = APIRouter()
logger = logging.getLogger(__name__)


@router.get("/admin/llms/options")
Expand Down Expand Up @@ -81,6 +83,7 @@ def test_llm(
success = True
error = ""
except Exception as e:
logger.debug(e)
success = False
error = str(e)
return LLMTestResult(success=success, error=error)
Expand Down
25 changes: 25 additions & 0 deletions backend/app/rag/chat_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from typing import Optional

import dspy
from llama_index.llms.bedrock.utils import BEDROCK_FOUNDATION_LLMS
from pydantic import BaseModel
from llama_index.llms.openai.utils import DEFAULT_OPENAI_API_BASE
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_like import OpenAILike
from llama_index.llms.gemini import Gemini
from llama_index.llms.bedrock import Bedrock
from llama_index.core.llms.llm import LLM
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.postprocessor.types import BaseNodePostprocessor
Expand Down Expand Up @@ -170,6 +172,29 @@ def get_llm(
case LLMProvider.GEMINI:
os.environ["GOOGLE_API_KEY"] = credentials
return Gemini(model=model, api_key=credentials, **config)
case LLMProvider.BEDROCK:
access_key_id = credentials["aws_access_key_id"]
secret_access_key = credentials["aws_secret_access_key"]
region_name = credentials["aws_region_name"]

context_size = None
if model not in BEDROCK_FOUNDATION_LLMS:
context_size = 200000

llm = Bedrock(
model=model,
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
region_name=region_name,
context_size=context_size
)
# Note: Because llama index Bedrock class doesn't set up these values to the corresponding
# attributes in its constructor function, we pass the values again via setter to pass them to
# `get_dspy_lm_by_llama_llm` function.
llm.aws_access_key_id=access_key_id
Mini256 marked this conversation as resolved.
Show resolved Hide resolved
llm.aws_secret_access_key=secret_access_key
llm.region_name=region_name
return llm
case LLMProvider.ANTHROPIC_VERTEX:
google_creds: service_account.Credentials = (
service_account.Credentials.from_service_account_info(
Expand Down
13 changes: 13 additions & 0 deletions backend/app/rag/llm_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,17 @@ class LLMOption(BaseModel):
"private_key_id": "****",
},
),
LLMOption(
provider=LLMProvider.BEDROCK,
default_llm_model="anthropic.claude-3-5-sonnet-20240620-v1:0",
llm_model_description="",
credentials_display_name="AWS Bedrock Credentials JSON",
credentials_description="The JSON Object of AWS Credentials, refer to https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html#cli-configure-files-global",
credentials_type="dict",
default_credentials={
"aws_access_key_id": "****",
"aws_secret_access_key": "****",
"aws_region_name": "us-west-2"
},
),
]
6 changes: 6 additions & 0 deletions backend/app/rag/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class LLMProvider(str, enum.Enum):
OPENAI = "openai"
GEMINI = "gemini"
VERTEX = "vertex"
BEDROCK = "bedrock"


class OpenAIModel(str, enum.Enum):
Expand All @@ -23,6 +24,11 @@ class VertexModel(str, enum.Enum):
CLAUDE_35_SONNET = "claude-3-5-sonnet@20240620"


# Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/models
class BedrockModel(str, enum.Enum):
CLAUDE_35_SONNET = "anthropic.claude-3-5-sonnet-20240620-v1:0"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These class appears to be deprecated. I will delete it after merging this PR.



class EmbeddingProvider(str, enum.Enum):
OPENAI = "openai"

Expand Down
1 change: 1 addition & 0 deletions backend/app/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class LLMProvider(str, enum.Enum):
GEMINI = "gemini"
ANTHROPIC_VERTEX = "anthropic_vertex"
OPENAI_LIKE = "openai_like"
BEDROCK = "bedrock"


class EmbeddingProvider(str, enum.Enum):
Expand Down
35 changes: 35 additions & 0 deletions backend/app/utils/dspy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os

import dspy

from llama_index.core.base.llms.base import BaseLLM
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_like import OpenAILike
from llama_index.llms.gemini import Gemini
from llama_index.llms.bedrock import Bedrock
from app.rag.llms.anthropic_vertex import AnthropicVertex


Expand Down Expand Up @@ -36,6 +39,38 @@ def get_dspy_lm_by_llama_llm(llama_llm: BaseLLM) -> dspy.LM:
model=llama_llm.model.split("models/")[1],
max_output_tokens=llama_llm.max_tokens or 8192,
)
elif isinstance(llama_llm, Bedrock):
# Notice: dspy.Bedrock currently does not support configuring access keys through parameters.
# Using environment variables for configuration risks contaminating global variables.
os.environ["AWS_ACCESS_KEY_ID"] = llama_llm.aws_access_key_id
os.environ["AWS_SECRET_ACCESS_KEY"] = llama_llm.aws_secret_access_key
bedrock = dspy.Bedrock(region_name=llama_llm.region_name)
if llama_llm.model.startswith("anthropic"):
return dspy.AWSAnthropic(
bedrock,
model=llama_llm.model,
max_new_tokens=llama_llm.max_tokens
)
elif llama_llm.model.startswith("meta"):
return dspy.AWSMeta(
bedrock,
model=llama_llm.model,
max_new_tokens=llama_llm.max_tokens
)
elif llama_llm.model.startswith("mistral"):
return dspy.AWSMistral(
bedrock,
model=llama_llm.model,
max_new_tokens=llama_llm.max_tokens
)
elif llama_llm.model.startswith("amazon"):
return dspy.AWSModel(
bedrock,
model=llama_llm.model,
max_new_tokens=llama_llm.max_tokens
)
else:
raise ValueError("Bedrock model " + llama_llm.model + " is not supported by dspy.")
elif isinstance(llama_llm, AnthropicVertex):
raise ValueError("AnthropicVertex is not supported by dspy.")
else:
Expand Down
3 changes: 2 additions & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ dependencies = [
"fastapi-pagination>=0.12.25",
"gunicorn>=22.0.0",
"pyyaml>=6.0.1",
"anthropic[vertex]>=0.31.0",
"anthropic[vertex]>=0.28.1",
"google-cloud-aiplatform>=1.59.0",
"deepeval>=0.21.73",
"llama-index-llms-openai>=0.1.27",
"llama-index-llms-openai-like>=0.1.3",
"playwright>=1.45.1",
"markdownify>=0.13.1",
"llama-index-postprocessor-cohere-rerank>=0.1.7",
"llama-index-llms-bedrock>=0.1.12",
]
readme = "README.md"
requires-python = ">= 3.8"
Expand Down
20 changes: 18 additions & 2 deletions backend/requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: true

aiohttp==3.9.5
# via datasets
Expand All @@ -23,7 +24,8 @@ amqp==5.2.0
# via kombu
annotated-types==0.7.0
# via pydantic
anthropic==0.31.0
anthropic==0.28.1
# via llama-index-llms-anthropic
anyio==4.4.0
# via anthropic
# via httpx
Expand Down Expand Up @@ -52,6 +54,7 @@ billiard==4.2.0
# via celery
boto3==1.34.156
# via cohere
# via llama-index-llms-bedrock
botocore==1.34.156
# via boto3
# via s3transfer
Expand Down Expand Up @@ -85,6 +88,12 @@ click-repl==0.3.0
# via celery
cohere==5.6.2
# via llama-index-postprocessor-cohere-rerank
colorama==0.4.6 ; platform_system == 'Windows' or sys_platform == 'win32'
# via click
# via colorlog
# via pytest
# via tqdm
# via uvicorn
colorlog==6.8.2
# via optuna
cryptography==42.0.8
Expand Down Expand Up @@ -306,6 +315,8 @@ llama-index-core==0.10.59
# via llama-index-cli
# via llama-index-embeddings-openai
# via llama-index-indices-managed-llama-cloud
# via llama-index-llms-anthropic
# via llama-index-llms-bedrock
# via llama-index-llms-gemini
# via llama-index-llms-openai
# via llama-index-llms-openai-like
Expand All @@ -324,6 +335,9 @@ llama-index-indices-managed-llama-cloud==0.2.5
# via llama-index
llama-index-legacy==0.9.48
# via llama-index
llama-index-llms-anthropic==0.1.16
# via llama-index-llms-bedrock
llama-index-llms-bedrock==0.1.12
llama-index-llms-gemini==0.1.11
llama-index-llms-openai==0.1.27
# via llama-index
Expand Down Expand Up @@ -538,6 +552,8 @@ python-multipart==0.0.9
pytz==2024.1
# via flower
# via pandas
pywin32==306 ; platform_system == 'Windows'
# via portalocker
pyyaml==6.0.1
# via datasets
# via huggingface-hub
Expand Down Expand Up @@ -689,7 +705,7 @@ urllib3==2.2.1
# via types-requests
uvicorn==0.30.3
# via fastapi
uvloop==0.19.0
uvloop==0.19.0 ; platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'
# via uvicorn
vine==5.1.0
# via amqp
Expand Down
20 changes: 18 additions & 2 deletions backend/requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: true

aiohttp==3.9.5
# via datasets
Expand All @@ -23,7 +24,8 @@ amqp==5.2.0
# via kombu
annotated-types==0.7.0
# via pydantic
anthropic==0.31.0
anthropic==0.28.1
# via llama-index-llms-anthropic
anyio==4.4.0
# via anthropic
# via httpx
Expand Down Expand Up @@ -52,6 +54,7 @@ billiard==4.2.0
# via celery
boto3==1.34.156
# via cohere
# via llama-index-llms-bedrock
botocore==1.34.156
# via boto3
# via s3transfer
Expand Down Expand Up @@ -85,6 +88,12 @@ click-repl==0.3.0
# via celery
cohere==5.6.2
# via llama-index-postprocessor-cohere-rerank
colorama==0.4.6 ; platform_system == 'Windows' or sys_platform == 'win32'
# via click
# via colorlog
# via pytest
# via tqdm
# via uvicorn
colorlog==6.8.2
# via optuna
cryptography==42.0.8
Expand Down Expand Up @@ -306,6 +315,8 @@ llama-index-core==0.10.59
# via llama-index-cli
# via llama-index-embeddings-openai
# via llama-index-indices-managed-llama-cloud
# via llama-index-llms-anthropic
# via llama-index-llms-bedrock
# via llama-index-llms-gemini
# via llama-index-llms-openai
# via llama-index-llms-openai-like
Expand All @@ -324,6 +335,9 @@ llama-index-indices-managed-llama-cloud==0.2.5
# via llama-index
llama-index-legacy==0.9.48
# via llama-index
llama-index-llms-anthropic==0.1.16
# via llama-index-llms-bedrock
llama-index-llms-bedrock==0.1.12
llama-index-llms-gemini==0.1.11
llama-index-llms-openai==0.1.27
# via llama-index
Expand Down Expand Up @@ -538,6 +552,8 @@ python-multipart==0.0.9
pytz==2024.1
# via flower
# via pandas
pywin32==306 ; platform_system == 'Windows'
# via portalocker
pyyaml==6.0.1
# via datasets
# via huggingface-hub
Expand Down Expand Up @@ -689,7 +705,7 @@ urllib3==2.2.1
# via types-requests
uvicorn==0.30.3
# via fastapi
uvloop==0.19.0
uvloop==0.19.0 ; platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'
# via uvicorn
vine==5.1.0
# via amqp
Expand Down