Skip to content

Commit

Permalink
make RDS persistence optional
Browse files Browse the repository at this point in the history
  • Loading branch information
LeonLuttenberger committed Feb 26, 2024
1 parent d482e66 commit ac69a33
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 59 deletions.
6 changes: 3 additions & 3 deletions modules/mlflow/mlflow-fargate/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

This module runs Mlflow on AWS Fargate.

Uses EFS and RDS for backend storage.
By default, uses EFS for backend storage. Optionally, an RDS instance can be used for storage.

### Architecture

Expand All @@ -20,15 +20,15 @@ Uses EFS and RDS for backend storage.
- `subnet-ids`: The subnets that the Fargate task will use.
- `ecr-repository-name`: The name of the ECR repository to pull the image from.
- `artifacts-bucket-name`: Name of the artifacts store bucket
- `rds-hostname`: Endpoint address of the RDS instance
- `rds-credentials-secret-arn`: RDS database credentials stored in SecretsManager

#### Optional

- `ecs-cluster-name`: Name of the ECS cluster.
- `service-name`: Name of the service.
- `task-cpu-units`: The number of cpu units used by the Fargate task.
- `task-memory-limit-mb`: The amount (in MiB) of memory used by the Fargate task.
- `rds-hostname`: Endpoint address of the RDS instance
- `rds-credentials-secret-arn`: RDS database credentials stored in SecretsManager
- `lb-access-logs-bucket-name`: Name of the bucket to store load balancer access logs
- `lb-access-logs-bucket-prefix`: Prefix for load balancer access logs

Expand Down
6 changes: 0 additions & 6 deletions modules/mlflow/mlflow-fargate/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,6 @@ def _param(name: str) -> str:
if not artifacts_bucket_name:
raise ValueError("Missing input parameter artifacts-bucket-name")

if not rds_hostname:
raise ValueError("Missing input parameter rds-hostname")

if not rds_credentials_secret_arn:
raise ValueError("Missing input parameter rds-credentials-secret-arn")


app = aws_cdk.App()

Expand Down
65 changes: 48 additions & 17 deletions modules/mlflow/mlflow-fargate/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def __init__(
task_memory_limit_mb: int,
autoscale_max_capacity: int,
artifacts_bucket_name: str,
rds_hostname: str,
rds_credentials_secret_arn: str,
rds_hostname: Optional[str],
rds_credentials_secret_arn: Optional[str],
lb_access_logs_bucket_name: Optional[str],
lb_access_logs_bucket_prefix: Optional[str],
**kwargs: Any,
Expand All @@ -57,9 +57,6 @@ def __init__(
vpc = ec2.Vpc.from_lookup(self, "Vpc", vpc_id=vpc_id)
subnets = [ec2.Subnet.from_subnet_id(self, f"Subnet {subnet_id}", subnet_id) for subnet_id in subnet_ids]

# Load database
secret = rds.DatabaseSecret.from_secret_complete_arn(self, "Secret", rds_credentials_secret_arn)

# Create ECS cluster
cluster = ecs.Cluster(
self,
Expand All @@ -78,27 +75,21 @@ def __init__(
memory_limit_mib=task_memory_limit_mb,
)

container = task_definition.add_container(
"ContainerDef",
container = self._add_container_to_task_definition(
task_definition=task_definition,
image=ecs.ContainerImage.from_ecr_repository(
repository=ecr.Repository.from_repository_name(
self,
"ECRRepo",
repository_name=ecr_repo_name,
),
),
environment={
"BUCKET": model_bucket.s3_url_for_object(),
"HOST": rds_hostname,
"PORT": secret.secret_value_from_json("port").to_string(),
"DATABASE": secret.secret_value_from_json("dbname").to_string(),
"USERNAME": secret.secret_value_from_json("username").to_string(),
},
secrets={
"PASSWORD": ecs.Secret.from_secrets_manager(secret, "password"),
},
logging=ecs.LogDriver.aws_logs(stream_prefix="mlflow"),
model_bucket=model_bucket,
rds_hostname=rds_hostname,
rds_credentials_secret_arn=rds_credentials_secret_arn,
)

port_mapping = ecs.PortMapping(container_port=5000, host_port=5000, protocol=ecs.Protocol.TCP)
container.add_port_mappings(port_mapping)

Expand Down Expand Up @@ -216,3 +207,43 @@ def __init__(
),
],
)

def _add_container_to_task_definition(
self,
task_definition: ecs.FargateTaskDefinition,
image: ecs.EcrImage,
logging: ecs.LogDriver,
model_bucket: s3.IBucket,
rds_hostname: Optional[str],
rds_credentials_secret_arn: Optional[str],
) -> ecs.ContainerDefinition:
if rds_hostname and rds_credentials_secret_arn:
secret = rds.DatabaseSecret.from_secret_complete_arn(self, "Secret", rds_credentials_secret_arn)

return task_definition.add_container(
"ContainerDef",
image=image,
environment={
"BUCKET": model_bucket.s3_url_for_object(),
"HOST": rds_hostname,
"PORT": secret.secret_value_from_json("port").to_string(),
"DATABASE": secret.secret_value_from_json("dbname").to_string(),
"USERNAME": secret.secret_value_from_json("username").to_string(),
},
secrets={
"PASSWORD": ecs.Secret.from_secrets_manager(secret, "password"),
},
logging=logging,
)

if rds_hostname or rds_credentials_secret_arn:
raise ValueError("Either both rds-hostname and rds-credentials-secret-arn need to be defined or neither.")

return task_definition.add_container(
"ContainerDef",
image=image,
environment={
"BUCKET": model_bucket.s3_url_for_object(),
},
logging=logging,
)
67 changes: 48 additions & 19 deletions modules/mlflow/mlflow-fargate/tests/test_app.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
import os
import sys
from unittest import mock

import pytest


@pytest.fixture(scope="function")
def stack_defaults():
os.environ["SEEDFARMER_PROJECT_NAME"] = "test-project"
os.environ["SEEDFARMER_DEPLOYMENT_NAME"] = "test-deployment"
os.environ["SEEDFARMER_MODULE_NAME"] = "test-module"
os.environ["CDK_DEFAULT_ACCOUNT"] = "111111111111"
os.environ["CDK_DEFAULT_REGION"] = "us-east-1"

os.environ["SEEDFARMER_PARAMETER_VPC_ID"] = "vpc-12345"
os.environ["SEEDFARMER_PARAMETER_ECR_REPOSITORY_NAME"] = "repo5"
os.environ["SEEDFARMER_PARAMETER_ARTIFACTS_BUCKET_NAME"] = "bucket"
os.environ["SEEDFARMER_PARAMETER_RDS_HOSTNAME"] = "xxxxx"
os.environ[
"SEEDFARMER_PARAMETER_RDS_CREDENTIALS_SECRET_ARN"
] = "arn:aws:secretsmanager:us-east-1:111111111111:secret:xxxxxx/xxxxxx-yyyyyy"
with mock.patch.dict(os.environ, {}, clear=True):
os.environ["SEEDFARMER_PROJECT_NAME"] = "test-project"
os.environ["SEEDFARMER_DEPLOYMENT_NAME"] = "test-deployment"
os.environ["SEEDFARMER_MODULE_NAME"] = "test-module"
os.environ["CDK_DEFAULT_ACCOUNT"] = "111111111111"
os.environ["CDK_DEFAULT_REGION"] = "us-east-1"

os.environ["SEEDFARMER_PARAMETER_VPC_ID"] = "vpc-12345"
os.environ["SEEDFARMER_PARAMETER_ECR_REPOSITORY_NAME"] = "repo5"
os.environ["SEEDFARMER_PARAMETER_ARTIFACTS_BUCKET_NAME"] = "bucket"

# Unload the app import so that subsequent tests don't reuse
if "app" in sys.modules:
del sys.modules["app"]

# Unload the app import so that subsequent tests don't reuse
if "app" in sys.modules:
del sys.modules["app"]
yield


def test_app(stack_defaults):
Expand All @@ -32,19 +32,48 @@ def test_app(stack_defaults):
def test_vpc_id(stack_defaults):
del os.environ["SEEDFARMER_PARAMETER_VPC_ID"]

with pytest.raises(Exception, match="Missing input parameter vpc-id"):
with pytest.raises(ValueError, match="Missing input parameter vpc-id"):
import app # noqa: F401


def test_ecr_repository_name(stack_defaults):
del os.environ["SEEDFARMER_PARAMETER_ECR_REPOSITORY_NAME"]

with pytest.raises(Exception, match="Missing input parameter ecr-repository-name"):
with pytest.raises(ValueError, match="Missing input parameter ecr-repository-name"):
import app # noqa: F401


def test_artifacts_bucket_name(stack_defaults):
del os.environ["SEEDFARMER_PARAMETER_ARTIFACTS_BUCKET_NAME"]

with pytest.raises(Exception, match="Missing input parameter artifacts-bucket-name"):
with pytest.raises(ValueError, match="Missing input parameter artifacts-bucket-name"):
import app # noqa: F401


def test_rds_settings(stack_defaults):
os.environ["SEEDFARMER_PARAMETER_RDS_HOSTNAME"] = "xxxxx"
os.environ["SEEDFARMER_PARAMETER_RDS_CREDENTIALS_SECRET_ARN"] = (
"arn:aws:secretsmanager:us-east-1:111111111111:secret:xxxxxx/xxxxxx-yyyyyy"
)

import app # noqa: F401


def test_rds_settings_missing_hostname(stack_defaults):
os.environ["SEEDFARMER_PARAMETER_RDS_CREDENTIALS_SECRET_ARN"] = (
"arn:aws:secretsmanager:us-east-1:111111111111:secret:xxxxxx/xxxxxx-yyyyyy"
)

with pytest.raises(
ValueError, match="Either both rds-hostname and rds-credentials-secret-arn need to be defined or neither."
):
import app # noqa: F401


def test_rds_settings_missing_credentials(stack_defaults):
os.environ["SEEDFARMER_PARAMETER_RDS_HOSTNAME"] = "xxxxx"

with pytest.raises(
ValueError, match="Either both rds-hostname and rds-credentials-secret-arn need to be defined or neither."
):
import app # noqa: F401
28 changes: 19 additions & 9 deletions modules/mlflow/mlflow-fargate/tests/test_stack.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
import os
import sys
from unittest import mock

import aws_cdk as cdk
import pytest
from aws_cdk.assertions import Template


@pytest.fixture(scope="function")
def stack_defaults() -> None:
os.environ["CDK_DEFAULT_ACCOUNT"] = "111111111111"
os.environ["CDK_DEFAULT_REGION"] = "us-east-1"
def stack_defaults():
with mock.patch.dict(os.environ, {}, clear=True):
os.environ["CDK_DEFAULT_ACCOUNT"] = "111111111111"
os.environ["CDK_DEFAULT_REGION"] = "us-east-1"

# Unload the app import so that subsequent tests don't reuse
if "stack" in sys.modules:
del sys.modules["stack"]
# Unload the app import so that subsequent tests don't reuse
if "stack" in sys.modules:
del sys.modules["stack"]

yield

def test_synthesize_stack() -> None:

@pytest.mark.parametrize("use_rds", [False, True])
def test_synthesize_stack(stack_defaults, use_rds) -> None:
import stack

app = cdk.App()
Expand All @@ -33,8 +38,13 @@ def test_synthesize_stack() -> None:
task_memory_limit_mb = 8 * 1024
autoscale_max_capacity = 2
artifacts_bucket_name = "bucket"
rds_hostname = "hostname"
secret_arn = "arn:aws:secretsmanager:us-east-1:111111111111:secret:xxxxxx/xxxxxx-yyyyyy"

if use_rds:
rds_hostname = "hostname"
secret_arn = "arn:aws:secretsmanager:us-east-1:111111111111:secret:xxxxxx/xxxxxx-yyyyyy"
else:
rds_hostname = None
secret_arn = None

stack = stack.MlflowFargateStack(
scope=app,
Expand Down
17 changes: 12 additions & 5 deletions modules/mlflow/mlflow-image/src/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@ RUN pip install \

EXPOSE 5000

CMD mlflow server \
--host 0.0.0.0 \
--port 5000 \
--default-artifact-root ${BUCKET} \
--backend-store-uri mysql+pymysql://${USERNAME}:${PASSWORD}@${HOST}:${PORT}/${DATABASE}
CMD if [ -n "$HOST" ]; then \
mlflow server \
--host 0.0.0.0 \
--port 5000 \
--default-artifact-root ${BUCKET} \
--backend-store-uri mysql+pymysql://${USERNAME}:${PASSWORD}@${HOST}:${PORT}/${DATABASE}; \
else \
mlflow server \
--host 0.0.0.0 \
--port 5000 \
--default-artifact-root ${BUCKET}; \
fi

0 comments on commit ac69a33

Please sign in to comment.