diff --git a/modules/sagemaker/sagemaker-endpoint/README.md b/modules/sagemaker/sagemaker-endpoint/README.md new file mode 100644 index 00000000..da526111 --- /dev/null +++ b/modules/sagemaker/sagemaker-endpoint/README.md @@ -0,0 +1,76 @@ +# SageMaker Model Endpoint + +## Description + +This is an example module that creates SageMaker real-time inference endpoint for a model. + +## Inputs/Outputs + +### Input Paramenters + +#### Required + +- `vpc-id`: The VPC-ID that the endpoint will be created in +- `subnet-ids`: The subnets that the endpoint will be created in +- `model-package-arn`: Model package ARN or +- `model-package-group-name`: Model package group name to pull latest approved model from +- `model-bucket-arn`: Model bucket ARN +- +#### Optional + +- `sagemaker-project-id`: The VPC-ID that the endpoint will be created in +- `sagemaker-project-name`: The subnets that the endpoint will be created in\ +- `model-execution-role-arn`: Model execution role ARN. Will be created if not provided. +- `ecr-repo-arn`: ECR repository ARN if custom container is used +- `variant-name`: Endpoint config production variant name. `AllTraffic` by default. +- `initial-instance-count`: Initial instance count. `1` by default. +- `initial-variant-weight`: Initial variant weight. `1` by default. +- `instance-type`: instance type. `ml.m4.xlarge` by default. + +### Sample manifest declaration + +```yaml +name: endpoint +path: modules/sagemaker/sagemaker-endpoint +parameters: + - name: sagemaker_project_id + value: dummy123 + - name: sagemaker_project_name + value: dummy123 + - name: model_package_arn + value: arn:aws:sagemaker:::model-package//1 + - name: model_bucket_arn + value: arn:aws:s3::: + - name: instance_type + value: ml.m5.large + - name: vpc_id + valueFrom: + moduleMetadata: + group: networking + name: networking + key: VpcId + - name: subnet_ids + valueFrom: + moduleMetadata: + group: networking + name: networking + key: PrivateSubnetIds +``` + +### Module Metadata Outputs + +- `ModelExecutionRoleArn`: Model execution role ARN +- `ModelName`: Model name +- `EndpointName`: Endpoint name +- `EndpointUrl` Endpoint Url + +#### Output Example + +```json +{ + "ModelExecutionRoleArn": "arn:aws:iam::xxx:role/xxx", + "ModelName": "xxx", + "EndpointName": "xxx-endpoint", + "EndpointUrl": "xxx-endpoint" +} +``` diff --git a/modules/sagemaker/sagemaker-endpoint/app.py b/modules/sagemaker/sagemaker-endpoint/app.py new file mode 100644 index 00000000..5f7b8d00 --- /dev/null +++ b/modules/sagemaker/sagemaker-endpoint/app.py @@ -0,0 +1,95 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import json +import os + +import aws_cdk +from stack import DeployEndpointStack + + +def _param(name: str) -> str: + return f"SEEDFARMER_PARAMETER_{name}" + + +project_name = os.getenv("SEEDFARMER_PROJECT_NAME", "") +deployment_name = os.getenv("SEEDFARMER_DEPLOYMENT_NAME", "") +module_name = os.getenv("SEEDFARMER_MODULE_NAME", "") +app_prefix = f"{project_name}-{deployment_name}-{module_name}" + +DEFAULT_SAGEMAKER_PROJECT_ID = None +DEFAULT_SAGEMAKER_PROJECT_NAME = None +DEFAULT_MODEL_PACKAGE_ARN = None +DEFAULT_MODEL_PACKAGE_GROUP_NAME = None +DEFAULT_MODEL_EXECUTION_ROLE_ARN = None +DEFAULT_MODEL_BUCKET_ARN = None +DEFAULT_ECR_REPO_ARN = None +DEFAULT_VARIANT_NAME = "AllTraffic" +DEFAULT_INITIAL_INSTANCE_COUNT = 1 +DEFAULT_INITIAL_VARIANT_WEIGHT = 1 +DEFAULT_INSTANCE_TYPE = "ml.m4.xlarge" + +environment = aws_cdk.Environment( + account=os.environ["CDK_DEFAULT_ACCOUNT"], + region=os.environ["CDK_DEFAULT_REGION"], +) + +vpc_id = os.getenv(_param("VPC_ID")) +subnet_ids = json.loads(os.getenv(_param("SUBNET_IDS"), "[]")) +sagemaker_project_id = os.getenv(_param("SAGEMAKER_PROJECT_ID"), DEFAULT_SAGEMAKER_PROJECT_ID) +sagemaker_project_name = os.getenv(_param("SAGEMAKER_PROJECT_NAME"), DEFAULT_SAGEMAKER_PROJECT_NAME) +model_package_arn = os.getenv(_param("MODEL_PACKAGE_ARN"), DEFAULT_MODEL_PACKAGE_ARN) +model_package_group_name = os.getenv(_param("MODEL_PACKAGE_GROUP_NAME"), DEFAULT_MODEL_PACKAGE_GROUP_NAME) +model_execution_role_arn = os.getenv(_param("MODEL_EXECUTION_ROLE_ARN"), DEFAULT_MODEL_EXECUTION_ROLE_ARN) +model_bucket_arn = os.getenv(_param("MODEL_BUCKET_ARN"), DEFAULT_MODEL_BUCKET_ARN) +ecr_repo_arn = os.getenv(_param("ECR_REPO_ARN"), DEFAULT_ECR_REPO_ARN) +variant_name = os.getenv(_param("VARIANT_NAME"), DEFAULT_VARIANT_NAME) +initial_instance_count = int(os.getenv(_param("INITIAL_INSTANCE_COUNT"), DEFAULT_INITIAL_INSTANCE_COUNT)) +initial_variant_weight = int(os.getenv(_param("INITIAL_VARIANT_WEIGHT"), DEFAULT_INITIAL_VARIANT_WEIGHT)) +instance_type = os.getenv(_param("INSTANCE_TYPE"), DEFAULT_INSTANCE_TYPE) + +if not vpc_id: + raise ValueError("Missing input parameter vpc-id") + +if not model_package_arn and not model_package_group_name: + raise ValueError("Parameter model-package-arn or model-package-group-name is required") + + +app = aws_cdk.App() +stack = DeployEndpointStack( + scope=app, + id=app_prefix, + app_prefix=app_prefix, + sagemaker_project_id=sagemaker_project_id, + sagemaker_project_name=sagemaker_project_name, + model_package_arn=model_package_arn, + model_package_group_name=model_package_group_name, + model_execution_role_arn=model_execution_role_arn, + vpc_id=vpc_id, + subnet_ids=subnet_ids, + model_bucket_arn=model_bucket_arn, + ecr_repo_arn=ecr_repo_arn, + endpoint_config_prod_variant=dict( + initial_instance_count=initial_instance_count, + initial_variant_weight=initial_variant_weight, + instance_type=instance_type, + variant_name=variant_name, + ), + env=environment, +) + +aws_cdk.CfnOutput( + scope=stack, + id="metadata", + value=stack.to_json_string( + { + "ModelExecutionRoleArn": stack.model_execution_role_arn, + "ModelName": stack.model.model_name, + "EndpointName": stack.endpoint.endpoint_name, + "EndpointUrl": stack.endpoint.endpoint_name, + } + ), +) + + +app.synth() diff --git a/modules/sagemaker/sagemaker-endpoint/deployspec.yaml b/modules/sagemaker/sagemaker-endpoint/deployspec.yaml new file mode 100644 index 00000000..def0bc53 --- /dev/null +++ b/modules/sagemaker/sagemaker-endpoint/deployspec.yaml @@ -0,0 +1,23 @@ +publishGenericEnvVariables: true +deploy: + phases: + install: + commands: + - env + # Install whatever additional build libraries + - npm install -g aws-cdk@2.91.0 + - pip install -r requirements.txt + build: + commands: + - cdk deploy --require-approval never --progress events --app "python app.py" --outputs-file ./cdk-exports.json +destroy: + phases: + install: + commands: + # Install whatever additional build libraries + - npm install -g aws-cdk@2.91.0 + - pip install -r requirements.txt + build: + commands: + # execute the CDK + - cdk destroy --force --app "python app.py" \ No newline at end of file diff --git a/modules/sagemaker/sagemaker-endpoint/modulestack.yaml b/modules/sagemaker/sagemaker-endpoint/modulestack.yaml new file mode 100644 index 00000000..34f6603f --- /dev/null +++ b/modules/sagemaker/sagemaker-endpoint/modulestack.yaml @@ -0,0 +1,26 @@ +AWSTemplateFormatVersion: 2010-09-09 +Description: This template deploys a Module specific IAM permissions + +Parameters: + RoleName: + Type: String + Description: The name of the IAM Role + ModelPackageGroupName: + Type: String + Description: The name of the SageMaker Model Package Group + Default: NotPopulated + +Resources: + Policy: + Type: AWS::IAM::Policy + Properties: + PolicyName: "sagemaker-endpoint-modulespecific-policy" + Roles: [!Ref RoleName] + PolicyDocument: + Version: 2012-10-17 + Statement: + - Effect: Allow + Action: + - sagemaker:ListModelPackages + Resource: + - !Sub arn:${AWS::Partition}:sagemaker:${AWS::Region}:${AWS::AccountId}:model-package/${ModelPackageGroupName}/* diff --git a/modules/sagemaker/sagemaker-endpoint/pyproject.toml b/modules/sagemaker/sagemaker-endpoint/pyproject.toml new file mode 100644 index 00000000..ca09c986 --- /dev/null +++ b/modules/sagemaker/sagemaker-endpoint/pyproject.toml @@ -0,0 +1,30 @@ +[tool.black] +line-length = 120 +target-version = ["py36", "py37", "py38"] +exclude = ''' +/( + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | \.env + | _build + | buck-out + | build + | dist + | codeseeder.out +)/ +''' + +[tool.isort] +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true +ensure_newline_before_comments = true +line_length = 120 +src_paths = ["sagemaker-model-deploy"] +py_version = 36 +skip_gitignore = false \ No newline at end of file diff --git a/modules/sagemaker/sagemaker-endpoint/requirements-dev.in b/modules/sagemaker/sagemaker-endpoint/requirements-dev.in new file mode 100644 index 00000000..6412f987 --- /dev/null +++ b/modules/sagemaker/sagemaker-endpoint/requirements-dev.in @@ -0,0 +1,14 @@ +awscli +black +cdk-nag +cfn-lint +check-manifest +flake8 +isort +mypy +pip-tools +pydot +pyroma +pytest +types-PyYAML +types-setuptools diff --git a/modules/sagemaker/sagemaker-endpoint/requirements-dev.txt b/modules/sagemaker/sagemaker-endpoint/requirements-dev.txt new file mode 100644 index 00000000..d9e66211 --- /dev/null +++ b/modules/sagemaker/sagemaker-endpoint/requirements-dev.txt @@ -0,0 +1,273 @@ +# +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: +# +# pip-compile requirements-dev.in +# +aiohttp==3.9.1 + # via black +aiosignal==1.3.1 + # via aiohttp +annotated-types==0.6.0 + # via pydantic +async-timeout==4.0.3 + # via aiohttp +attrs==23.1.0 + # via + # aiohttp + # cattrs + # jschema-to-python + # jsii + # jsonschema + # referencing + # sarif-om +aws-cdk-asset-awscli-v1==2.2.201 + # via aws-cdk-lib +aws-cdk-asset-kubectl-v20==2.1.2 + # via aws-cdk-lib +aws-cdk-asset-node-proxy-agent-v6==2.0.1 + # via aws-cdk-lib +aws-cdk-lib==2.115.0 + # via cdk-nag +aws-sam-translator==1.82.0 + # via cfn-lint +awscli==1.32.0 + # via -r requirements-dev.in +black==23.12.0 + # via -r requirements-dev.in +boto3==1.34.0 + # via aws-sam-translator +botocore==1.34.0 + # via + # awscli + # boto3 + # s3transfer +build==1.0.3 + # via + # check-manifest + # pip-tools + # pyroma +cattrs==23.2.3 + # via jsii +cdk-nag==2.27.216 + # via -r requirements-dev.in +certifi==2023.11.17 + # via requests +cfn-lint==0.83.5 + # via -r requirements-dev.in +charset-normalizer==3.3.2 + # via requests +check-manifest==0.49 + # via -r requirements-dev.in +click==8.1.7 + # via + # black + # pip-tools +colorama==0.4.4 + # via awscli +constructs==10.3.0 + # via + # aws-cdk-lib + # cdk-nag +docutils==0.16 + # via + # awscli + # pyroma +exceptiongroup==1.2.0 + # via + # cattrs + # pytest +flake8==6.1.0 + # via -r requirements-dev.in +frozenlist==1.4.0 + # via + # aiohttp + # aiosignal +idna==3.6 + # via + # requests + # yarl +importlib-metadata==7.0.0 + # via build +importlib-resources==6.1.1 + # via jsii +iniconfig==2.0.0 + # via pytest +isort==5.13.2 + # via -r requirements-dev.in +jmespath==1.0.1 + # via + # boto3 + # botocore +jschema-to-python==1.2.3 + # via cfn-lint +jsii==1.93.0 + # via + # aws-cdk-asset-awscli-v1 + # aws-cdk-asset-kubectl-v20 + # aws-cdk-asset-node-proxy-agent-v6 + # aws-cdk-lib + # cdk-nag + # constructs +jsonpatch==1.33 + # via cfn-lint +jsonpickle==3.0.2 + # via jschema-to-python +jsonpointer==2.4 + # via jsonpatch +jsonschema==4.20.0 + # via + # aws-sam-translator + # cfn-lint +jsonschema-specifications==2023.11.2 + # via jsonschema +junit-xml==1.9 + # via cfn-lint +mccabe==0.7.0 + # via flake8 +mpmath==1.3.0 + # via sympy +multidict==6.0.4 + # via + # aiohttp + # yarl +mypy==1.7.1 + # via -r requirements-dev.in +mypy-extensions==1.0.0 + # via + # black + # mypy +networkx==3.2.1 + # via cfn-lint +packaging==23.2 + # via + # black + # build + # pyroma + # pytest +pathspec==0.12.1 + # via black +pbr==6.0.0 + # via + # jschema-to-python + # sarif-om +pip-tools==7.3.0 + # via -r requirements-dev.in +platformdirs==4.1.0 + # via black +pluggy==1.3.0 + # via pytest +publication==0.0.3 + # via + # aws-cdk-asset-awscli-v1 + # aws-cdk-asset-kubectl-v20 + # aws-cdk-asset-node-proxy-agent-v6 + # aws-cdk-lib + # cdk-nag + # constructs + # jsii +pyasn1==0.5.1 + # via rsa +pycodestyle==2.11.1 + # via flake8 +pydantic==2.5.2 + # via aws-sam-translator +pydantic-core==2.14.5 + # via pydantic +pydot==1.4.2 + # via -r requirements-dev.in +pyflakes==3.1.0 + # via flake8 +pygments==2.17.2 + # via pyroma +pyparsing==3.1.1 + # via pydot +pyproject-hooks==1.0.0 + # via build +pyroma==4.2 + # via -r requirements-dev.in +pytest==7.4.3 + # via -r requirements-dev.in +python-dateutil==2.8.2 + # via + # botocore + # jsii +pyyaml==6.0.1 + # via + # awscli + # cfn-lint +referencing==0.32.0 + # via + # jsonschema + # jsonschema-specifications +regex==2023.10.3 + # via cfn-lint +requests==2.31.0 + # via pyroma +rpds-py==0.13.2 + # via + # jsonschema + # referencing +rsa==4.7.2 + # via awscli +s3transfer==0.9.0 + # via + # awscli + # boto3 +sarif-om==1.0.4 + # via cfn-lint +six==1.16.0 + # via + # junit-xml + # python-dateutil +sympy==1.12 + # via cfn-lint +tomli==2.0.1 + # via + # black + # build + # check-manifest + # mypy + # pip-tools + # pyproject-hooks + # pytest +trove-classifiers==2023.11.29 + # via pyroma +typeguard==2.13.3 + # via + # aws-cdk-asset-awscli-v1 + # aws-cdk-asset-kubectl-v20 + # aws-cdk-asset-node-proxy-agent-v6 + # aws-cdk-lib + # cdk-nag + # constructs + # jsii +types-pyyaml==6.0.12.12 + # via -r requirements-dev.in +types-setuptools==69.0.0.0 + # via -r requirements-dev.in +typing-extensions==4.9.0 + # via + # aws-sam-translator + # black + # cattrs + # jsii + # mypy + # pydantic + # pydantic-core +urllib3==1.26.18 + # via + # botocore + # requests +wheel==0.42.0 + # via pip-tools +yarl==1.9.4 + # via aiohttp +zipp==3.17.0 + # via + # importlib-metadata + # importlib-resources + +# The following packages are considered to be unsafe in a requirements file: +# pip +# setuptools diff --git a/modules/sagemaker/sagemaker-endpoint/requirements.txt b/modules/sagemaker/sagemaker-endpoint/requirements.txt new file mode 100644 index 00000000..790a706e --- /dev/null +++ b/modules/sagemaker/sagemaker-endpoint/requirements.txt @@ -0,0 +1,4 @@ +aws-cdk-lib==2.91.0 +cdk-nag==2.12.29 +yamldataclassconfig +boto3 \ No newline at end of file diff --git a/modules/sagemaker/sagemaker-endpoint/scripts/__init__.py b/modules/sagemaker/sagemaker-endpoint/scripts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modules/sagemaker/sagemaker-endpoint/scripts/get_approved_package.py b/modules/sagemaker/sagemaker-endpoint/scripts/get_approved_package.py new file mode 100644 index 00000000..4790d311 --- /dev/null +++ b/modules/sagemaker/sagemaker-endpoint/scripts/get_approved_package.py @@ -0,0 +1,56 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from logging import Logger +from typing import Optional + +import boto3 +from botocore.exceptions import ClientError + +logger = Logger(name="get_approved_package") + + +def get_approved_package(region_name: str, model_package_group_name: str) -> Optional[str]: + """Gets the latest approved model package for a model package group. + Returns: + The SageMaker Model Package ARN. + """ + sm_client = boto3.client("sagemaker", region_name=region_name) + + try: + # Get the latest approved model package + response = sm_client.list_model_packages( + ModelPackageGroupName=model_package_group_name, + ModelApprovalStatus="Approved", + SortBy="CreationTime", + MaxResults=100, + ) + approved_packages = response["ModelPackageSummaryList"] + + # Fetch more packages if none returned with continuation token + while len(approved_packages) == 0 and "NextToken" in response: + logger.debug(f"Getting more packages for token: {response['NextToken']}") + response = sm_client.list_model_packages( + ModelPackageGroupName=model_package_group_name, + ModelApprovalStatus="Approved", + SortBy="CreationTime", + MaxResults=100, + NextToken=response["NextToken"], + ) + approved_packages.extend(response["ModelPackageSummaryList"]) + + # Return None if no packages found + if len(approved_packages) == 0: + error_message = f"No approved ModelPackage found for ModelPackageGroup: {model_package_group_name}" + logger.warn(error_message) + return None + + # Return the pmodel package arn + model_package_arn = approved_packages[0]["ModelPackageArn"] + logger.info(f"Identified the latest approved model package: {model_package_arn}") + return model_package_arn + + except ClientError as e: + error_message = e.response["Error"]["Message"] + logger.error(error_message) + raise Exception(error_message) diff --git a/modules/sagemaker/sagemaker-endpoint/setup.py b/modules/sagemaker/sagemaker-endpoint/setup.py new file mode 100644 index 00000000..44552496 --- /dev/null +++ b/modules/sagemaker/sagemaker-endpoint/setup.py @@ -0,0 +1,24 @@ +import setuptools + +with open("README.md") as fp: + long_description = fp.read() + +setuptools.setup( + version="0.1.0", + description="A short description of the module.", + long_description=long_description, + long_description_content_type="text/markdown", + author="author", + install_requires=[ + "aws-cdk-lib==2.20.0", + ], + python_requires=">=3.6", + classifiers=[ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.6", + "Typing :: Typed", + ], +) diff --git a/modules/sagemaker/sagemaker-endpoint/stack.py b/modules/sagemaker/sagemaker-endpoint/stack.py new file mode 100644 index 00000000..0a80ce71 --- /dev/null +++ b/modules/sagemaker/sagemaker-endpoint/stack.py @@ -0,0 +1,168 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from datetime import datetime, timezone +from typing import Any, List, Optional, TypedDict + +import constructs +from aws_cdk import Aspects, Stack, Tags +from aws_cdk import aws_ec2 as ec2 +from aws_cdk import aws_iam as iam +from aws_cdk import aws_kms as kms +from aws_cdk import aws_s3 as s3 +from aws_cdk import aws_sagemaker as sagemaker +from cdk_nag import AwsSolutionsChecks, NagPackSuppression, NagSuppressions +from scripts.get_approved_package import get_approved_package +from typing_extensions import NotRequired, Required + + +class EndpointConfigProductionVariant(TypedDict): + variant_name: Required[str] + initial_instance_count: NotRequired[float] + initial_variant_weight: Required[float] + instance_type: NotRequired[str] + + +def get_timestamp() -> str: + now = datetime.now().replace(tzinfo=timezone.utc) + return now.strftime("%Y%m%d%H%M%S") + + +class DeployEndpointStack(Stack): + def __init__( + self, + scope: constructs.Construct, + id: str, + app_prefix: str, + sagemaker_project_id: Optional[str], + sagemaker_project_name: Optional[str], + model_package_arn: Optional[str], + model_package_group_name: Optional[str], + model_execution_role_arn: Optional[str], + vpc_id: str, + subnet_ids: List[str], + model_bucket_arn: Optional[str], + ecr_repo_arn: Optional[str], + endpoint_config_prod_variant: EndpointConfigProductionVariant, + **kwargs: Any, + ) -> None: + super().__init__(scope, id, **kwargs) + + if sagemaker_project_id: + Tags.of(self).add("sagemaker:project-id", sagemaker_project_id) + if sagemaker_project_name: + Tags.of(self).add("sagemaker:project-name", sagemaker_project_name) + + # Import VPC, create security group, and add ingress rule + vpc = ec2.Vpc.from_lookup(self, f"{app_prefix}-vpc", vpc_id=vpc_id) + security_group = ec2.SecurityGroup(self, f"{app_prefix}-sg", vpc=vpc, allow_all_outbound=True) + security_group.add_ingress_rule(peer=ec2.Peer.ipv4(vpc.vpc_cidr_block), connection=ec2.Port.all_tcp()) + + if not model_execution_role_arn: + # Create model execution role + model_execution_role = iam.Role( + self, + f"{app_prefix}-model-exec-role", + assumed_by=iam.ServicePrincipal("sagemaker.amazonaws.com"), + managed_policies=[ + iam.ManagedPolicy.from_aws_managed_policy_name("AmazonSageMakerFullAccess"), + ], + ) + + if model_bucket_arn: + # Grant model assets bucket read-write permissions + model_bucket = s3.Bucket.from_bucket_arn(self, f"{app_prefix}-model-bucket", model_bucket_arn) + model_bucket.grant_read_write(model_execution_role) + + if ecr_repo_arn: + # Add ECR permissions + model_execution_role.add_to_policy( + iam.PolicyStatement( + actions=["ecr:Get*"], + effect=iam.Effect.ALLOW, + resources=[ecr_repo_arn], + ) + ) + + model_execution_role_arn: str = model_execution_role.role_arn + + self.model_execution_role_arn = model_execution_role_arn + + if not model_package_arn: + # Get latest approved model package from the model registry + model_package_arn = get_approved_package(self.region, model_package_group_name) + + # Create model instance + model_name = f"{app_prefix}-{get_timestamp()}" + model = sagemaker.CfnModel( + self, + f"{app_prefix}-model", + execution_role_arn=model_execution_role_arn, + model_name=model_name, + containers=[sagemaker.CfnModel.ContainerDefinitionProperty(model_package_name=model_package_arn)], + vpc_config=sagemaker.CfnModel.VpcConfigProperty( + security_group_ids=[security_group.security_group_id], + subnets=subnet_ids, + ), + ) + self.model = model + + # Create kms key to be used by the endpoint assets bucket + kms_key = kms.Key( + self, + f"{app_prefix}-endpoint-key", + description="Key used for encryption of data in Amazon SageMaker Endpoint", + enable_key_rotation=True, + ) + kms_key.grant_encrypt_decrypt(iam.AccountRootPrincipal()) + + # Create endpoint config + endpoint_config_name = f"{app_prefix}-endpoint-config" + endpoint_config_prod_variant = endpoint_config_prod_variant or {} + endpoint_config = sagemaker.CfnEndpointConfig( + self, + f"{app_prefix}-endpoint-config", + endpoint_config_name=endpoint_config_name, + kms_key_id=kms_key.key_id, + production_variants=[ + sagemaker.CfnEndpointConfig.ProductionVariantProperty( + model_name=model_name, + **endpoint_config_prod_variant, + ) + ], + ) + endpoint_config.add_depends_on(model) + + # Create endpoint + endpoint_name = f"{app_prefix}-endpoint" + endpoint = sagemaker.CfnEndpoint( + self, + "Endpoint", + endpoint_config_name=endpoint_config.endpoint_config_name, + endpoint_name=endpoint_name, + ) + endpoint.add_depends_on(endpoint_config) + self.endpoint = endpoint + + # Add CDK nag solutions checks + Aspects.of(self).add(AwsSolutionsChecks()) + + NagSuppressions.add_stack_suppressions( + self, + suppressions=[ + NagPackSuppression( + id="AwsSolutions-IAM4", + reason="Managed Policies are for service account roles only.", + ) + ], + ) + + NagSuppressions.add_stack_suppressions( + self, + suppressions=[ + NagPackSuppression( + id="AwsSolutions-IAM5", + reason="Model execution role requires s3 permissions to the bucket.", + ) + ], + ) diff --git a/modules/sagemaker/sagemaker-endpoint/tests/__init__.py b/modules/sagemaker/sagemaker-endpoint/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modules/sagemaker/sagemaker-endpoint/tests/test_app.py b/modules/sagemaker/sagemaker-endpoint/tests/test_app.py new file mode 100644 index 00000000..c73e3793 --- /dev/null +++ b/modules/sagemaker/sagemaker-endpoint/tests/test_app.py @@ -0,0 +1,34 @@ +import os +import sys + +import pytest + + +@pytest.fixture(scope="function") +def stack_defaults(): + os.environ["SEEDFARMER_PROJECT_NAME"] = "test-project" + os.environ["SEEDFARMER_DEPLOYMENT_NAME"] = "test-deployment" + os.environ["SEEDFARMER_MODULE_NAME"] = "test-module" + os.environ["CDK_DEFAULT_ACCOUNT"] = "111111111111" + os.environ["CDK_DEFAULT_REGION"] = "us-east-1" + + os.environ["SEEDFARMER_PARAMETER_VPC_ID"] = "vpc-12345" + os.environ["SEEDFARMER_PARAMETER_SAGEMAKER_PROJECT_ID"] = "12345" + os.environ["SEEDFARMER_PARAMETER_SAGEMAKER_PROJECT_NAME"] = "sagemaker-project" + os.environ["SEEDFARMER_PARAMETER_MODEL_PACKAGE_ARN"] = "example-arn" + os.environ["SEEDFARMER_PARAMETER_MODEL_BUCKET_ARN"] = "arn:aws:s3:::test-bucket" + + # Unload the app import so that subsequent tests don't reuse + if "app" in sys.modules: + del sys.modules["app"] + + +def test_app(stack_defaults): + import app # noqa: F401 + + +def test_vpc_id(stack_defaults): + del os.environ["SEEDFARMER_PARAMETER_VPC_ID"] + + with pytest.raises(Exception): + import app # noqa: F401 diff --git a/modules/sagemaker/sagemaker-endpoint/tests/test_stack.py b/modules/sagemaker/sagemaker-endpoint/tests/test_stack.py new file mode 100644 index 00000000..0c823d68 --- /dev/null +++ b/modules/sagemaker/sagemaker-endpoint/tests/test_stack.py @@ -0,0 +1,129 @@ +import os +import sys +from unittest import mock + +import aws_cdk as cdk +import botocore.session +import pytest +from aws_cdk.assertions import Template +from botocore.stub import Stubber + + +@pytest.fixture(scope="function") +def stack_defaults() -> None: + os.environ["CDK_DEFAULT_ACCOUNT"] = "111111111111" + os.environ["CDK_DEFAULT_REGION"] = "us-east-1" + + # Unload the app import so that subsequent tests don't reuse + if "stack" in sys.modules: + del sys.modules["stack"] + + +def test_synthesize_stack_model_package_input(stack_defaults) -> None: + import stack + + app = cdk.App() + + project_name = "test-project" + dep_name = "test-deployment" + mod_name = "test-module" + app_prefix = f"{project_name}-{dep_name}-{mod_name}" + + sagemaker_project_id = "12345" + sagemaker_project_name = "sagemaker-project" + vpc_id = "vpc-12345" + model_package_arn = "example-arn" + model_bucket_arn = "arn:aws:s3:::test-bucket" + + endpoint_stack = stack.DeployEndpointStack( + scope=app, + id=app_prefix, + app_prefix=app_prefix, + sagemaker_project_id=sagemaker_project_id, + sagemaker_project_name=sagemaker_project_name, + model_package_arn=model_package_arn, + model_package_group_name=None, + model_execution_role_arn=None, + vpc_id=vpc_id, + subnet_ids=[], + model_bucket_arn=model_bucket_arn, + ecr_repo_arn=None, + env=cdk.Environment( + account=os.environ["CDK_DEFAULT_ACCOUNT"], + region=os.environ["CDK_DEFAULT_REGION"], + ), + endpoint_config_prod_variant=dict( + variant_name="AllTraffic", + initial_variant_weight=1, + ), + ) + + template = Template.from_stack(endpoint_stack) + template.resource_count_is("AWS::SageMaker::Endpoint", 1) + + +@mock.patch("scripts.get_approved_package.boto3.client") +def test_synthesize_stack_latest_approved_model_package(mock_s3_client, stack_defaults) -> None: + import stack + + app = cdk.App() + + project_name = "test-project" + dep_name = "test-deployment" + mod_name = "test-module" + app_prefix = f"{project_name}-{dep_name}-{mod_name}" + + sagemaker_project_id = "12345" + sagemaker_project_name = "sagemaker-project" + vpc_id = "vpc-12345" + model_package_group_name = "example-group" + model_bucket_arn = "arn:aws:s3:::test-bucket" + dev_account = "111111111111" + + sagemaker_client = botocore.session.get_session().create_client("sagemaker", region_name="us-east-1") + mock_s3_client.return_value = sagemaker_client + + with Stubber(sagemaker_client) as stubber: + expected_params = { + "ModelPackageGroupName": model_package_group_name, + "ModelApprovalStatus": "Approved", + "SortBy": "CreationTime", + "MaxResults": 100, + } + response = { + "ModelPackageSummaryList": [ + { + "ModelPackageArn": f"arn:aws:sagemaker:us-east-1:{dev_account}:model-package/{model_package_group_name}/1", + "ModelPackageStatus": "Completed", + "ModelPackageName": model_package_group_name, + "CreationTime": "2021-01-01T00:00:00Z", + }, + ], + } + stubber.add_response("list_model_packages", response, expected_params) + + endpoint_stack = stack.DeployEndpointStack( + scope=app, + id=app_prefix, + app_prefix=app_prefix, + sagemaker_project_id=sagemaker_project_id, + sagemaker_project_name=sagemaker_project_name, + model_package_arn=None, + model_package_group_name=model_package_group_name, + model_execution_role_arn=None, + vpc_id=vpc_id, + subnet_ids=[], + model_bucket_arn=model_bucket_arn, + ecr_repo_arn=None, + env=cdk.Environment( + account=os.environ["CDK_DEFAULT_ACCOUNT"], + region=os.environ["CDK_DEFAULT_REGION"], + ), + endpoint_config_prod_variant=dict( + variant_name="AllTraffic", + initial_variant_weight=1, + ), + ) + + template = Template.from_stack(endpoint_stack) + template.resource_count_is("AWS::SageMaker::Endpoint", 1)