Skip to content

Commit

Permalink
add expanded actions to TF AWS IAM resources
Browse files Browse the repository at this point in the history
  • Loading branch information
gruebel committed Sep 24, 2023
1 parent 26109f3 commit 50d691b
Show file tree
Hide file tree
Showing 16 changed files with 231 additions and 57 deletions.
13 changes: 11 additions & 2 deletions checkov/common/checks_infra/checks_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ def parse_raw_check(self, raw_check: Dict[str, Dict[str, Any]], **kwargs: Any) -
check.frameworks = raw_check.get("metadata", {}).get("frameworks", [])
check.guideline = raw_check.get("metadata", {}).get("guideline")
check.check_path = kwargs.get("check_path", "")
extensions = raw_check.get("metadata", {}).get("extensions", None)
if extensions is not None:
check.extensions = extensions
solver = self.get_check_solver(check)
check.set_solver(solver)

Expand Down Expand Up @@ -282,9 +285,11 @@ def get_check_solver(self, check: BaseGraphCheck) -> BaseSolver:
if check.sub_checks:
sub_solvers = []
for sub_solver in check.sub_checks:
# set extensions
sub_solver.extensions = check.extensions
sub_solvers.append(self.get_check_solver(sub_solver))

type_to_solver = {
type_to_solver: dict[SolverType | None, BaseSolver | None] = {
SolverType.COMPLEX_CONNECTION: operator_to_complex_connection_solver_classes.get(
check.operator, lambda *args: None
)(sub_solvers, check.operator),
Expand All @@ -300,9 +305,13 @@ def get_check_solver(self, check: BaseGraphCheck) -> BaseSolver:
),
}

solver = type_to_solver.get(check.type) # type:ignore[arg-type] # if not str will return None
solver = type_to_solver.get(check.type)
if not solver:
raise NotImplementedError(f"solver type {check.type} with operator {check.operator} is not supported")

# set extensions
solver.extensions = check.extensions

return solver


Expand Down
Empty file.
85 changes: 85 additions & 0 deletions checkov/common/checks_infra/extensions/iam_action_expansion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import annotations

from typing import Any

from policy_sentry.analysis.expand import expand # type:ignore[import] # will be fixed with the next version
from typing_extensions import Self

from checkov.common.graph.checks_infra.extensions.base_extension import BaseGraphCheckExtension
from checkov.common.graph.graph_builder.graph_components.attribute_names import CustomAttributes
from checkov.common.models.enums import GraphCheckExtension
from checkov.common.util.data_structures_utils import pickle_deepcopy
from checkov.common.util.type_forcers import force_list

SUPPORTED_IAM_BLOCKS = {
"aws_iam_group_policy",
"aws_iam_policy",
"aws_iam_role_policy",
"aws_iam_user_policy",
"aws_ssoadmin_permission_set_inline_policy",
"data.aws_iam_policy_document",
}
IAM_POLICY_BLOCKS = {
"aws_iam_group_policy",
"aws_iam_policy",
"aws_iam_role_policy",
"aws_iam_user_policy",
}


class IamActionExpansion(BaseGraphCheckExtension):
_instance = None # noqa: CCE003 # singleton

name = GraphCheckExtension.IAM_ACTION_EXPANSION # noqa: CCE003 # a static attribute
iam_expanded_actions_cache: dict[str, dict[str, Any]] = {} # noqa: CCE003 # global cache

def __new__(cls) -> Self:
if cls._instance is None:
cls._instance = super().__new__(cls)

return cls._instance

def extend(self, vertex_data: dict[str, Any]) -> dict[str, Any]:
if not vertex_data[CustomAttributes.RESOURCE_TYPE] in SUPPORTED_IAM_BLOCKS:
return vertex_data

cache_key = f"{vertex_data[CustomAttributes.FILE_PATH]}:{vertex_data[CustomAttributes.RESOURCE_TYPE]}:{vertex_data[CustomAttributes.BLOCK_NAME]}"
if cache_key in IamActionExpansion.iam_expanded_actions_cache:
return IamActionExpansion.iam_expanded_actions_cache[cache_key]

expanded_actions = self._expand_iam_actions(vertex_data=vertex_data)
IamActionExpansion.iam_expanded_actions_cache[cache_key] = expanded_actions
return expanded_actions

def _expand_iam_actions(self, vertex_data: dict[str, Any]) -> dict[str, Any]:
"""Returns resource data with the expanded actions of an IAM statement
Info: Only AWS Terraform resources are supported for now
"""

vertex_data = pickle_deepcopy(vertex_data)
resource_type = vertex_data[CustomAttributes.RESOURCE_TYPE]
if resource_type == "data.aws_iam_policy_document":
self._adjust_action_value(policy=vertex_data, statement_key="statement", action_key="actions")
elif resource_type in IAM_POLICY_BLOCKS:
policy = vertex_data.get("policy")
if isinstance(policy, dict):
self._adjust_action_value(policy=policy, statement_key="Statement", action_key="Action")
elif resource_type == "aws_ssoadmin_permission_set_inline_policy":
policy = vertex_data.get("inline_policy")
if isinstance(policy, dict):
self._adjust_action_value(policy=policy, statement_key="Statement", action_key="Action")

return vertex_data

def _adjust_action_value(self, policy: dict[str, Any], statement_key: str, action_key: str) -> None:
for statement in force_list(policy.get(statement_key, [])):
if action_key in statement:
original_actions = statement[action_key]
expanded_actions = expand(action=original_actions)
if isinstance(original_actions, list):
statement[action_key].extend(expanded_actions)
else:
expanded_actions = list(expanded_actions) # fix in policy_sentry to be a list not a set
expanded_actions.append(original_actions)
statement[action_key] = expanded_actions
34 changes: 34 additions & 0 deletions checkov/common/checks_infra/extensions_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import annotations

import logging
from typing import Any

from typing_extensions import Self

from checkov.common.checks_infra.extensions.iam_action_expansion import IamActionExpansion
from checkov.common.models.enums import GraphCheckExtension

logger = logging.getLogger(__name__)


class GraphCheckExtensionsRegistry:
_instance = None # noqa: CCE003 # singleton

def __new__(cls) -> Self:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls.extensions = {
IamActionExpansion.name: IamActionExpansion(),
}

return cls._instance

def run(self, extensions: list[GraphCheckExtension], vertex_data: dict[str, Any]) -> dict[str, Any]:
for extension in extensions:
if extension not in self.extensions:
logger.info(f"Extension {extension} doesn't exist")
continue

vertex_data = self.extensions[extension].extend(vertex_data=vertex_data)

return vertex_data
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from checkov.common.graph.checks_infra import debug
from checkov.common.graph.checks_infra.enums import SolverType
from checkov.common.checks_infra.extensions_registry import GraphCheckExtensionsRegistry
from checkov.common.graph.checks_infra.solvers.base_solver import BaseSolver

from concurrent.futures import ThreadPoolExecutor
Expand Down Expand Up @@ -54,15 +55,19 @@ def run(self, graph_connector: LibraryGraph) -> Tuple[List[Dict[str, Any]], List
else:
select_kwargs = {"block_type__in": list(SUPPORTED_BLOCK_TYPES)}

for data in graph_connector.vs.select(**select_kwargs)["attr"]:
result = self.get_operation(vertex=data)
# A None indicate for UNKNOWN result - the vertex shouldn't be added to the passed or the failed vertices
if result is None:
unknown_vertices.append(data)
elif result:
passed_vertices.append(data)
else:
failed_vertices.append(data)
try:
for data in graph_connector.vs.select(**select_kwargs)["attr"]:
result = self.get_operation(vertex=data)
# A None indicate for UNKNOWN result - the vertex shouldn't be added to the passed or the failed vertices
if result is None:
unknown_vertices.append(data)
elif result:
passed_vertices.append(data)
else:
failed_vertices.append(data)
except KeyError:
# igraph throws a KeyError, when it can't find any related vertices
pass

return passed_vertices, failed_vertices, unknown_vertices

Expand Down Expand Up @@ -96,6 +101,8 @@ def get_operation(self, vertex: Dict[str, Any]) -> Optional[bool]:
and self.value != '':
return None

vertex = GraphCheckExtensionsRegistry().run(extensions=self.extensions, vertex_data=vertex)

if self.attribute and (self.is_jsonpath_check or re.match(WILDCARD_PATTERN, self.attribute)):
attribute_matches = self.get_attribute_matches(vertex)
filtered_attribute_matches = attribute_matches
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,18 @@ def run(self, graph_connector: LibraryGraph) -> Tuple[List[Dict[str, Any]], List
if self.resource_types:
select_kwargs = {"resource_type_in": self.resource_types}

for data in graph_connector.vs.select(**select_kwargs)["attr"]:
result = self.get_operation(data)
if result is None:
unknown_vertices.append(data)
elif result:
passed_vertices.append(data)
else:
failed_vertices.append(data)
try:
for data in graph_connector.vs.select(**select_kwargs)["attr"]:
result = self.get_operation(data)
if result is None:
unknown_vertices.append(data)
elif result:
passed_vertices.append(data)
else:
failed_vertices.append(data)
except KeyError:
# igraph throws a KeyError, when it can't find any related vertices
pass

debug.complex_connection_block(
solvers=self.solvers,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,16 @@ def set_vertices(self, graph_connector: LibraryGraph, exclude_vertices: List[Dic
if self.resource_types:
select_kwargs = {"resource_type_in": self.resource_types}

self.vertices_under_resource_types = [
data for data in graph_connector.vs.select(**select_kwargs)["attr"]
]
self.vertices_under_connected_resources_types = [
data for data in graph_connector.vs.select(resource_type_in=self.connected_resources_types)["attr"]
]
try:
self.vertices_under_resource_types = [
data for data in graph_connector.vs.select(**select_kwargs)["attr"]
]
self.vertices_under_connected_resources_types = [
data for data in graph_connector.vs.select(resource_type_in=self.connected_resources_types)["attr"]
]
except KeyError:
# igraph throws a KeyError, when it can't find any related vertices
pass
else:
self.vertices_under_resource_types = [
v for _, v in graph_connector.nodes(data=True) if self.resource_type_pred(v, self.resource_types)
Expand All @@ -86,12 +90,16 @@ def reduce_graph_by_target_types(self, graph_connector: LibraryGraph) -> Library
return graph_connector

if isinstance(graph_connector, Graph):
resource_nodes = {
vertex for vertex in graph_connector.vs.select(resource_type_in=self.targeted_resources_types)
}
connection_nodes = {
vertex for vertex in graph_connector.vs.select(block_type__in=BaseConnectionSolver.SUPPORTED_CONNECTION_BLOCK_TYPES)
}
try:
resource_nodes = {
vertex for vertex in graph_connector.vs.select(resource_type_in=self.targeted_resources_types)
}
connection_nodes = {
vertex for vertex in graph_connector.vs.select(block_type__in=BaseConnectionSolver.SUPPORTED_CONNECTION_BLOCK_TYPES)
}
except KeyError:
# igraph throws a KeyError, when it can't find any related vertices
return Graph()
else:
resource_nodes = {
node
Expand Down
8 changes: 6 additions & 2 deletions checkov/common/graph/checks_infra/base_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

from checkov.common.graph.checks_infra.enums import SolverType
from checkov.common.graph.checks_infra.solvers.base_solver import BaseSolver
from checkov.common.models.enums import GraphCheckExtension

if TYPE_CHECKING:
from checkov.common.typing import LibraryGraph
from checkov.common.bridgecrew.severities import Severity
from networkx import DiGraph


class BaseGraphCheck:
Expand All @@ -33,11 +34,14 @@ def __init__(self) -> None:
self.frameworks: List[str] = []
self.is_jsonpath_check: bool = False
self.check_path: str = ""
self.extensions: list[GraphCheckExtension] = [
GraphCheckExtension.IAM_ACTION_EXPANSION,
]

def set_solver(self, solver: BaseSolver) -> None:
self.solver = solver

def run(self, graph_connector: DiGraph) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
def run(self, graph_connector: LibraryGraph) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
if not self.solver:
raise AttributeError("solver attribute was not set")

Expand Down
Empty file.
14 changes: 14 additions & 0 deletions checkov/common/graph/checks_infra/extensions/base_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any

from checkov.common.models.enums import GraphCheckExtension


class BaseGraphCheckExtension(ABC):
name: GraphCheckExtension

@abstractmethod
def extend(self, vertex_data: dict[str, Any]) -> dict[str, Any]:
pass
27 changes: 10 additions & 17 deletions checkov/common/graph/checks_infra/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import annotations
import concurrent.futures
import logging
from typing import Any, TYPE_CHECKING

Expand All @@ -24,20 +23,12 @@ def load_checks(self) -> None:
def run_checks(
self, graph_connector: LibraryGraph, runner_filter: RunnerFilter, report_type: str
) -> dict[BaseGraphCheck, list[_CheckResult]]:
return {
check: self.run_check(check=check, graph_connector=graph_connector)
for check in (c for c in self.checks if runner_filter.should_run_check(c, report_type=report_type))
}

check_results: "dict[BaseGraphCheck, list[_CheckResult]]" = {}
checks_to_run = [c for c in self.checks if runner_filter.should_run_check(c, report_type=report_type)]
with concurrent.futures.ThreadPoolExecutor() as executor:
concurrent.futures.wait(
[executor.submit(self.run_check_parallel, check, check_results, graph_connector)
for check in checks_to_run]
)
return check_results

def run_check_parallel(
self, check: BaseGraphCheck, check_results: dict[BaseGraphCheck, list[_CheckResult]],
graph_connector: LibraryGraph
) -> None:
def run_check(self, check: BaseGraphCheck, graph_connector: LibraryGraph) -> list[_CheckResult]:
logging.debug(f'Running graph check: {check.id}')
debug.graph_check(check_id=check.id, check_name=check.name)

Expand All @@ -46,7 +37,7 @@ def run_check_parallel(
check_result = self._process_check_result(passed, [], CheckResult.PASSED, evaluated_keys)
check_result = self._process_check_result(failed, check_result, CheckResult.FAILED, evaluated_keys)
check_result = self._process_check_result(unknown, check_result, CheckResult.UNKNOWN, evaluated_keys)
check_results[check] = check_result
return check_result

@staticmethod
def _process_check_result(
Expand All @@ -55,6 +46,8 @@ def _process_check_result(
result: CheckResult,
evaluated_keys: list[str],
) -> list[_CheckResult]:
for vertex in results:
processed_results.append({"result": result, "entity": vertex, "evaluated_keys": evaluated_keys})
processed_results.extend(
{"result": result, "entity": vertex, "evaluated_keys": evaluated_keys}
for vertex in results
)
return processed_results
Loading

0 comments on commit 50d691b

Please sign in to comment.