From ac69a335533d0dfda1981f75a9c0642322ad6a3e Mon Sep 17 00:00:00 2001 From: Leon Luttenberger Date: Mon, 26 Feb 2024 12:09:57 -0600 Subject: [PATCH] make RDS persistence optional --- modules/mlflow/mlflow-fargate/README.md | 6 +- modules/mlflow/mlflow-fargate/app.py | 6 -- modules/mlflow/mlflow-fargate/stack.py | 65 +++++++++++++----- .../mlflow/mlflow-fargate/tests/test_app.py | 67 +++++++++++++------ .../mlflow/mlflow-fargate/tests/test_stack.py | 28 +++++--- modules/mlflow/mlflow-image/src/Dockerfile | 17 +++-- 6 files changed, 130 insertions(+), 59 deletions(-) diff --git a/modules/mlflow/mlflow-fargate/README.md b/modules/mlflow/mlflow-fargate/README.md index 0a16ffa4..0cd04333 100644 --- a/modules/mlflow/mlflow-fargate/README.md +++ b/modules/mlflow/mlflow-fargate/README.md @@ -4,7 +4,7 @@ This module runs Mlflow on AWS Fargate. -Uses EFS and RDS for backend storage. +By default, uses EFS for backend storage. Optionally, an RDS instance can be used for storage. ### Architecture @@ -20,8 +20,6 @@ Uses EFS and RDS for backend storage. - `subnet-ids`: The subnets that the Fargate task will use. - `ecr-repository-name`: The name of the ECR repository to pull the image from. - `artifacts-bucket-name`: Name of the artifacts store bucket -- `rds-hostname`: Endpoint address of the RDS instance -- `rds-credentials-secret-arn`: RDS database credentials stored in SecretsManager #### Optional @@ -29,6 +27,8 @@ Uses EFS and RDS for backend storage. - `service-name`: Name of the service. - `task-cpu-units`: The number of cpu units used by the Fargate task. - `task-memory-limit-mb`: The amount (in MiB) of memory used by the Fargate task. +- `rds-hostname`: Endpoint address of the RDS instance +- `rds-credentials-secret-arn`: RDS database credentials stored in SecretsManager - `lb-access-logs-bucket-name`: Name of the bucket to store load balancer access logs - `lb-access-logs-bucket-prefix`: Prefix for load balancer access logs diff --git a/modules/mlflow/mlflow-fargate/app.py b/modules/mlflow/mlflow-fargate/app.py index 7f69a786..d677e11c 100644 --- a/modules/mlflow/mlflow-fargate/app.py +++ b/modules/mlflow/mlflow-fargate/app.py @@ -54,12 +54,6 @@ def _param(name: str) -> str: if not artifacts_bucket_name: raise ValueError("Missing input parameter artifacts-bucket-name") -if not rds_hostname: - raise ValueError("Missing input parameter rds-hostname") - -if not rds_credentials_secret_arn: - raise ValueError("Missing input parameter rds-credentials-secret-arn") - app = aws_cdk.App() diff --git a/modules/mlflow/mlflow-fargate/stack.py b/modules/mlflow/mlflow-fargate/stack.py index 89dae0ce..359d5cb0 100644 --- a/modules/mlflow/mlflow-fargate/stack.py +++ b/modules/mlflow/mlflow-fargate/stack.py @@ -31,8 +31,8 @@ def __init__( task_memory_limit_mb: int, autoscale_max_capacity: int, artifacts_bucket_name: str, - rds_hostname: str, - rds_credentials_secret_arn: str, + rds_hostname: Optional[str], + rds_credentials_secret_arn: Optional[str], lb_access_logs_bucket_name: Optional[str], lb_access_logs_bucket_prefix: Optional[str], **kwargs: Any, @@ -57,9 +57,6 @@ def __init__( vpc = ec2.Vpc.from_lookup(self, "Vpc", vpc_id=vpc_id) subnets = [ec2.Subnet.from_subnet_id(self, f"Subnet {subnet_id}", subnet_id) for subnet_id in subnet_ids] - # Load database - secret = rds.DatabaseSecret.from_secret_complete_arn(self, "Secret", rds_credentials_secret_arn) - # Create ECS cluster cluster = ecs.Cluster( self, @@ -78,8 +75,8 @@ def __init__( memory_limit_mib=task_memory_limit_mb, ) - container = task_definition.add_container( - "ContainerDef", + container = self._add_container_to_task_definition( + task_definition=task_definition, image=ecs.ContainerImage.from_ecr_repository( repository=ecr.Repository.from_repository_name( self, @@ -87,18 +84,12 @@ def __init__( repository_name=ecr_repo_name, ), ), - environment={ - "BUCKET": model_bucket.s3_url_for_object(), - "HOST": rds_hostname, - "PORT": secret.secret_value_from_json("port").to_string(), - "DATABASE": secret.secret_value_from_json("dbname").to_string(), - "USERNAME": secret.secret_value_from_json("username").to_string(), - }, - secrets={ - "PASSWORD": ecs.Secret.from_secrets_manager(secret, "password"), - }, logging=ecs.LogDriver.aws_logs(stream_prefix="mlflow"), + model_bucket=model_bucket, + rds_hostname=rds_hostname, + rds_credentials_secret_arn=rds_credentials_secret_arn, ) + port_mapping = ecs.PortMapping(container_port=5000, host_port=5000, protocol=ecs.Protocol.TCP) container.add_port_mappings(port_mapping) @@ -216,3 +207,43 @@ def __init__( ), ], ) + + def _add_container_to_task_definition( + self, + task_definition: ecs.FargateTaskDefinition, + image: ecs.EcrImage, + logging: ecs.LogDriver, + model_bucket: s3.IBucket, + rds_hostname: Optional[str], + rds_credentials_secret_arn: Optional[str], + ) -> ecs.ContainerDefinition: + if rds_hostname and rds_credentials_secret_arn: + secret = rds.DatabaseSecret.from_secret_complete_arn(self, "Secret", rds_credentials_secret_arn) + + return task_definition.add_container( + "ContainerDef", + image=image, + environment={ + "BUCKET": model_bucket.s3_url_for_object(), + "HOST": rds_hostname, + "PORT": secret.secret_value_from_json("port").to_string(), + "DATABASE": secret.secret_value_from_json("dbname").to_string(), + "USERNAME": secret.secret_value_from_json("username").to_string(), + }, + secrets={ + "PASSWORD": ecs.Secret.from_secrets_manager(secret, "password"), + }, + logging=logging, + ) + + if rds_hostname or rds_credentials_secret_arn: + raise ValueError("Either both rds-hostname and rds-credentials-secret-arn need to be defined or neither.") + + return task_definition.add_container( + "ContainerDef", + image=image, + environment={ + "BUCKET": model_bucket.s3_url_for_object(), + }, + logging=logging, + ) diff --git a/modules/mlflow/mlflow-fargate/tests/test_app.py b/modules/mlflow/mlflow-fargate/tests/test_app.py index 37a37438..2fbeecf2 100644 --- a/modules/mlflow/mlflow-fargate/tests/test_app.py +++ b/modules/mlflow/mlflow-fargate/tests/test_app.py @@ -1,28 +1,28 @@ import os import sys +from unittest import mock import pytest @pytest.fixture(scope="function") def stack_defaults(): - os.environ["SEEDFARMER_PROJECT_NAME"] = "test-project" - os.environ["SEEDFARMER_DEPLOYMENT_NAME"] = "test-deployment" - os.environ["SEEDFARMER_MODULE_NAME"] = "test-module" - os.environ["CDK_DEFAULT_ACCOUNT"] = "111111111111" - os.environ["CDK_DEFAULT_REGION"] = "us-east-1" - - os.environ["SEEDFARMER_PARAMETER_VPC_ID"] = "vpc-12345" - os.environ["SEEDFARMER_PARAMETER_ECR_REPOSITORY_NAME"] = "repo5" - os.environ["SEEDFARMER_PARAMETER_ARTIFACTS_BUCKET_NAME"] = "bucket" - os.environ["SEEDFARMER_PARAMETER_RDS_HOSTNAME"] = "xxxxx" - os.environ[ - "SEEDFARMER_PARAMETER_RDS_CREDENTIALS_SECRET_ARN" - ] = "arn:aws:secretsmanager:us-east-1:111111111111:secret:xxxxxx/xxxxxx-yyyyyy" + with mock.patch.dict(os.environ, {}, clear=True): + os.environ["SEEDFARMER_PROJECT_NAME"] = "test-project" + os.environ["SEEDFARMER_DEPLOYMENT_NAME"] = "test-deployment" + os.environ["SEEDFARMER_MODULE_NAME"] = "test-module" + os.environ["CDK_DEFAULT_ACCOUNT"] = "111111111111" + os.environ["CDK_DEFAULT_REGION"] = "us-east-1" + + os.environ["SEEDFARMER_PARAMETER_VPC_ID"] = "vpc-12345" + os.environ["SEEDFARMER_PARAMETER_ECR_REPOSITORY_NAME"] = "repo5" + os.environ["SEEDFARMER_PARAMETER_ARTIFACTS_BUCKET_NAME"] = "bucket" + + # Unload the app import so that subsequent tests don't reuse + if "app" in sys.modules: + del sys.modules["app"] - # Unload the app import so that subsequent tests don't reuse - if "app" in sys.modules: - del sys.modules["app"] + yield def test_app(stack_defaults): @@ -32,19 +32,48 @@ def test_app(stack_defaults): def test_vpc_id(stack_defaults): del os.environ["SEEDFARMER_PARAMETER_VPC_ID"] - with pytest.raises(Exception, match="Missing input parameter vpc-id"): + with pytest.raises(ValueError, match="Missing input parameter vpc-id"): import app # noqa: F401 def test_ecr_repository_name(stack_defaults): del os.environ["SEEDFARMER_PARAMETER_ECR_REPOSITORY_NAME"] - with pytest.raises(Exception, match="Missing input parameter ecr-repository-name"): + with pytest.raises(ValueError, match="Missing input parameter ecr-repository-name"): import app # noqa: F401 def test_artifacts_bucket_name(stack_defaults): del os.environ["SEEDFARMER_PARAMETER_ARTIFACTS_BUCKET_NAME"] - with pytest.raises(Exception, match="Missing input parameter artifacts-bucket-name"): + with pytest.raises(ValueError, match="Missing input parameter artifacts-bucket-name"): + import app # noqa: F401 + + +def test_rds_settings(stack_defaults): + os.environ["SEEDFARMER_PARAMETER_RDS_HOSTNAME"] = "xxxxx" + os.environ["SEEDFARMER_PARAMETER_RDS_CREDENTIALS_SECRET_ARN"] = ( + "arn:aws:secretsmanager:us-east-1:111111111111:secret:xxxxxx/xxxxxx-yyyyyy" + ) + + import app # noqa: F401 + + +def test_rds_settings_missing_hostname(stack_defaults): + os.environ["SEEDFARMER_PARAMETER_RDS_CREDENTIALS_SECRET_ARN"] = ( + "arn:aws:secretsmanager:us-east-1:111111111111:secret:xxxxxx/xxxxxx-yyyyyy" + ) + + with pytest.raises( + ValueError, match="Either both rds-hostname and rds-credentials-secret-arn need to be defined or neither." + ): + import app # noqa: F401 + + +def test_rds_settings_missing_credentials(stack_defaults): + os.environ["SEEDFARMER_PARAMETER_RDS_HOSTNAME"] = "xxxxx" + + with pytest.raises( + ValueError, match="Either both rds-hostname and rds-credentials-secret-arn need to be defined or neither." + ): import app # noqa: F401 diff --git a/modules/mlflow/mlflow-fargate/tests/test_stack.py b/modules/mlflow/mlflow-fargate/tests/test_stack.py index 89a0d447..91bad457 100644 --- a/modules/mlflow/mlflow-fargate/tests/test_stack.py +++ b/modules/mlflow/mlflow-fargate/tests/test_stack.py @@ -1,5 +1,6 @@ import os import sys +from unittest import mock import aws_cdk as cdk import pytest @@ -7,16 +8,20 @@ @pytest.fixture(scope="function") -def stack_defaults() -> None: - os.environ["CDK_DEFAULT_ACCOUNT"] = "111111111111" - os.environ["CDK_DEFAULT_REGION"] = "us-east-1" +def stack_defaults(): + with mock.patch.dict(os.environ, {}, clear=True): + os.environ["CDK_DEFAULT_ACCOUNT"] = "111111111111" + os.environ["CDK_DEFAULT_REGION"] = "us-east-1" - # Unload the app import so that subsequent tests don't reuse - if "stack" in sys.modules: - del sys.modules["stack"] + # Unload the app import so that subsequent tests don't reuse + if "stack" in sys.modules: + del sys.modules["stack"] + yield -def test_synthesize_stack() -> None: + +@pytest.mark.parametrize("use_rds", [False, True]) +def test_synthesize_stack(stack_defaults, use_rds) -> None: import stack app = cdk.App() @@ -33,8 +38,13 @@ def test_synthesize_stack() -> None: task_memory_limit_mb = 8 * 1024 autoscale_max_capacity = 2 artifacts_bucket_name = "bucket" - rds_hostname = "hostname" - secret_arn = "arn:aws:secretsmanager:us-east-1:111111111111:secret:xxxxxx/xxxxxx-yyyyyy" + + if use_rds: + rds_hostname = "hostname" + secret_arn = "arn:aws:secretsmanager:us-east-1:111111111111:secret:xxxxxx/xxxxxx-yyyyyy" + else: + rds_hostname = None + secret_arn = None stack = stack.MlflowFargateStack( scope=app, diff --git a/modules/mlflow/mlflow-image/src/Dockerfile b/modules/mlflow/mlflow-image/src/Dockerfile index d8558169..7bca6848 100644 --- a/modules/mlflow/mlflow-image/src/Dockerfile +++ b/modules/mlflow/mlflow-image/src/Dockerfile @@ -8,8 +8,15 @@ RUN pip install \ EXPOSE 5000 -CMD mlflow server \ - --host 0.0.0.0 \ - --port 5000 \ - --default-artifact-root ${BUCKET} \ - --backend-store-uri mysql+pymysql://${USERNAME}:${PASSWORD}@${HOST}:${PORT}/${DATABASE} +CMD if [ -n "$HOST" ]; then \ + mlflow server \ + --host 0.0.0.0 \ + --port 5000 \ + --default-artifact-root ${BUCKET} \ + --backend-store-uri mysql+pymysql://${USERNAME}:${PASSWORD}@${HOST}:${PORT}/${DATABASE}; \ + else \ + mlflow server \ + --host 0.0.0.0 \ + --port 5000 \ + --default-artifact-root ${BUCKET}; \ + fi