From 2a7fb16ea095863c24768946e2d964633413ce43 Mon Sep 17 00:00:00 2001 From: Daniel Barranquero Date: Tue, 5 Nov 2024 11:15:02 +0100 Subject: [PATCH] fix(dms): fix MagicMock errors using Botocore --- .../providers/aws/services/dms/dms_service.py | 2 +- .../dms_no_public_access_test.py | 165 +++++++++--------- 2 files changed, 87 insertions(+), 80 deletions(-) diff --git a/prowler/providers/aws/services/dms/dms_service.py b/prowler/providers/aws/services/dms/dms_service.py index c7385afb706..27ec105ea39 100644 --- a/prowler/providers/aws/services/dms/dms_service.py +++ b/prowler/providers/aws/services/dms/dms_service.py @@ -112,4 +112,4 @@ class RepInstance(BaseModel): security_groups: list[str] = [] multi_az: bool region: str - tags: Optional[list] + tags: Optional[list] = [] diff --git a/tests/providers/aws/services/dms/dms_instance_no_public_access/dms_no_public_access_test.py b/tests/providers/aws/services/dms/dms_instance_no_public_access/dms_no_public_access_test.py index 83d3fd91725..0b9a5209c52 100644 --- a/tests/providers/aws/services/dms/dms_instance_no_public_access/dms_no_public_access_test.py +++ b/tests/providers/aws/services/dms/dms_instance_no_public_access/dms_no_public_access_test.py @@ -1,5 +1,6 @@ from unittest import mock +import botocore from boto3 import client from moto import mock_aws @@ -16,59 +17,91 @@ ) KMS_KEY_ID = f"arn:aws:kms:{AWS_REGION_US_EAST_1}:{AWS_ACCOUNT_NUMBER}:key/abcdabcd-1234-abcd-1234-abcdabcdabcd" +make_api_call = botocore.client.BaseClient._make_api_call + + +def mock_make_api_call_public(self, operation_name, kwargs): + if operation_name == "DescribeReplicationInstances": + return { + "ReplicationInstances": [ + { + "ReplicationInstanceIdentifier": DMS_INSTANCE_NAME, + "ReplicationInstanceStatus": "available", + "AutoMinorVersionUpgrade": True, + "PubliclyAccessible": True, + "ReplicationInstanceArn": DMS_INSTANCE_ARN, + "MultiAZ": True, + "VpcSecurityGroups": [], + "KmsKeyId": KMS_KEY_ID, + }, + ] + } + + return make_api_call(self, operation_name, kwargs) + + +def mock_make_api_call_private(self, operation_name, kwargs): + if operation_name == "DescribeReplicationInstances": + return { + "ReplicationInstances": [ + { + "ReplicationInstanceIdentifier": DMS_INSTANCE_NAME, + "ReplicationInstanceStatus": "available", + "AutoMinorVersionUpgrade": True, + "PubliclyAccessible": False, + "ReplicationInstanceArn": DMS_INSTANCE_ARN, + "MultiAZ": True, + "VpcSecurityGroups": [], + "KmsKeyId": KMS_KEY_ID, + }, + ] + } + + return make_api_call(self, operation_name, kwargs) + class Test_dms_instance_no_public_access: + @mock_aws def test_dms_no_instances(self): - dms_client = mock.MagicMock() + dms_client = client("dms", region_name=AWS_REGION_US_EAST_1) dms_client.instances = [] + aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1]) + + from prowler.providers.aws.services.dms.dms_service import DMS + with mock.patch( "prowler.providers.common.provider.Provider.get_global_provider", - return_value=set_mocked_aws_provider([AWS_REGION_US_EAST_1]), + return_value=aws_provider, + ), mock.patch( + "prowler.providers.aws.services.dms.dms_instance_no_public_access.dms_instance_no_public_access.dms_client", + new=DMS(aws_provider), ): - with mock.patch( - "prowler.providers.aws.services.dms.dms_service.DMS", - new=dms_client, - ), mock.patch( - "prowler.providers.aws.services.dms.dms_client.dms_client", - new=dms_client, - ): - from prowler.providers.aws.services.dms.dms_instance_no_public_access.dms_instance_no_public_access import ( - dms_instance_no_public_access, - ) + from prowler.providers.aws.services.dms.dms_instance_no_public_access.dms_instance_no_public_access import ( + dms_instance_no_public_access, + ) - check = dms_instance_no_public_access() - result = check.execute() - assert len(result) == 0 + check = dms_instance_no_public_access() + result = check.execute() + assert len(result) == 0 + @mock_aws def test_dms_private(self): - dms_client = mock.MagicMock() - dms_client.instances = [] - dms_client.instances.append( - RepInstance( - id=DMS_INSTANCE_NAME, - arn=DMS_INSTANCE_ARN, - status="available", - public=False, - security_groups=[], - kms_key=KMS_KEY_ID, - auto_minor_version_upgrade=False, - multi_az=False, - region=AWS_REGION_US_EAST_1, - tags=[{"Key": "Name", "Value": DMS_INSTANCE_NAME}], - ) - ) - with mock.patch( - "prowler.providers.common.provider.Provider.get_global_provider", - return_value=set_mocked_aws_provider([AWS_REGION_US_EAST_1]), + "botocore.client.BaseClient._make_api_call", + new=mock_make_api_call_private, ): + + aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1]) + + from prowler.providers.aws.services.dms.dms_service import DMS + with mock.patch( - "prowler.providers.aws.services.dms.dms_service.DMS", - new=dms_client, + "prowler.providers.common.provider.Provider.get_global_provider", + return_value=aws_provider, ), mock.patch( - "prowler.providers.aws.services.dms.dms_client.dms_client", - new=dms_client, + "prowler.providers.aws.services.dms.dms_instance_no_public_access.dms_instance_no_public_access.dms_client", + new=DMS(aws_provider), ): from prowler.providers.aws.services.dms.dms_instance_no_public_access.dms_instance_no_public_access import ( dms_instance_no_public_access, @@ -85,41 +118,25 @@ def test_dms_private(self): assert result[0].region == AWS_REGION_US_EAST_1 assert result[0].resource_id == DMS_INSTANCE_NAME assert result[0].resource_arn == DMS_INSTANCE_ARN - assert result[0].resource_tags == [ - { - "Key": "Name", - "Value": DMS_INSTANCE_NAME, - } - ] + assert result[0].resource_tags == [] + @mock_aws def test_dms_public(self): - dms_client = mock.MagicMock() - dms_client.instances = [] - dms_client.instances.append( - RepInstance( - id=DMS_INSTANCE_NAME, - arn=DMS_INSTANCE_ARN, - status="available", - public=True, - security_groups=[], - kms_key=KMS_KEY_ID, - auto_minor_version_upgrade=False, - multi_az=False, - region=AWS_REGION_US_EAST_1, - tags=[{"Key": "Name", "Value": DMS_INSTANCE_NAME}], - ) - ) - with mock.patch( - "prowler.providers.common.provider.Provider.get_global_provider", - return_value=set_mocked_aws_provider([AWS_REGION_US_EAST_1]), + "botocore.client.BaseClient._make_api_call", + new=mock_make_api_call_public, ): + + aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1]) + + from prowler.providers.aws.services.dms.dms_service import DMS + with mock.patch( - "prowler.providers.aws.services.dms.dms_service.DMS", - new=dms_client, + "prowler.providers.common.provider.Provider.get_global_provider", + return_value=aws_provider, ), mock.patch( - "prowler.providers.aws.services.dms.dms_client.dms_client", - new=dms_client, + "prowler.providers.aws.services.dms.dms_instance_no_public_access.dms_instance_no_public_access.dms_client", + new=DMS(aws_provider), ): from prowler.providers.aws.services.dms.dms_instance_no_public_access.dms_instance_no_public_access import ( dms_instance_no_public_access, @@ -136,12 +153,7 @@ def test_dms_public(self): assert result[0].region == AWS_REGION_US_EAST_1 assert result[0].resource_id == DMS_INSTANCE_NAME assert result[0].resource_arn == DMS_INSTANCE_ARN - assert result[0].resource_tags == [ - { - "Key": "Name", - "Value": DMS_INSTANCE_NAME, - } - ] + assert result[0].resource_tags == [] @mock_aws def test_dms_public_with_public_sg(self): @@ -160,6 +172,7 @@ def test_dms_public_with_public_sg(self): } ], ) + dms_client = mock.MagicMock dms_client = mock.MagicMock() dms_client.instances = [] dms_client.instances.append( @@ -176,14 +189,12 @@ def test_dms_public_with_public_sg(self): tags=[{"Key": "Name", "Value": DMS_INSTANCE_NAME}], ) ) - from prowler.providers.aws.services.ec2.ec2_service import EC2 aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1]) aws_provider.audit_metadata.expected_checks = [ "ec2_securitygroup_allow_ingress_from_internet_to_any_port" ] - with mock.patch( "prowler.providers.common.provider.Provider.get_global_provider", return_value=aws_provider, @@ -205,7 +216,6 @@ def test_dms_public_with_public_sg(self): check = dms_instance_no_public_access() result = check.execute() - assert len(result) == 1 assert result[0].status == "FAIL" assert ( @@ -255,14 +265,12 @@ def test_dms_public_with_filtered_sg(self): tags=[{"Key": "Name", "Value": DMS_INSTANCE_NAME}], ) ) - from prowler.providers.aws.services.ec2.ec2_service import EC2 aws_provider = set_mocked_aws_provider([AWS_REGION_US_EAST_1]) aws_provider.audit_metadata.expected_checks = [ "ec2_securitygroup_allow_ingress_from_internet_to_any_port" ] - with mock.patch( "prowler.providers.common.provider.Provider.get_global_provider", return_value=aws_provider, @@ -284,7 +292,6 @@ def test_dms_public_with_filtered_sg(self): check = dms_instance_no_public_access() result = check.execute() - assert len(result) == 1 assert result[0].status == "PASS" assert (