Skip to content

Commit

Permalink
refactor(aws): Rename get_regions and validate partition (#5772)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfagoagas authored Nov 14, 2024
1 parent 3608aa3 commit cb74dae
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 39 deletions.
60 changes: 38 additions & 22 deletions prowler/providers/aws/aws_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
AWSMFAInfo,
AWSOrganizationsInfo,
AWSSession,
Partition,
)
from prowler.providers.common.models import Audit_Metadata, Connection
from prowler.providers.common.provider import Provider
Expand Down Expand Up @@ -107,7 +108,7 @@ def __init__(
"""
Initializes the AWS provider.
Arguments:
Args:
- retries_max_attempts: The maximum number of retries for the AWS client.
- role_arn: The ARN of the IAM role to assume.
- session_duration: The duration of the session in seconds.
Expand Down Expand Up @@ -368,7 +369,7 @@ def get_organizations_info(
"""
get_organizations_info returns a AWSOrganizationsInfo object if the account to be audited is a delegated administrator for AWS Organizations or if the AWS Organizations Role ARN (--organizations-role) is passed.
Arguments:
Args:
- organizations_session: needs to be a Session object with permissions to do organizations:DescribeAccount and organizations:ListTagsForResource.
- aws_account_id: is the AWS Account ID from which we want to get the AWS Organizations account metadata
Expand Down Expand Up @@ -660,7 +661,7 @@ def get_available_aws_service_regions(
"""
get_available_aws_service_regions returns the available regions for the given service and partition.
Arguments:
Args:
- service: The AWS service name.
- partition: The AWS partition name. Default is "aws".
- audited_regions: A set of regions to audit. Default is None.
Expand Down Expand Up @@ -1274,7 +1275,7 @@ def create_sts_session(
"""
Create an STS session client.
Parameters:
Args:
- session (session.Session): The AWS session object.
- aws_region (str): The AWS region to use for the session.
Expand All @@ -1299,42 +1300,57 @@ def create_sts_session(
raise error

@staticmethod
def get_regions_by_partition(partition: str = None) -> set:
def get_regions(partition: Partition = Partition.aws) -> set:
"""
Get the available AWS regions from the AWS services JSON file with the ability of filtering by partition.
Args:
- partition (str): The AWS partition name. Default is None.
partition (str): The AWS partition to retrieve regions for. Defaults to "aws".
Returns:
set: A set of available AWS regions. All if no `partition` is especified.
set: A set of region names.
Raises:
AWSInvalidPartitionError: If the provided partition name is invalid.
Example:
>>> AwsProvider.get_regions("aws")
{"af-south-1"}
"""

try:
regions = set()
data = read_aws_regions_file()

regions = set()
if partition is None:
for service in data["services"].values():
for partition in service["regions"]:
for item in service["regions"][partition]:
regions.add(item)
regions.update(service["regions"][partition])
else:
partition = Partition(partition)
for service in data["services"].values():
try:
for item in service["regions"][partition]:
regions.add(item)
except KeyError as key_error:
logger.error(
f"{key_error.__class__.__name__}[{key_error.__traceback__.tb_lineno}]: {key_error}"
)
raise AWSInvalidPartitionError(
message=f"Invalid partition name: {partition}",
file=os.path.basename(__file__),
)
regions.update(service["regions"][partition.value])

return regions
except ValueError as value_error:
logger.error(
f"{value_error.__class__.__name__}[{value_error.__traceback__.tb_lineno}]: {value_error}"
)
raise AWSInvalidPartitionError(
message=f"Invalid partition: {partition}",
file=os.path.basename(__file__),
)
except KeyError as key_error:
logger.error(
f"{key_error.__class__.__name__}[{key_error.__traceback__.tb_lineno}]: {key_error}"
)
raise AWSInvalidPartitionError(
message=f"Invalid partition: {partition}",
file=os.path.basename(__file__),
)
except Exception as error:
logger.error(f"{error.__class__.__name__}: {error}")
return set()
raise error


def read_aws_regions_file() -> dict:
Expand Down
2 changes: 1 addition & 1 deletion prowler/providers/aws/lib/arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def init_parser(self):
"-f",
nargs="+",
help="AWS region names to run Prowler against",
choices=AwsProvider.get_regions_by_partition(),
choices=AwsProvider.get_regions(partition=None),
)
# AWS Organizations
aws_orgs_subparser = aws_parser.add_argument_group("AWS Organizations")
Expand Down
24 changes: 24 additions & 0 deletions prowler/providers/aws/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from datetime import datetime
from enum import Enum

from boto3.session import Session
from botocore.config import Config
Expand Down Expand Up @@ -82,6 +83,29 @@ class AWSMFAInfo:
totp: str


class Partition(str, Enum):
"""
Enum class representing different AWS partitions.
Attributes:
aws (str): Represents the standard AWS commercial regions.
aws_cn (str): Represents the AWS China regions.
aws_us_gov (str): Represents the AWS GovCloud (US) Regions.
aws_iso (str): Represents the AWS ISO (US) Regions.
aws_iso_b (str): Represents the AWS ISOB (US) Regions.
aws_iso_e (str): Represents the AWS ISOE (Europe) Regions.
aws_iso_f (str): Represents the AWS ISOF Regions.
"""

aws = "aws"
aws_cn = "aws-cn"
aws_us_gov = "aws-us-gov"
aws_iso = "aws-iso"
aws_iso_b = "aws-iso-b"
aws_iso_e = "aws-iso-e"
aws_iso_f = "aws-iso-f"


class AWSOutputOptions(ProviderOutputOptions):
security_hub_enabled: bool

Expand Down
10 changes: 0 additions & 10 deletions tests/config/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
load_and_validate_config_file,
load_and_validate_fixer_config_file,
)
from prowler.providers.aws.aws_provider import AwsProvider

MOCK_PROWLER_VERSION = "3.3.0"
MOCK_OLD_PROWLER_VERSION = "0.0.0"
Expand Down Expand Up @@ -346,15 +345,6 @@ def mock_prowler_get_latest_release(_, **kwargs):


class Test_Config:
def test_get_regions_by_partition(self):
assert len(AwsProvider.get_regions_by_partition()) == 34

def test_get_regions_by_partition_with_partition(self):
assert len(AwsProvider.get_regions_by_partition("aws-cn")) == 2

def test_get_regions_by_partition_with_unknown_partition(self):
assert len(AwsProvider.get_regions_by_partition("unknown")) == 0

@mock.patch(
"prowler.config.config.requests.get", new=mock_prowler_get_latest_release
)
Expand Down
78 changes: 72 additions & 6 deletions tests/providers/aws/aws_provider_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import botocore
import botocore.exceptions
import pytest
from boto3 import client, resource, session
from mock import patch
from moto import mock_aws
Expand All @@ -24,6 +25,7 @@
from prowler.providers.aws.exceptions.exceptions import (
AWSArgumentTypeValidationError,
AWSIAMRoleARNInvalidResourceTypeError,
AWSInvalidPartitionError,
AWSInvalidProviderIdError,
AWSNoCredentialsError,
)
Expand Down Expand Up @@ -1716,7 +1718,16 @@ def test_get_regions_from_audit_resources_without_regions(self):
)
assert not recovered_regions

def test_get_regions_by_partition(self):
def test_get_regions_all_count(self):
assert len(AwsProvider.get_regions(partition=None)) == 34

def test_get_regions_cn_count(self):
assert len(AwsProvider.get_regions("aws-cn")) == 2

def test_get_regions_aws_count(self):
assert len(AwsProvider.get_regions(partition="aws")) == 30

def test_get_all_regions(self):
with patch(
"prowler.providers.aws.aws_provider.read_aws_regions_file",
return_value={
Expand All @@ -1737,13 +1748,38 @@ def test_get_regions_by_partition(self):
}
},
):
assert AwsProvider.get_regions_by_partition() == {
assert AwsProvider.get_regions(partition=None) == {
"af-south-1",
"cn-north-1",
"us-gov-west-1",
}

def test_get_regions_by_partition_with_partition(self):
def test_get_regions_with_us_gov_partition(self):
with patch(
"prowler.providers.aws.aws_provider.read_aws_regions_file",
return_value={
"services": {
"acm": {
"regions": {
"aws": [
"af-south-1",
],
"aws-cn": [
"cn-north-1",
],
"aws-us-gov": [
"us-gov-west-1",
],
}
}
}
},
):
assert AwsProvider.get_regions("aws-us-gov") == {
"us-gov-west-1",
}

def test_get_regions_with_aws_partition(self):
with patch(
"prowler.providers.aws.aws_provider.read_aws_regions_file",
return_value={
Expand All @@ -1764,11 +1800,36 @@ def test_get_regions_by_partition_with_partition(self):
}
},
):
assert AwsProvider.get_regions_by_partition("aws-cn") == {
assert AwsProvider.get_regions("aws") == {
"af-south-1",
}

def test_get_regions_with_cn_partition(self):
with patch(
"prowler.providers.aws.aws_provider.read_aws_regions_file",
return_value={
"services": {
"acm": {
"regions": {
"aws": [
"af-south-1",
],
"aws-cn": [
"cn-north-1",
],
"aws-us-gov": [
"us-gov-west-1",
],
}
}
}
},
):
assert AwsProvider.get_regions("aws-cn") == {
"cn-north-1",
}

def test_get_regions_by_partition_with_unknown_partition(self):
def test_get_regions_with_unknown_partition(self):
with patch(
"prowler.providers.aws.aws_provider.read_aws_regions_file",
return_value={
Expand All @@ -1789,7 +1850,12 @@ def test_get_regions_by_partition_with_unknown_partition(self):
}
},
):
assert AwsProvider.get_regions_by_partition("unknown") == set()
partition = "unknown"
with pytest.raises(AWSInvalidPartitionError) as exception:
AwsProvider.get_regions(partition)

assert exception.type == AWSInvalidPartitionError
assert f"Invalid partition: {partition}" in exception.value.args[0]

def test_get_aws_region_for_sts_input_regions_none_session_region_none(self):
input_regions = None
Expand Down

0 comments on commit cb74dae

Please sign in to comment.