Skip to content

Commit

Permalink
fix(dms): fix MagicMock errors using Botocore
Browse files Browse the repository at this point in the history
  • Loading branch information
danibarranqueroo committed Nov 5, 2024
1 parent f0809ea commit 2a7fb16
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 80 deletions.
2 changes: 1 addition & 1 deletion prowler/providers/aws/services/dms/dms_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,4 @@ class RepInstance(BaseModel):
security_groups: list[str] = []
multi_az: bool
region: str
tags: Optional[list]
tags: Optional[list] = []
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest import mock

import botocore
from boto3 import client
from moto import mock_aws

Expand All @@ -16,59 +17,91 @@
)
KMS_KEY_ID = f"arn:aws:kms:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:key/abcdabcd-1234-abcd-1234-abcdabcdabcd"

make_api_call = botocore.client.BaseClient._make_api_call


def mock_make_api_call_public(self, operation_name, kwargs):
if operation_name == "DescribeReplicationInstances":
return {
"ReplicationInstances": [
{
"ReplicationInstanceIdentifier": DMS_INSTANCE_NAME,
"ReplicationInstanceStatus": "available",
"AutoMinorVersionUpgrade": True,
"PubliclyAccessible": True,
"ReplicationInstanceArn": DMS_INSTANCE_ARN,
"MultiAZ": True,
"VpcSecurityGroups": [],
"KmsKeyId": KMS_KEY_ID,
},
]
}

return make_api_call(self, operation_name, kwargs)


def mock_make_api_call_private(self, operation_name, kwargs):
if operation_name == "DescribeReplicationInstances":
return {
"ReplicationInstances": [
{
"ReplicationInstanceIdentifier": DMS_INSTANCE_NAME,
"ReplicationInstanceStatus": "available",
"AutoMinorVersionUpgrade": True,
"PubliclyAccessible": False,
"ReplicationInstanceArn": DMS_INSTANCE_ARN,
"MultiAZ": True,
"VpcSecurityGroups": [],
"KmsKeyId": KMS_KEY_ID,
},
]
}

return make_api_call(self, operation_name, kwargs)


class Test_dms_instance_no_public_access:
@mock_aws
def test_dms_no_instances(self):
dms_client = mock.MagicMock()
dms_client = client("dms", region_name=AWS_REGION_US_EAST_1)
dms_client.instances = []

aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])

from prowler.providers.aws.services.dms.dms_service import DMS

with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider([AWS_REGION_US_EAST_1]),
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.dms.dms_instance_no_public_access.dms_instance_no_public_access.dms_client",
new=DMS(aws_provider),
):
with mock.patch(
"prowler.providers.aws.services.dms.dms_service.DMS",
new=dms_client,
), mock.patch(
"prowler.providers.aws.services.dms.dms_client.dms_client",
new=dms_client,
):
from prowler.providers.aws.services.dms.dms_instance_no_public_access.dms_instance_no_public_access import (
dms_instance_no_public_access,
)
from prowler.providers.aws.services.dms.dms_instance_no_public_access.dms_instance_no_public_access import (
dms_instance_no_public_access,
)

check = dms_instance_no_public_access()
result = check.execute()
assert len(result) == 0
check = dms_instance_no_public_access()
result = check.execute()
assert len(result) == 0

@mock_aws
def test_dms_private(self):
dms_client = mock.MagicMock()
dms_client.instances = []
dms_client.instances.append(
RepInstance(
id=DMS_INSTANCE_NAME,
arn=DMS_INSTANCE_ARN,
status="available",
public=False,
security_groups=[],
kms_key=KMS_KEY_ID,
auto_minor_version_upgrade=False,
multi_az=False,
region=AWS_REGION_US_EAST_1,
tags=[{"Key": "Name", "Value": DMS_INSTANCE_NAME}],
)
)

with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider([AWS_REGION_US_EAST_1]),
"botocore.client.BaseClient._make_api_call",
new=mock_make_api_call_private,
):

aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])

from prowler.providers.aws.services.dms.dms_service import DMS

with mock.patch(
"prowler.providers.aws.services.dms.dms_service.DMS",
new=dms_client,
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.dms.dms_client.dms_client",
new=dms_client,
"prowler.providers.aws.services.dms.dms_instance_no_public_access.dms_instance_no_public_access.dms_client",
new=DMS(aws_provider),
):
from prowler.providers.aws.services.dms.dms_instance_no_public_access.dms_instance_no_public_access import (
dms_instance_no_public_access,
Expand All @@ -85,41 +118,25 @@ def test_dms_private(self):
assert result[0].region == AWS_REGION_US_EAST_1
assert result[0].resource_id == DMS_INSTANCE_NAME
assert result[0].resource_arn == DMS_INSTANCE_ARN
assert result[0].resource_tags == [
{
"Key": "Name",
"Value": DMS_INSTANCE_NAME,
}
]
assert result[0].resource_tags == []

@mock_aws
def test_dms_public(self):
dms_client = mock.MagicMock()
dms_client.instances = []
dms_client.instances.append(
RepInstance(
id=DMS_INSTANCE_NAME,
arn=DMS_INSTANCE_ARN,
status="available",
public=True,
security_groups=[],
kms_key=KMS_KEY_ID,
auto_minor_version_upgrade=False,
multi_az=False,
region=AWS_REGION_US_EAST_1,
tags=[{"Key": "Name", "Value": DMS_INSTANCE_NAME}],
)
)

with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=set_mocked_aws_provider([AWS_REGION_US_EAST_1]),
"botocore.client.BaseClient._make_api_call",
new=mock_make_api_call_public,
):

aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])

from prowler.providers.aws.services.dms.dms_service import DMS

with mock.patch(
"prowler.providers.aws.services.dms.dms_service.DMS",
new=dms_client,
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
), mock.patch(
"prowler.providers.aws.services.dms.dms_client.dms_client",
new=dms_client,
"prowler.providers.aws.services.dms.dms_instance_no_public_access.dms_instance_no_public_access.dms_client",
new=DMS(aws_provider),
):
from prowler.providers.aws.services.dms.dms_instance_no_public_access.dms_instance_no_public_access import (
dms_instance_no_public_access,
Expand All @@ -136,12 +153,7 @@ def test_dms_public(self):
assert result[0].region == AWS_REGION_US_EAST_1
assert result[0].resource_id == DMS_INSTANCE_NAME
assert result[0].resource_arn == DMS_INSTANCE_ARN
assert result[0].resource_tags == [
{
"Key": "Name",
"Value": DMS_INSTANCE_NAME,
}
]
assert result[0].resource_tags == []

@mock_aws
def test_dms_public_with_public_sg(self):
Expand All @@ -160,6 +172,7 @@ def test_dms_public_with_public_sg(self):
}
],
)
dms_client = mock.MagicMock
dms_client = mock.MagicMock()
dms_client.instances = []
dms_client.instances.append(
Expand All @@ -176,14 +189,12 @@ def test_dms_public_with_public_sg(self):
tags=[{"Key": "Name", "Value": DMS_INSTANCE_NAME}],
)
)

from prowler.providers.aws.services.ec2.ec2_service import EC2

aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
aws_provider.audit_metadata.expected_checks = [
"ec2_securitygroup_allow_ingress_from_internet_to_any_port"
]

with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
Expand All @@ -205,7 +216,6 @@ def test_dms_public_with_public_sg(self):

check = dms_instance_no_public_access()
result = check.execute()

assert len(result) == 1
assert result[0].status == "FAIL"
assert (
Expand Down Expand Up @@ -255,14 +265,12 @@ def test_dms_public_with_filtered_sg(self):
tags=[{"Key": "Name", "Value": DMS_INSTANCE_NAME}],
)
)

from prowler.providers.aws.services.ec2.ec2_service import EC2

aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1])
aws_provider.audit_metadata.expected_checks = [
"ec2_securitygroup_allow_ingress_from_internet_to_any_port"
]

with mock.patch(
"prowler.providers.common.provider.Provider.get_global_provider",
return_value=aws_provider,
Expand All @@ -284,7 +292,6 @@ def test_dms_public_with_filtered_sg(self):

check = dms_instance_no_public_access()
result = check.execute()

assert len(result) == 1
assert result[0].status == "PASS"
assert (
Expand Down

0 comments on commit 2a7fb16

Please sign in to comment.