Skip to content

Commit

Permalink
feat: support bedrock llm provider (#213)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mini256 authored Aug 9, 2024
1 parent 3eb8d4e commit b61c9a8
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 5 deletions.
31 changes: 31 additions & 0 deletions backend/app/alembic/versions/04d81be446c3_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""empty message
Revision ID: 04d81be446c3
Revises: e32f1e546eec
Create Date: 2024-08-08 17:11:50.178696
"""
from alembic import op
from sqlalchemy.dialects import mysql

# revision identifiers, used by Alembic.
revision = '04d81be446c3'
down_revision = 'e32f1e546eec'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('llms', 'provider',
existing_type=mysql.ENUM('OPENAI', 'GEMINI', 'ANTHROPIC_VERTEX', 'OPENAI_LIKE', 'BEDROCK'),
nullable=False)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('llms', 'provider',
existing_type=mysql.ENUM('OPENAI', 'GEMINI', 'ANTHROPIC_VERTEX', 'OPENAI_LIKE'),
nullable=False)
# ### end Alembic commands ###
3 changes: 3 additions & 0 deletions backend/app/api/admin_routes/llm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import List
from pydantic import BaseModel

Expand Down Expand Up @@ -28,6 +29,7 @@
)

router = APIRouter()
logger = logging.getLogger(__name__)


@router.get("/admin/llms/options")
Expand Down Expand Up @@ -81,6 +83,7 @@ def test_llm(
success = True
error = ""
except Exception as e:
logger.debug(e)
success = False
error = str(e)
return LLMTestResult(success=success, error=error)
Expand Down
25 changes: 25 additions & 0 deletions backend/app/rag/chat_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from typing import Optional

import dspy
from llama_index.llms.bedrock.utils import BEDROCK_FOUNDATION_LLMS
from pydantic import BaseModel
from llama_index.llms.openai.utils import DEFAULT_OPENAI_API_BASE
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_like import OpenAILike
from llama_index.llms.gemini import Gemini
from llama_index.llms.bedrock import Bedrock
from llama_index.core.llms.llm import LLM
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.postprocessor.types import BaseNodePostprocessor
Expand Down Expand Up @@ -170,6 +172,29 @@ def get_llm(
case LLMProvider.GEMINI:
os.environ["GOOGLE_API_KEY"] = credentials
return Gemini(model=model, api_key=credentials, **config)
case LLMProvider.BEDROCK:
access_key_id = credentials["aws_access_key_id"]
secret_access_key = credentials["aws_secret_access_key"]
region_name = credentials["aws_region_name"]

context_size = None
if model not in BEDROCK_FOUNDATION_LLMS:
context_size = 200000

llm = Bedrock(
model=model,
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
region_name=region_name,
context_size=context_size
)
# Note: Because llama index Bedrock class doesn't set up these values to the corresponding
# attributes in its constructor function, we pass the values again via setter to pass them to
# `get_dspy_lm_by_llama_llm` function.
llm.aws_access_key_id=access_key_id
llm.aws_secret_access_key=secret_access_key
llm.region_name=region_name
return llm
case LLMProvider.ANTHROPIC_VERTEX:
google_creds: service_account.Credentials = (
service_account.Credentials.from_service_account_info(
Expand Down
13 changes: 13 additions & 0 deletions backend/app/rag/llm_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,17 @@ class LLMOption(BaseModel):
"private_key_id": "****",
},
),
LLMOption(
provider=LLMProvider.BEDROCK,
default_llm_model="anthropic.claude-3-5-sonnet-20240620-v1:0",
llm_model_description="",
credentials_display_name="AWS Bedrock Credentials JSON",
credentials_description="The JSON Object of AWS Credentials, refer to https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html#cli-configure-files-global",
credentials_type="dict",
default_credentials={
"aws_access_key_id": "****",
"aws_secret_access_key": "****",
"aws_region_name": "us-west-2"
},
),
]
6 changes: 6 additions & 0 deletions backend/app/rag/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class LLMProvider(str, enum.Enum):
OPENAI = "openai"
GEMINI = "gemini"
VERTEX = "vertex"
BEDROCK = "bedrock"


class OpenAIModel(str, enum.Enum):
Expand All @@ -23,6 +24,11 @@ class VertexModel(str, enum.Enum):
CLAUDE_35_SONNET = "claude-3-5-sonnet@20240620"


# Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/models
class BedrockModel(str, enum.Enum):
CLAUDE_35_SONNET = "anthropic.claude-3-5-sonnet-20240620-v1:0"


class EmbeddingProvider(str, enum.Enum):
OPENAI = "openai"

Expand Down
1 change: 1 addition & 0 deletions backend/app/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class LLMProvider(str, enum.Enum):
GEMINI = "gemini"
ANTHROPIC_VERTEX = "anthropic_vertex"
OPENAI_LIKE = "openai_like"
BEDROCK = "bedrock"


class EmbeddingProvider(str, enum.Enum):
Expand Down
35 changes: 35 additions & 0 deletions backend/app/utils/dspy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os

import dspy

from llama_index.core.base.llms.base import BaseLLM
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_like import OpenAILike
from llama_index.llms.gemini import Gemini
from llama_index.llms.bedrock import Bedrock
from app.rag.llms.anthropic_vertex import AnthropicVertex


Expand Down Expand Up @@ -36,6 +39,38 @@ def get_dspy_lm_by_llama_llm(llama_llm: BaseLLM) -> dspy.LM:
model=llama_llm.model.split("models/")[1],
max_output_tokens=llama_llm.max_tokens or 8192,
)
elif isinstance(llama_llm, Bedrock):
# Notice: dspy.Bedrock currently does not support configuring access keys through parameters.
# Using environment variables for configuration risks contaminating global variables.
os.environ["AWS_ACCESS_KEY_ID"] = llama_llm.aws_access_key_id
os.environ["AWS_SECRET_ACCESS_KEY"] = llama_llm.aws_secret_access_key
bedrock = dspy.Bedrock(region_name=llama_llm.region_name)
if llama_llm.model.startswith("anthropic"):
return dspy.AWSAnthropic(
bedrock,
model=llama_llm.model,
max_new_tokens=llama_llm.max_tokens
)
elif llama_llm.model.startswith("meta"):
return dspy.AWSMeta(
bedrock,
model=llama_llm.model,
max_new_tokens=llama_llm.max_tokens
)
elif llama_llm.model.startswith("mistral"):
return dspy.AWSMistral(
bedrock,
model=llama_llm.model,
max_new_tokens=llama_llm.max_tokens
)
elif llama_llm.model.startswith("amazon"):
return dspy.AWSModel(
bedrock,
model=llama_llm.model,
max_new_tokens=llama_llm.max_tokens
)
else:
raise ValueError("Bedrock model " + llama_llm.model + " is not supported by dspy.")
elif isinstance(llama_llm, AnthropicVertex):
raise ValueError("AnthropicVertex is not supported by dspy.")
else:
Expand Down
3 changes: 2 additions & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ dependencies = [
"fastapi-pagination>=0.12.25",
"gunicorn>=22.0.0",
"pyyaml>=6.0.1",
"anthropic[vertex]>=0.31.0",
"anthropic[vertex]>=0.28.1",
"google-cloud-aiplatform>=1.59.0",
"deepeval>=0.21.73",
"llama-index-llms-openai>=0.1.27",
"llama-index-llms-openai-like>=0.1.3",
"playwright>=1.45.1",
"markdownify>=0.13.1",
"llama-index-postprocessor-cohere-rerank>=0.1.7",
"llama-index-llms-bedrock>=0.1.12",
]
readme = "README.md"
requires-python = ">= 3.8"
Expand Down
20 changes: 18 additions & 2 deletions backend/requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: true

aiohttp==3.9.5
# via datasets
Expand All @@ -23,7 +24,8 @@ amqp==5.2.0
# via kombu
annotated-types==0.7.0
# via pydantic
anthropic==0.31.0
anthropic==0.28.1
# via llama-index-llms-anthropic
anyio==4.4.0
# via anthropic
# via httpx
Expand Down Expand Up @@ -52,6 +54,7 @@ billiard==4.2.0
# via celery
boto3==1.34.156
# via cohere
# via llama-index-llms-bedrock
botocore==1.34.156
# via boto3
# via s3transfer
Expand Down Expand Up @@ -85,6 +88,12 @@ click-repl==0.3.0
# via celery
cohere==5.6.2
# via llama-index-postprocessor-cohere-rerank
colorama==0.4.6 ; platform_system == 'Windows' or sys_platform == 'win32'
# via click
# via colorlog
# via pytest
# via tqdm
# via uvicorn
colorlog==6.8.2
# via optuna
cryptography==42.0.8
Expand Down Expand Up @@ -306,6 +315,8 @@ llama-index-core==0.10.59
# via llama-index-cli
# via llama-index-embeddings-openai
# via llama-index-indices-managed-llama-cloud
# via llama-index-llms-anthropic
# via llama-index-llms-bedrock
# via llama-index-llms-gemini
# via llama-index-llms-openai
# via llama-index-llms-openai-like
Expand All @@ -324,6 +335,9 @@ llama-index-indices-managed-llama-cloud==0.2.5
# via llama-index
llama-index-legacy==0.9.48
# via llama-index
llama-index-llms-anthropic==0.1.16
# via llama-index-llms-bedrock
llama-index-llms-bedrock==0.1.12
llama-index-llms-gemini==0.1.11
llama-index-llms-openai==0.1.27
# via llama-index
Expand Down Expand Up @@ -538,6 +552,8 @@ python-multipart==0.0.9
pytz==2024.1
# via flower
# via pandas
pywin32==306 ; platform_system == 'Windows'
# via portalocker
pyyaml==6.0.1
# via datasets
# via huggingface-hub
Expand Down Expand Up @@ -689,7 +705,7 @@ urllib3==2.2.1
# via types-requests
uvicorn==0.30.3
# via fastapi
uvloop==0.19.0
uvloop==0.19.0 ; platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'
# via uvicorn
vine==5.1.0
# via amqp
Expand Down
20 changes: 18 additions & 2 deletions backend/requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# all-features: false
# with-sources: false
# generate-hashes: false
# universal: true

aiohttp==3.9.5
# via datasets
Expand All @@ -23,7 +24,8 @@ amqp==5.2.0
# via kombu
annotated-types==0.7.0
# via pydantic
anthropic==0.31.0
anthropic==0.28.1
# via llama-index-llms-anthropic
anyio==4.4.0
# via anthropic
# via httpx
Expand Down Expand Up @@ -52,6 +54,7 @@ billiard==4.2.0
# via celery
boto3==1.34.156
# via cohere
# via llama-index-llms-bedrock
botocore==1.34.156
# via boto3
# via s3transfer
Expand Down Expand Up @@ -85,6 +88,12 @@ click-repl==0.3.0
# via celery
cohere==5.6.2
# via llama-index-postprocessor-cohere-rerank
colorama==0.4.6 ; platform_system == 'Windows' or sys_platform == 'win32'
# via click
# via colorlog
# via pytest
# via tqdm
# via uvicorn
colorlog==6.8.2
# via optuna
cryptography==42.0.8
Expand Down Expand Up @@ -306,6 +315,8 @@ llama-index-core==0.10.59
# via llama-index-cli
# via llama-index-embeddings-openai
# via llama-index-indices-managed-llama-cloud
# via llama-index-llms-anthropic
# via llama-index-llms-bedrock
# via llama-index-llms-gemini
# via llama-index-llms-openai
# via llama-index-llms-openai-like
Expand All @@ -324,6 +335,9 @@ llama-index-indices-managed-llama-cloud==0.2.5
# via llama-index
llama-index-legacy==0.9.48
# via llama-index
llama-index-llms-anthropic==0.1.16
# via llama-index-llms-bedrock
llama-index-llms-bedrock==0.1.12
llama-index-llms-gemini==0.1.11
llama-index-llms-openai==0.1.27
# via llama-index
Expand Down Expand Up @@ -538,6 +552,8 @@ python-multipart==0.0.9
pytz==2024.1
# via flower
# via pandas
pywin32==306 ; platform_system == 'Windows'
# via portalocker
pyyaml==6.0.1
# via datasets
# via huggingface-hub
Expand Down Expand Up @@ -689,7 +705,7 @@ urllib3==2.2.1
# via types-requests
uvicorn==0.30.3
# via fastapi
uvloop==0.19.0
uvloop==0.19.0 ; platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'
# via uvicorn
vine==5.1.0
# via amqp
Expand Down

0 comments on commit b61c9a8

Please sign in to comment.