Skip to content

Commit

Permalink
Allow to deploy Kendra without SageMaker endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
spugachev committed Nov 2, 2023
1 parent db8411e commit 5ad09fa
Show file tree
Hide file tree
Showing 15 changed files with 94 additions and 44 deletions.
2 changes: 1 addition & 1 deletion bin/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ export function getConfig(): SystemConfig {
{
provider: "sagemaker",
name: "intfloat/multilingual-e5-large",
default: true,
dimensions: 1024,
},
{
Expand All @@ -51,6 +50,7 @@ export function getConfig(): SystemConfig {
provider: "bedrock",
name: "amazon.titan-embed-text-v1",
dimensions: 1536,
default: true,
},
{
provider: "openai",
Expand Down
4 changes: 4 additions & 0 deletions lib/aws-genai-llm-chatbot-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack {
restApi: chatBotApi.restApi,
webSocketApi: chatBotApi.webSocketApi,
chatbotFilesBucket: chatBotApi.filesBucket,
crossEncodersEnabled:
typeof ragEngines?.sageMakerRagModels?.model !== "undefined",
sagemakerEmbeddingsEnabled:
typeof ragEngines?.sageMakerRagModels?.model !== "undefined",
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,14 @@ class CrossEncodersRequest(BaseModel):
@router.get("/cross-encoders/models")
@tracer.capture_method
def models():
config = genai_core.parameters.get_config()
models = config["rag"]["crossEncoderModels"]
models = genai_core.cross_encoder.get_cross_encoder_models()

return {"ok": True, "data": models}


@router.post("/cross-encoders")
@tracer.capture_method
def cross_encoders():
config = genai_core.parameters.get_config()

data: dict = router.current_event.json_body
request = CrossEncodersRequest(**data)
selected_model = genai_core.cross_encoder.get_cross_encoder_model(
Expand Down
6 changes: 3 additions & 3 deletions lib/chatbot-api/functions/api-handler/routes/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ class EmbeddingsRequest(BaseModel):
@router.get("/embeddings/models")
@tracer.capture_method
def models():
config = genai_core.parameters.get_config()
models = config["rag"]["embeddingsModels"]
models = genai_core.embeddings.get_embeddings_models()

return {"ok": True, "data": models}

Expand All @@ -38,6 +37,7 @@ def embeddings():
if selected_model is None:
raise genai_core.types.CommonError("Model not found")

ret_value = genai_core.embeddings.generate_embeddings(selected_model, request.input)
ret_value = genai_core.embeddings.generate_embeddings(
selected_model, request.input)

return {"ok": True, "data": ret_value}
21 changes: 12 additions & 9 deletions lib/chatbot-api/rest-api.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import * as path from "path";
import * as cdk from "aws-cdk-lib";
import { SageMakerModelEndpoint, SystemConfig } from "../shared/types";
import { Construct } from "constructs";
import { RagEngines } from "../rag-engines";
import { Shared } from "../shared";
import * as apigateway from "aws-cdk-lib/aws-apigateway";
import * as cognito from "aws-cdk-lib/aws-cognito";
import * as dynamodb from "aws-cdk-lib/aws-dynamodb";
Expand All @@ -7,11 +12,6 @@ import * as iam from "aws-cdk-lib/aws-iam";
import * as lambda from "aws-cdk-lib/aws-lambda";
import * as logs from "aws-cdk-lib/aws-logs";
import * as ssm from "aws-cdk-lib/aws-ssm";
import { Construct } from "constructs";
import * as path from "path";
import { RagEngines } from "../rag-engines";
import { Shared } from "../shared";
import { SageMakerModelEndpoint, SystemConfig } from "../shared/types";

export interface RestApiProps {
readonly shared: Shared;
Expand Down Expand Up @@ -74,7 +74,8 @@ export class RestApi extends Construct {
DOCUMENTS_BY_COMPOUND_KEY_INDEX_NAME:
props.ragEngines?.documentsByCompountKeyIndexName ?? "",
SAGEMAKER_RAG_MODELS_ENDPOINT:
props.ragEngines?.sageMakerRagModelsEndpoint?.attrEndpointName ?? "",
props.ragEngines?.sageMakerRagModels?.model.endpoint
?.attrEndpointName ?? "",
DELETE_WORKSPACE_WORKFLOW_ARN:
props.ragEngines?.deleteWorkspaceWorkflow?.stateMachineArn ?? "",
CREATE_AURORA_WORKSPACE_WORKFLOW_ARN:
Expand Down Expand Up @@ -190,7 +191,9 @@ export class RestApi extends Construct {
new iam.PolicyStatement({
actions: ["kendra:Retrieve", "kendra:Query"],
resources: [
`arn:${cdk.Aws.PARTITION}:kendra:${item.region}:${cdk.Aws.ACCOUNT_ID}:index/${item.kendraId}`,
`arn:${cdk.Aws.PARTITION}:kendra:${
item.region ?? cdk.Aws.REGION
}:${cdk.Aws.ACCOUNT_ID}:index/${item.kendraId}`,
],
})
);
Expand All @@ -210,11 +213,11 @@ export class RestApi extends Construct {
props.ragEngines.deleteWorkspaceWorkflow.grantStartExecution(apiHandler);
}

if (props.ragEngines?.sageMakerRagModelsEndpoint) {
if (props.ragEngines?.sageMakerRagModels) {
apiHandler.addToRolePolicy(
new iam.PolicyStatement({
actions: ["sagemaker:InvokeEndpoint"],
resources: [props.ragEngines?.sageMakerRagModelsEndpoint.ref],
resources: [props.ragEngines.sageMakerRagModels.model.endpoint.ref],
})
);
}
Expand Down
11 changes: 8 additions & 3 deletions lib/model-interfaces/langchain/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ export class LangChainInterface extends Construct {
AURORA_DB_SECRET_ID: props.ragEngines?.auroraPgVector?.database?.secret
?.secretArn as string,
SAGEMAKER_RAG_MODELS_ENDPOINT:
props.ragEngines?.sageMakerRagModelsEndpoint?.attrEndpointName ?? "",
props.ragEngines?.sageMakerRagModels?.model.endpoint
?.attrEndpointName ?? "",
OPEN_SEARCH_COLLECTION_ENDPOINT:
props.ragEngines?.openSearchVector?.openSearchCollectionEndpoint ??
"",
Expand Down Expand Up @@ -126,11 +127,13 @@ export class LangChainInterface extends Construct {
if (props.ragEngines) {
props.ragEngines.workspacesTable.grantReadWriteData(requestHandler);
props.ragEngines.documentsTable.grantReadWriteData(requestHandler);
}

if (props.ragEngines?.sageMakerRagModels) {
requestHandler.addToRolePolicy(
new iam.PolicyStatement({
actions: ["sagemaker:InvokeEndpoint"],
resources: [props.ragEngines?.sageMakerRagModelsEndpoint.ref],
resources: [props.ragEngines.sageMakerRagModels.model.endpoint.ref],
})
);
}
Expand Down Expand Up @@ -162,7 +165,9 @@ export class LangChainInterface extends Construct {
new iam.PolicyStatement({
actions: ["kendra:Retrieve", "kendra:Query"],
resources: [
`arn:${cdk.Aws.PARTITION}:kendra:${item.region}:${cdk.Aws.ACCOUNT_ID}:index/${item.kendraId}`,
`arn:${cdk.Aws.PARTITION}:kendra:${
item.region ?? cdk.Aws.REGION
}:${cdk.Aws.ACCOUNT_ID}:index/${item.kendraId}`,
],
})
);
Expand Down
2 changes: 1 addition & 1 deletion lib/rag-engines/data-import/file-import-batch-job.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export class FileImportBatchJob extends Construct {
vpc: props.shared.vpc,
allocationStrategy: batch.AllocationStrategy.BEST_FIT,
maxvCpus: 4,
minvCpus: 4,
minvCpus: 0,
replaceComputeEnvironment: true,
updateTimeout: cdk.Duration.minutes(30),
updateToLatestImageVersion: true,
Expand Down
10 changes: 5 additions & 5 deletions lib/rag-engines/data-import/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import { FileImportWorkflow } from "./file-import-workflow";
import { WebsiteCrawlingWorkflow } from "./website-crawling-workflow";
import { OpenSearchVector } from "../opensearch-vector";
import { KendraRetrieval } from "../kendra-retrieval";
import { SageMakerRagModels } from "../sagemaker-rag-models";
import * as s3 from "aws-cdk-lib/aws-s3";
import * as sqs from "aws-cdk-lib/aws-sqs";
import * as lambda from "aws-cdk-lib/aws-lambda";
import * as logs from "aws-cdk-lib/aws-logs";
import * as ec2 from "aws-cdk-lib/aws-ec2";
import * as iam from "aws-cdk-lib/aws-iam";
import * as sagemaker from "aws-cdk-lib/aws-sagemaker";
import * as dynamodb from "aws-cdk-lib/aws-dynamodb";
import * as s3Notifications from "aws-cdk-lib/aws-s3-notifications";
import * as lambdaEventSources from "aws-cdk-lib/aws-lambda-event-sources";
Expand All @@ -29,7 +29,7 @@ export interface DataImportProps {
readonly ragDynamoDBTables: RagDynamoDBTables;
readonly openSearchVector?: OpenSearchVector;
readonly kendraRetrieval?: KendraRetrieval;
readonly sageMakerRagModelsEndpoint?: sagemaker.CfnEndpoint;
readonly sageMakerRagModels?: SageMakerRagModels;
readonly workspacesTable: dynamodb.Table;
readonly documentsTable: dynamodb.Table;
readonly workspacesByObjectTypeIndexName: string;
Expand Down Expand Up @@ -109,7 +109,7 @@ export class DataImport extends Construct {
processingBucket,
auroraDatabase: props.auroraDatabase,
ragDynamoDBTables: props.ragDynamoDBTables,
sageMakerRagModelsEndpoint: props.sageMakerRagModelsEndpoint,
sageMakerRagModelsEndpoint: props.sageMakerRagModels?.model.endpoint,
openSearchVector: props.openSearchVector,
}
);
Expand All @@ -134,7 +134,7 @@ export class DataImport extends Construct {
processingBucket,
auroraDatabase: props.auroraDatabase,
ragDynamoDBTables: props.ragDynamoDBTables,
sageMakerRagModelsEndpoint: props.sageMakerRagModelsEndpoint,
sageMakerRagModelsEndpoint: props.sageMakerRagModels?.model.endpoint,
openSearchVector: props.openSearchVector,
}
);
Expand Down Expand Up @@ -170,7 +170,7 @@ export class DataImport extends Construct {
DOCUMENTS_BY_COMPOUND_KEY_INDEX_NAME:
props.documentsByCompountKeyIndexName ?? "",
SAGEMAKER_RAG_MODELS_ENDPOINT:
props.sageMakerRagModelsEndpoint?.attrEndpointName ?? "",
props.sageMakerRagModels?.model.endpoint.attrEndpointName ?? "",
FILE_IMPORT_WORKFLOW_ARN:
fileImportWorkflow?.stateMachine.stateMachineArn ?? "",
DEFAULT_KENDRA_S3_DATA_SOURCE_BUCKET_NAME:
Expand Down
21 changes: 11 additions & 10 deletions lib/rag-engines/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import * as dynamodb from "aws-cdk-lib/aws-dynamodb";
import * as s3 from "aws-cdk-lib/aws-s3";
import * as sagemaker from "aws-cdk-lib/aws-sagemaker";
import * as sfn from "aws-cdk-lib/aws-stepfunctions";
import { Construct } from "constructs";
import { Shared } from "../shared";
Expand All @@ -22,13 +21,13 @@ export class RagEngines extends Construct {
public readonly auroraPgVector: AuroraPgVector | null;
public readonly openSearchVector: OpenSearchVector | null;
public readonly kendraRetrieval: KendraRetrieval | null;
public readonly sageMakerRagModels: SageMakerRagModels | null;
public readonly uploadBucket: s3.Bucket;
public readonly processingBucket: s3.Bucket;
public readonly documentsTable: dynamodb.Table;
public readonly workspacesTable: dynamodb.Table;
public readonly workspacesByObjectTypeIndexName: string;
public readonly documentsByCompountKeyIndexName: string;
public readonly sageMakerRagModelsEndpoint: sagemaker.CfnEndpoint;
public readonly fileImportWorkflow?: sfn.StateMachine;
public readonly websiteCrawlingWorkflow?: sfn.StateMachine;
public readonly deleteWorkspaceWorkflow?: sfn.StateMachine;
Expand All @@ -38,14 +37,16 @@ export class RagEngines extends Construct {

const tables = new RagDynamoDBTables(this, "RagDynamoDBTables");

const sageMakerRagModels = new SageMakerRagModels(
this,
"SageMaker",
{
let sageMakerRagModels: SageMakerRagModels | null = null;
if (
props.config.rag.engines.aurora.enabled ||
props.config.rag.engines.opensearch.enabled
) {
sageMakerRagModels = new SageMakerRagModels(this, "SageMaker", {
shared: props.shared,
config: props.config,
}
);
});
}

let auroraPgVector: AuroraPgVector | null = null;
if (props.config.rag.engines.aurora.enabled) {
Expand Down Expand Up @@ -78,7 +79,7 @@ export class RagEngines extends Construct {
shared: props.shared,
config: props.config,
auroraDatabase: auroraPgVector?.database,
sageMakerRagModelsEndpoint: sageMakerRagModels.model.endpoint,
sageMakerRagModels: sageMakerRagModels ?? undefined,
workspacesTable: tables.workspacesTable,
documentsTable: tables.documentsTable,
ragDynamoDBTables: tables,
Expand All @@ -101,7 +102,7 @@ export class RagEngines extends Construct {
this.auroraPgVector = auroraPgVector;
this.openSearchVector = openSearchVector;
this.kendraRetrieval = kendraRetrieval;
this.sageMakerRagModelsEndpoint = sageMakerRagModels.model.endpoint;
this.sageMakerRagModels = sageMakerRagModels;
this.uploadBucket = dataImport.uploadBucket;
this.processingBucket = dataImport.processingBucket;
this.workspacesTable = tables.workspacesTable;
Expand Down
11 changes: 11 additions & 0 deletions lib/shared/layers/python-sdk/python/genai_core/cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ def rank_passages(
raise genai_core.typesCommonError(f"Unknown provider")


def get_cross_encoder_models():
config = genai_core.parameters.get_config()
models = config["rag"]["crossEncoderModels"]

if not SAGEMAKER_RAG_MODELS_ENDPOINT:
models = list(
filter(lambda x: x["provider"] != "sagemaker", models))

return models


def get_cross_encoder_model(
provider: str, name: str
) -> Optional[genai_core.types.CrossEncoderModel]:
Expand Down
14 changes: 13 additions & 1 deletion lib/shared/layers/python-sdk/python/genai_core/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def generate_embeddings(
input = list(map(lambda x: x[:10000], input))

ret_value = []
batch_split = [input[i : i + batch_size] for i in range(0, len(input), batch_size)]
batch_split = [input[i: i + batch_size]
for i in range(0, len(input), batch_size)]

for batch in batch_split:
if model.provider == "openai":
Expand All @@ -33,6 +34,17 @@ def generate_embeddings(
return ret_value


def get_embeddings_models():
config = genai_core.parameters.get_config()
models = config["rag"]["embeddingsModels"]

if not SAGEMAKER_RAG_MODELS_ENDPOINT:
models = list(
filter(lambda x: x["provider"] != "sagemaker", models))

return models


def get_embeddings_model(
provider: str, name: str
) -> Optional[genai_core.types.EmbeddingsModel]:
Expand Down
4 changes: 4 additions & 0 deletions lib/user-interface/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ export interface UserInterfaceProps {
readonly restApi: apigateway.RestApi;
readonly webSocketApi: apigwv2.WebSocketApi;
readonly chatbotFilesBucket: s3.Bucket;
readonly crossEncodersEnabled: boolean;
readonly sagemakerEmbeddingsEnabled: boolean;
}

export class UserInterface extends Construct {
Expand Down Expand Up @@ -172,6 +174,8 @@ export class UserInterface extends Construct {
api_endpoint: `https://${distribution.distributionDomainName}/api`,
websocket_endpoint: `wss://${distribution.distributionDomainName}/socket`,
rag_enabled: props.config.rag.enabled,
cross_encoders_enabled: props.crossEncodersEnabled,
sagemaker_embeddings_enabled: props.sagemakerEmbeddingsEnabled,
default_embeddings_model: Utils.getDefaultEmbeddingsModel(props.config),
default_cross_encoder_model: Utils.getDefaultCrossEncoderModel(
props.config
Expand Down
3 changes: 2 additions & 1 deletion lib/user-interface/react-app/src/common/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ export interface AppConfig {
name: CognitoHostedUIIdentityProvider;
};
rag_enabled: boolean;
cross_encoders_enabled: boolean;
sagemaker_embeddings_enabled: boolean;
api_endpoint: string;
websocket_endpoint: string;
default_embeddings_model: string;
Expand All @@ -27,7 +29,6 @@ export interface AppConfig {
region: string;
};
};

}

export interface NavigationPanelState {
Expand Down
Loading

0 comments on commit 5ad09fa

Please sign in to comment.