From 5ad09fa0d10f59c92a872d51cb6ee24c76338681 Mon Sep 17 00:00:00 2001 From: Sergey Pugachev Date: Thu, 2 Nov 2023 17:14:04 +0000 Subject: [PATCH] Allow to deploy Kendra without SageMaker endpoints --- bin/config.ts | 2 +- lib/aws-genai-llm-chatbot-stack.ts | 4 ++++ .../api-handler/routes/cross_encoders.py | 5 +--- .../api-handler/routes/embeddings.py | 6 ++--- lib/chatbot-api/rest-api.ts | 21 +++++++++-------- lib/model-interfaces/langchain/index.ts | 11 ++++++--- .../data-import/file-import-batch-job.ts | 2 +- lib/rag-engines/data-import/index.ts | 10 ++++---- lib/rag-engines/index.ts | 21 +++++++++-------- .../python/genai_core/cross_encoder.py | 11 +++++++++ .../python/genai_core/embeddings.py | 14 ++++++++++- lib/user-interface/index.ts | 4 ++++ .../react-app/src/common/types.ts | 3 ++- .../src/components/navigation-panel.tsx | 23 ++++++++++++++----- .../rag/cross-encoders/cross-encoders.tsx | 1 + 15 files changed, 94 insertions(+), 44 deletions(-) diff --git a/bin/config.ts b/bin/config.ts index c13ffb22e..344a04013 100644 --- a/bin/config.ts +++ b/bin/config.ts @@ -39,7 +39,6 @@ export function getConfig(): SystemConfig { { provider: "sagemaker", name: "intfloat/multilingual-e5-large", - default: true, dimensions: 1024, }, { @@ -51,6 +50,7 @@ export function getConfig(): SystemConfig { provider: "bedrock", name: "amazon.titan-embed-text-v1", dimensions: 1536, + default: true, }, { provider: "openai", diff --git a/lib/aws-genai-llm-chatbot-stack.ts b/lib/aws-genai-llm-chatbot-stack.ts index 4dc196a17..2d3b509f0 100644 --- a/lib/aws-genai-llm-chatbot-stack.ts +++ b/lib/aws-genai-llm-chatbot-stack.ts @@ -151,6 +151,10 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack { restApi: chatBotApi.restApi, webSocketApi: chatBotApi.webSocketApi, chatbotFilesBucket: chatBotApi.filesBucket, + crossEncodersEnabled: + typeof ragEngines?.sageMakerRagModels?.model !== "undefined", + sagemakerEmbeddingsEnabled: + typeof ragEngines?.sageMakerRagModels?.model !== "undefined", }); } } diff --git a/lib/chatbot-api/functions/api-handler/routes/cross_encoders.py b/lib/chatbot-api/functions/api-handler/routes/cross_encoders.py index 29721729b..52472be01 100644 --- a/lib/chatbot-api/functions/api-handler/routes/cross_encoders.py +++ b/lib/chatbot-api/functions/api-handler/routes/cross_encoders.py @@ -21,8 +21,7 @@ class CrossEncodersRequest(BaseModel): @router.get("/cross-encoders/models") @tracer.capture_method def models(): - config = genai_core.parameters.get_config() - models = config["rag"]["crossEncoderModels"] + models = genai_core.cross_encoder.get_cross_encoder_models() return {"ok": True, "data": models} @@ -30,8 +29,6 @@ def models(): @router.post("/cross-encoders") @tracer.capture_method def cross_encoders(): - config = genai_core.parameters.get_config() - data: dict = router.current_event.json_body request = CrossEncodersRequest(**data) selected_model = genai_core.cross_encoder.get_cross_encoder_model( diff --git a/lib/chatbot-api/functions/api-handler/routes/embeddings.py b/lib/chatbot-api/functions/api-handler/routes/embeddings.py index da598cdce..faf467756 100644 --- a/lib/chatbot-api/functions/api-handler/routes/embeddings.py +++ b/lib/chatbot-api/functions/api-handler/routes/embeddings.py @@ -20,8 +20,7 @@ class EmbeddingsRequest(BaseModel): @router.get("/embeddings/models") @tracer.capture_method def models(): - config = genai_core.parameters.get_config() - models = config["rag"]["embeddingsModels"] + models = genai_core.embeddings.get_embeddings_models() return {"ok": True, "data": models} @@ -38,6 +37,7 @@ def embeddings(): if selected_model is None: raise genai_core.types.CommonError("Model not found") - ret_value = genai_core.embeddings.generate_embeddings(selected_model, request.input) + ret_value = genai_core.embeddings.generate_embeddings( + selected_model, request.input) return {"ok": True, "data": ret_value} diff --git a/lib/chatbot-api/rest-api.ts b/lib/chatbot-api/rest-api.ts index a5b59d2d2..1fe572139 100644 --- a/lib/chatbot-api/rest-api.ts +++ b/lib/chatbot-api/rest-api.ts @@ -1,4 +1,9 @@ +import * as path from "path"; import * as cdk from "aws-cdk-lib"; +import { SageMakerModelEndpoint, SystemConfig } from "../shared/types"; +import { Construct } from "constructs"; +import { RagEngines } from "../rag-engines"; +import { Shared } from "../shared"; import * as apigateway from "aws-cdk-lib/aws-apigateway"; import * as cognito from "aws-cdk-lib/aws-cognito"; import * as dynamodb from "aws-cdk-lib/aws-dynamodb"; @@ -7,11 +12,6 @@ import * as iam from "aws-cdk-lib/aws-iam"; import * as lambda from "aws-cdk-lib/aws-lambda"; import * as logs from "aws-cdk-lib/aws-logs"; import * as ssm from "aws-cdk-lib/aws-ssm"; -import { Construct } from "constructs"; -import * as path from "path"; -import { RagEngines } from "../rag-engines"; -import { Shared } from "../shared"; -import { SageMakerModelEndpoint, SystemConfig } from "../shared/types"; export interface RestApiProps { readonly shared: Shared; @@ -74,7 +74,8 @@ export class RestApi extends Construct { DOCUMENTS_BY_COMPOUND_KEY_INDEX_NAME: props.ragEngines?.documentsByCompountKeyIndexName ?? "", SAGEMAKER_RAG_MODELS_ENDPOINT: - props.ragEngines?.sageMakerRagModelsEndpoint?.attrEndpointName ?? "", + props.ragEngines?.sageMakerRagModels?.model.endpoint + ?.attrEndpointName ?? "", DELETE_WORKSPACE_WORKFLOW_ARN: props.ragEngines?.deleteWorkspaceWorkflow?.stateMachineArn ?? "", CREATE_AURORA_WORKSPACE_WORKFLOW_ARN: @@ -190,7 +191,9 @@ export class RestApi extends Construct { new iam.PolicyStatement({ actions: ["kendra:Retrieve", "kendra:Query"], resources: [ - `arn:${cdk.Aws.PARTITION}:kendra:${item.region}:${cdk.Aws.ACCOUNT_ID}:index/${item.kendraId}`, + `arn:${cdk.Aws.PARTITION}:kendra:${ + item.region ?? cdk.Aws.REGION + }:${cdk.Aws.ACCOUNT_ID}:index/${item.kendraId}`, ], }) ); @@ -210,11 +213,11 @@ export class RestApi extends Construct { props.ragEngines.deleteWorkspaceWorkflow.grantStartExecution(apiHandler); } - if (props.ragEngines?.sageMakerRagModelsEndpoint) { + if (props.ragEngines?.sageMakerRagModels) { apiHandler.addToRolePolicy( new iam.PolicyStatement({ actions: ["sagemaker:InvokeEndpoint"], - resources: [props.ragEngines?.sageMakerRagModelsEndpoint.ref], + resources: [props.ragEngines.sageMakerRagModels.model.endpoint.ref], }) ); } diff --git a/lib/model-interfaces/langchain/index.ts b/lib/model-interfaces/langchain/index.ts index 050259176..a6b3fdf0e 100644 --- a/lib/model-interfaces/langchain/index.ts +++ b/lib/model-interfaces/langchain/index.ts @@ -60,7 +60,8 @@ export class LangChainInterface extends Construct { AURORA_DB_SECRET_ID: props.ragEngines?.auroraPgVector?.database?.secret ?.secretArn as string, SAGEMAKER_RAG_MODELS_ENDPOINT: - props.ragEngines?.sageMakerRagModelsEndpoint?.attrEndpointName ?? "", + props.ragEngines?.sageMakerRagModels?.model.endpoint + ?.attrEndpointName ?? "", OPEN_SEARCH_COLLECTION_ENDPOINT: props.ragEngines?.openSearchVector?.openSearchCollectionEndpoint ?? "", @@ -126,11 +127,13 @@ export class LangChainInterface extends Construct { if (props.ragEngines) { props.ragEngines.workspacesTable.grantReadWriteData(requestHandler); props.ragEngines.documentsTable.grantReadWriteData(requestHandler); + } + if (props.ragEngines?.sageMakerRagModels) { requestHandler.addToRolePolicy( new iam.PolicyStatement({ actions: ["sagemaker:InvokeEndpoint"], - resources: [props.ragEngines?.sageMakerRagModelsEndpoint.ref], + resources: [props.ragEngines.sageMakerRagModels.model.endpoint.ref], }) ); } @@ -162,7 +165,9 @@ export class LangChainInterface extends Construct { new iam.PolicyStatement({ actions: ["kendra:Retrieve", "kendra:Query"], resources: [ - `arn:${cdk.Aws.PARTITION}:kendra:${item.region}:${cdk.Aws.ACCOUNT_ID}:index/${item.kendraId}`, + `arn:${cdk.Aws.PARTITION}:kendra:${ + item.region ?? cdk.Aws.REGION + }:${cdk.Aws.ACCOUNT_ID}:index/${item.kendraId}`, ], }) ); diff --git a/lib/rag-engines/data-import/file-import-batch-job.ts b/lib/rag-engines/data-import/file-import-batch-job.ts index 09ec66d28..53c216c9f 100644 --- a/lib/rag-engines/data-import/file-import-batch-job.ts +++ b/lib/rag-engines/data-import/file-import-batch-job.ts @@ -37,7 +37,7 @@ export class FileImportBatchJob extends Construct { vpc: props.shared.vpc, allocationStrategy: batch.AllocationStrategy.BEST_FIT, maxvCpus: 4, - minvCpus: 4, + minvCpus: 0, replaceComputeEnvironment: true, updateTimeout: cdk.Duration.minutes(30), updateToLatestImageVersion: true, diff --git a/lib/rag-engines/data-import/index.ts b/lib/rag-engines/data-import/index.ts index e32e3f08c..babc12dc1 100644 --- a/lib/rag-engines/data-import/index.ts +++ b/lib/rag-engines/data-import/index.ts @@ -9,13 +9,13 @@ import { FileImportWorkflow } from "./file-import-workflow"; import { WebsiteCrawlingWorkflow } from "./website-crawling-workflow"; import { OpenSearchVector } from "../opensearch-vector"; import { KendraRetrieval } from "../kendra-retrieval"; +import { SageMakerRagModels } from "../sagemaker-rag-models"; import * as s3 from "aws-cdk-lib/aws-s3"; import * as sqs from "aws-cdk-lib/aws-sqs"; import * as lambda from "aws-cdk-lib/aws-lambda"; import * as logs from "aws-cdk-lib/aws-logs"; import * as ec2 from "aws-cdk-lib/aws-ec2"; import * as iam from "aws-cdk-lib/aws-iam"; -import * as sagemaker from "aws-cdk-lib/aws-sagemaker"; import * as dynamodb from "aws-cdk-lib/aws-dynamodb"; import * as s3Notifications from "aws-cdk-lib/aws-s3-notifications"; import * as lambdaEventSources from "aws-cdk-lib/aws-lambda-event-sources"; @@ -29,7 +29,7 @@ export interface DataImportProps { readonly ragDynamoDBTables: RagDynamoDBTables; readonly openSearchVector?: OpenSearchVector; readonly kendraRetrieval?: KendraRetrieval; - readonly sageMakerRagModelsEndpoint?: sagemaker.CfnEndpoint; + readonly sageMakerRagModels?: SageMakerRagModels; readonly workspacesTable: dynamodb.Table; readonly documentsTable: dynamodb.Table; readonly workspacesByObjectTypeIndexName: string; @@ -109,7 +109,7 @@ export class DataImport extends Construct { processingBucket, auroraDatabase: props.auroraDatabase, ragDynamoDBTables: props.ragDynamoDBTables, - sageMakerRagModelsEndpoint: props.sageMakerRagModelsEndpoint, + sageMakerRagModelsEndpoint: props.sageMakerRagModels?.model.endpoint, openSearchVector: props.openSearchVector, } ); @@ -134,7 +134,7 @@ export class DataImport extends Construct { processingBucket, auroraDatabase: props.auroraDatabase, ragDynamoDBTables: props.ragDynamoDBTables, - sageMakerRagModelsEndpoint: props.sageMakerRagModelsEndpoint, + sageMakerRagModelsEndpoint: props.sageMakerRagModels?.model.endpoint, openSearchVector: props.openSearchVector, } ); @@ -170,7 +170,7 @@ export class DataImport extends Construct { DOCUMENTS_BY_COMPOUND_KEY_INDEX_NAME: props.documentsByCompountKeyIndexName ?? "", SAGEMAKER_RAG_MODELS_ENDPOINT: - props.sageMakerRagModelsEndpoint?.attrEndpointName ?? "", + props.sageMakerRagModels?.model.endpoint.attrEndpointName ?? "", FILE_IMPORT_WORKFLOW_ARN: fileImportWorkflow?.stateMachine.stateMachineArn ?? "", DEFAULT_KENDRA_S3_DATA_SOURCE_BUCKET_NAME: diff --git a/lib/rag-engines/index.ts b/lib/rag-engines/index.ts index 26ee2d201..712ade4fc 100644 --- a/lib/rag-engines/index.ts +++ b/lib/rag-engines/index.ts @@ -1,6 +1,5 @@ import * as dynamodb from "aws-cdk-lib/aws-dynamodb"; import * as s3 from "aws-cdk-lib/aws-s3"; -import * as sagemaker from "aws-cdk-lib/aws-sagemaker"; import * as sfn from "aws-cdk-lib/aws-stepfunctions"; import { Construct } from "constructs"; import { Shared } from "../shared"; @@ -22,13 +21,13 @@ export class RagEngines extends Construct { public readonly auroraPgVector: AuroraPgVector | null; public readonly openSearchVector: OpenSearchVector | null; public readonly kendraRetrieval: KendraRetrieval | null; + public readonly sageMakerRagModels: SageMakerRagModels | null; public readonly uploadBucket: s3.Bucket; public readonly processingBucket: s3.Bucket; public readonly documentsTable: dynamodb.Table; public readonly workspacesTable: dynamodb.Table; public readonly workspacesByObjectTypeIndexName: string; public readonly documentsByCompountKeyIndexName: string; - public readonly sageMakerRagModelsEndpoint: sagemaker.CfnEndpoint; public readonly fileImportWorkflow?: sfn.StateMachine; public readonly websiteCrawlingWorkflow?: sfn.StateMachine; public readonly deleteWorkspaceWorkflow?: sfn.StateMachine; @@ -38,14 +37,16 @@ export class RagEngines extends Construct { const tables = new RagDynamoDBTables(this, "RagDynamoDBTables"); - const sageMakerRagModels = new SageMakerRagModels( - this, - "SageMaker", - { + let sageMakerRagModels: SageMakerRagModels | null = null; + if ( + props.config.rag.engines.aurora.enabled || + props.config.rag.engines.opensearch.enabled + ) { + sageMakerRagModels = new SageMakerRagModels(this, "SageMaker", { shared: props.shared, config: props.config, - } - ); + }); + } let auroraPgVector: AuroraPgVector | null = null; if (props.config.rag.engines.aurora.enabled) { @@ -78,7 +79,7 @@ export class RagEngines extends Construct { shared: props.shared, config: props.config, auroraDatabase: auroraPgVector?.database, - sageMakerRagModelsEndpoint: sageMakerRagModels.model.endpoint, + sageMakerRagModels: sageMakerRagModels ?? undefined, workspacesTable: tables.workspacesTable, documentsTable: tables.documentsTable, ragDynamoDBTables: tables, @@ -101,7 +102,7 @@ export class RagEngines extends Construct { this.auroraPgVector = auroraPgVector; this.openSearchVector = openSearchVector; this.kendraRetrieval = kendraRetrieval; - this.sageMakerRagModelsEndpoint = sageMakerRagModels.model.endpoint; + this.sageMakerRagModels = sageMakerRagModels; this.uploadBucket = dataImport.uploadBucket; this.processingBucket = dataImport.processingBucket; this.workspacesTable = tables.workspacesTable; diff --git a/lib/shared/layers/python-sdk/python/genai_core/cross_encoder.py b/lib/shared/layers/python-sdk/python/genai_core/cross_encoder.py index f3cc800d6..5600f1de9 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/cross_encoder.py +++ b/lib/shared/layers/python-sdk/python/genai_core/cross_encoder.py @@ -22,6 +22,17 @@ def rank_passages( raise genai_core.typesCommonError(f"Unknown provider") +def get_cross_encoder_models(): + config = genai_core.parameters.get_config() + models = config["rag"]["crossEncoderModels"] + + if not SAGEMAKER_RAG_MODELS_ENDPOINT: + models = list( + filter(lambda x: x["provider"] != "sagemaker", models)) + + return models + + def get_cross_encoder_model( provider: str, name: str ) -> Optional[genai_core.types.CrossEncoderModel]: diff --git a/lib/shared/layers/python-sdk/python/genai_core/embeddings.py b/lib/shared/layers/python-sdk/python/genai_core/embeddings.py index f2ee9928e..2366c7762 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/embeddings.py +++ b/lib/shared/layers/python-sdk/python/genai_core/embeddings.py @@ -18,7 +18,8 @@ def generate_embeddings( input = list(map(lambda x: x[:10000], input)) ret_value = [] - batch_split = [input[i : i + batch_size] for i in range(0, len(input), batch_size)] + batch_split = [input[i: i + batch_size] + for i in range(0, len(input), batch_size)] for batch in batch_split: if model.provider == "openai": @@ -33,6 +34,17 @@ def generate_embeddings( return ret_value +def get_embeddings_models(): + config = genai_core.parameters.get_config() + models = config["rag"]["embeddingsModels"] + + if not SAGEMAKER_RAG_MODELS_ENDPOINT: + models = list( + filter(lambda x: x["provider"] != "sagemaker", models)) + + return models + + def get_embeddings_model( provider: str, name: str ) -> Optional[genai_core.types.EmbeddingsModel]: diff --git a/lib/user-interface/index.ts b/lib/user-interface/index.ts index 41deb794f..e602ffe82 100644 --- a/lib/user-interface/index.ts +++ b/lib/user-interface/index.ts @@ -25,6 +25,8 @@ export interface UserInterfaceProps { readonly restApi: apigateway.RestApi; readonly webSocketApi: apigwv2.WebSocketApi; readonly chatbotFilesBucket: s3.Bucket; + readonly crossEncodersEnabled: boolean; + readonly sagemakerEmbeddingsEnabled: boolean; } export class UserInterface extends Construct { @@ -172,6 +174,8 @@ export class UserInterface extends Construct { api_endpoint: `https://${distribution.distributionDomainName}/api`, websocket_endpoint: `wss://${distribution.distributionDomainName}/socket`, rag_enabled: props.config.rag.enabled, + cross_encoders_enabled: props.crossEncodersEnabled, + sagemaker_embeddings_enabled: props.sagemakerEmbeddingsEnabled, default_embeddings_model: Utils.getDefaultEmbeddingsModel(props.config), default_cross_encoder_model: Utils.getDefaultCrossEncoderModel( props.config diff --git a/lib/user-interface/react-app/src/common/types.ts b/lib/user-interface/react-app/src/common/types.ts index 28c8b78a8..888d1f2db 100644 --- a/lib/user-interface/react-app/src/common/types.ts +++ b/lib/user-interface/react-app/src/common/types.ts @@ -16,6 +16,8 @@ export interface AppConfig { name: CognitoHostedUIIdentityProvider; }; rag_enabled: boolean; + cross_encoders_enabled: boolean; + sagemaker_embeddings_enabled: boolean; api_endpoint: string; websocket_endpoint: string; default_embeddings_model: string; @@ -27,7 +29,6 @@ export interface AppConfig { region: string; }; }; - } export interface NavigationPanelState { diff --git a/lib/user-interface/react-app/src/components/navigation-panel.tsx b/lib/user-interface/react-app/src/components/navigation-panel.tsx index 1f2f0f3da..f697dd15d 100644 --- a/lib/user-interface/react-app/src/components/navigation-panel.tsx +++ b/lib/user-interface/react-app/src/components/navigation-panel.tsx @@ -24,7 +24,11 @@ export default function NavigationPanel() { text: "Chatbot", items: [ { type: "link", text: "Playground", href: "/chatbot/playground" }, - { type: "link", text: "Multi-chat playground", href: "/chatbot/multichat" }, + { + type: "link", + text: "Multi-chat playground", + href: "/chatbot/multichat", + }, { type: "link", text: "Models", @@ -35,6 +39,17 @@ export default function NavigationPanel() { ]; if (appContext?.config.rag_enabled) { + const crossEncodersItems: SideNavigationProps.Item[] = appContext?.config + .cross_encoders_enabled + ? [ + { + type: "link", + text: "Cross-encoders", + href: "/rag/cross-encoders", + }, + ] + : []; + items.push({ type: "section", text: "Retrieval-Augmented Generation (RAG)", @@ -51,11 +66,7 @@ export default function NavigationPanel() { text: "Embeddings", href: "/rag/embeddings", }, - { - type: "link", - text: "Cross-encoders", - href: "/rag/cross-encoders", - }, + ...crossEncodersItems, { type: "link", text: "Engines", href: "/rag/engines" }, ], }); diff --git a/lib/user-interface/react-app/src/pages/rag/cross-encoders/cross-encoders.tsx b/lib/user-interface/react-app/src/pages/rag/cross-encoders/cross-encoders.tsx index 59aac030b..e7bbd8442 100644 --- a/lib/user-interface/react-app/src/pages/rag/cross-encoders/cross-encoders.tsx +++ b/lib/user-interface/react-app/src/pages/rag/cross-encoders/cross-encoders.tsx @@ -255,6 +255,7 @@ export default function CrossEncoders() { onChange={({ detail: { selectedOption } }) => onChange({ crossEncoderModel: selectedOption }) } + empty={
No cross-encoder models found
} />