Skip to content

Commit

Permalink
typing & coverage
Browse files Browse the repository at this point in the history
Signed-off-by: Anton Kukushkin <[email protected]>
  • Loading branch information
kukushking committed Feb 5, 2024
1 parent 0393ee9 commit 9a72c4d
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 18 deletions.
3 changes: 3 additions & 0 deletions modules/sagemaker/sagemaker-endpoint/coverage.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[run]
omit =
tests/*
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
# SPDX-License-Identifier: Apache-2.0

from logging import Logger
from typing import Optional
from typing import Any

import boto3
from botocore.exceptions import ClientError

logger = Logger(name="get_approved_package")


def get_approved_package(region_name: str, model_package_group_name: str) -> Optional[str]:
def get_approved_package(region_name: str, model_package_group_name: str) -> Any:
"""Gets the latest approved model package for a model package group.
Returns:
The SageMaker Model Package ARN.
Expand Down
24 changes: 9 additions & 15 deletions modules/sagemaker/sagemaker-endpoint/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from datetime import datetime, timezone
from typing import Any, List, Optional, TypedDict
from typing import Any, Dict, List, Optional

import constructs
from aws_cdk import Aspects, Stack, Tags
Expand All @@ -12,18 +12,10 @@
from aws_cdk import aws_s3 as s3
from aws_cdk import aws_sagemaker as sagemaker
from cdk_nag import AwsSolutionsChecks, NagPackSuppression, NagSuppressions
from typing_extensions import NotRequired, Required

from scripts.get_approved_package import get_approved_package


class EndpointConfigProductionVariant(TypedDict):
variant_name: Required[str]
initial_instance_count: NotRequired[float]
initial_variant_weight: Required[float]
instance_type: NotRequired[str]


def get_timestamp() -> str:
now = datetime.now().replace(tzinfo=timezone.utc)
return now.strftime("%Y%m%d%H%M%S")
Expand All @@ -44,7 +36,7 @@ def __init__(
subnet_ids: List[str],
model_bucket_arn: Optional[str],
ecr_repo_arn: Optional[str],
endpoint_config_prod_variant: EndpointConfigProductionVariant,
endpoint_config_prod_variant: Dict[str, Any],
**kwargs: Any,
) -> None:
super().__init__(scope, id, **kwargs)
Expand Down Expand Up @@ -88,13 +80,16 @@ def __init__(
)
)

model_execution_role_arn: str = model_execution_role.role_arn
model_execution_role_arn = model_execution_role.role_arn

self.model_execution_role_arn = model_execution_role_arn

if not model_package_arn:
# Get latest approved model package from the model registry
model_package_arn = get_approved_package(self.region, model_package_group_name)
if model_package_group_name:
model_package_arn = get_approved_package(self.region, model_package_group_name)
else:
raise ValueError("Either model_package_arn or model_package_group_name is required")

# Create model instance
model_name = f"{app_prefix}-{get_timestamp()}"
Expand All @@ -121,8 +116,7 @@ def __init__(
kms_key.grant_encrypt_decrypt(iam.AccountRootPrincipal())

# Create endpoint config
endpoint_config_name = f"{app_prefix}-endpoint-config"
endpoint_config_prod_variant = endpoint_config_prod_variant or {}
endpoint_config_name: str = f"{app_prefix}-endpoint-config"
endpoint_config = sagemaker.CfnEndpointConfig(
self,
f"{app_prefix}-endpoint-config",
Expand All @@ -142,7 +136,7 @@ def __init__(
endpoint = sagemaker.CfnEndpoint(
self,
"Endpoint",
endpoint_config_name=endpoint_config.endpoint_config_name,
endpoint_config_name=endpoint_config.endpoint_config_name, # type: ignore[arg-type]
endpoint_name=endpoint_name,
)
endpoint.add_depends_on(endpoint_config)
Expand Down
2 changes: 1 addition & 1 deletion modules/sagemaker/sagemaker-endpoint/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ def test_app(stack_defaults):
def test_vpc_id(stack_defaults):
del os.environ["SEEDFARMER_PARAMETER_VPC_ID"]

with pytest.raises(Exception):
with pytest.raises(Exception, match="Missing input parameter vpc-id"):
import app # noqa: F401

0 comments on commit 9a72c4d

Please sign in to comment.